mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge: resolve conflicts between refactor/agent and main
This commit is contained in:
+27
-16
@@ -222,13 +222,10 @@ func (cb *ContextBuilder) InvalidateCache() {
|
||||
// invalidation (bootstrap files + memory). Skill roots are handled separately
|
||||
// because they require both directory-level and recursive file-level checks.
|
||||
func (cb *ContextBuilder) sourcePaths() []string {
|
||||
return []string{
|
||||
filepath.Join(cb.workspace, "AGENTS.md"),
|
||||
filepath.Join(cb.workspace, "SOUL.md"),
|
||||
filepath.Join(cb.workspace, "USER.md"),
|
||||
filepath.Join(cb.workspace, "IDENTITY.md"),
|
||||
filepath.Join(cb.workspace, "memory", "MEMORY.md"),
|
||||
}
|
||||
agentDefinition := cb.LoadAgentDefinition()
|
||||
paths := agentDefinition.trackedPaths(cb.workspace)
|
||||
paths = append(paths, filepath.Join(cb.workspace, "memory", "MEMORY.md"))
|
||||
return uniquePaths(paths)
|
||||
}
|
||||
|
||||
// skillRoots returns all skill root directories that can affect
|
||||
@@ -432,18 +429,32 @@ func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Ti
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
bootstrapFiles := []string{
|
||||
"AGENTS.md",
|
||||
"SOUL.md",
|
||||
"USER.md",
|
||||
"IDENTITY.md",
|
||||
var sb strings.Builder
|
||||
|
||||
agentDefinition := cb.LoadAgentDefinition()
|
||||
if agentDefinition.Agent != nil {
|
||||
label := string(agentDefinition.Source)
|
||||
if label == "" {
|
||||
label = relativeWorkspacePath(cb.workspace, agentDefinition.Agent.Path)
|
||||
}
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", label, agentDefinition.Agent.Body)
|
||||
}
|
||||
if agentDefinition.Soul != nil {
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
"## %s\n\n%s\n\n",
|
||||
relativeWorkspacePath(cb.workspace, agentDefinition.Soul.Path),
|
||||
agentDefinition.Soul.Content,
|
||||
)
|
||||
}
|
||||
if agentDefinition.User != nil {
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "USER.md", agentDefinition.User.Content)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, filename := range bootstrapFiles {
|
||||
filePath := filepath.Join(cb.workspace, filename)
|
||||
if agentDefinition.Source != AgentDefinitionSourceAgent {
|
||||
filePath := filepath.Join(cb.workspace, "IDENTITY.md")
|
||||
if data, err := os.ReadFile(filePath); err == nil {
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", filename, data)
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "IDENTITY.md", data)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// parseTurnBoundaries returns the starting index of each Turn in the history.
|
||||
// A Turn is a complete "user input → LLM iterations → final response" cycle
|
||||
// (as defined in #1316). Each Turn begins at a user message and extends
|
||||
// through all subsequent assistant/tool messages until the next user message.
|
||||
//
|
||||
// Cutting at a Turn boundary guarantees that no tool-call sequence
|
||||
// (assistant+ToolCalls → tool results) is split across the cut.
|
||||
func parseTurnBoundaries(history []providers.Message) []int {
|
||||
var starts []int
|
||||
for i, msg := range history {
|
||||
if msg.Role == "user" {
|
||||
starts = append(starts, i)
|
||||
}
|
||||
}
|
||||
return starts
|
||||
}
|
||||
|
||||
// isSafeBoundary reports whether index is a valid Turn boundary — i.e.,
|
||||
// a position where the kept portion (history[index:]) begins at a user
|
||||
// message, so no tool-call sequence is torn apart.
|
||||
func isSafeBoundary(history []providers.Message, index int) bool {
|
||||
if index <= 0 || index >= len(history) {
|
||||
return true
|
||||
}
|
||||
return history[index].Role == "user"
|
||||
}
|
||||
|
||||
// findSafeBoundary locates the nearest Turn boundary to targetIndex.
|
||||
// It prefers the boundary at or before targetIndex (preserving more recent
|
||||
// context). Falls back to the nearest boundary after targetIndex, and
|
||||
// returns targetIndex unchanged only when no Turn boundary exists at all.
|
||||
func findSafeBoundary(history []providers.Message, targetIndex int) int {
|
||||
if len(history) == 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIndex <= 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIndex >= len(history) {
|
||||
return len(history)
|
||||
}
|
||||
|
||||
turns := parseTurnBoundaries(history)
|
||||
if len(turns) == 0 {
|
||||
return targetIndex
|
||||
}
|
||||
|
||||
// Find the last Turn boundary at or before targetIndex.
|
||||
// Prefer backward: keeps more recent messages.
|
||||
backward := -1
|
||||
for _, t := range turns {
|
||||
if t <= targetIndex {
|
||||
backward = t
|
||||
}
|
||||
}
|
||||
if backward > 0 {
|
||||
return backward
|
||||
}
|
||||
|
||||
// No valid Turn boundary before target (or only at index 0 which
|
||||
// would keep everything). Use the first Turn after targetIndex.
|
||||
for _, t := range turns {
|
||||
if t > targetIndex {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// No Turn boundary after targetIndex either. The only boundary is at
|
||||
// index 0, meaning the entire history is a single Turn. Return 0 to
|
||||
// signal that safe compression is not possible — callers check for
|
||||
// mid <= 0 and skip compression in that case.
|
||||
return 0
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a single message,
|
||||
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
||||
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
||||
func estimateMessageTokens(msg providers.Message) int {
|
||||
chars := utf8.RuneCountInString(msg.Content)
|
||||
|
||||
// ReasoningContent (extended thinking / chain-of-thought) can be
|
||||
// substantial and is stored in session history via AddFullMessage.
|
||||
if msg.ReasoningContent != "" {
|
||||
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
}
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
chars += len(tc.ID) + len(tc.Type)
|
||||
if tc.Function != nil {
|
||||
// Count function name + arguments (the wire format for most providers).
|
||||
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
} else {
|
||||
// Fallback: some provider formats use top-level Name without Function.
|
||||
chars += len(tc.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ToolCallID != "" {
|
||||
chars += len(msg.ToolCallID)
|
||||
}
|
||||
|
||||
// Per-message overhead for role label, JSON structure, separators.
|
||||
const messageOverhead = 12
|
||||
chars += messageOverhead
|
||||
|
||||
tokens := chars * 2 / 5
|
||||
|
||||
// Media items (images, files) are serialized by provider adapters into
|
||||
// multipart or image_url payloads. Add a fixed per-item token estimate
|
||||
// directly (not through the chars heuristic) since actual cost depends
|
||||
// on resolution and provider-specific image tokenization.
|
||||
const mediaTokensPerItem = 256
|
||||
tokens += len(msg.Media) * mediaTokensPerItem
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// estimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Each tool's name, description, and
|
||||
// JSON schema parameters contribute to the context window budget.
|
||||
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
if len(defs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
totalChars := 0
|
||||
for _, d := range defs {
|
||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||
|
||||
if d.Function.Parameters != nil {
|
||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||
totalChars += len(paramJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Per-tool overhead: type field, JSON structure, separators.
|
||||
totalChars += 20
|
||||
}
|
||||
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
||||
// and output reserve would exceed the model's context window. This enables
|
||||
// proactive compression before calling the LLM, rather than reacting to 400 errors.
|
||||
func isOverContextBudget(
|
||||
contextWindow int,
|
||||
messages []providers.Message,
|
||||
toolDefs []providers.ToolDefinition,
|
||||
maxTokens int,
|
||||
) bool {
|
||||
msgTokens := 0
|
||||
for _, m := range messages {
|
||||
msgTokens += estimateMessageTokens(m)
|
||||
}
|
||||
|
||||
toolTokens := estimateToolDefsTokens(toolDefs)
|
||||
total := msgTokens + toolTokens + maxTokens
|
||||
|
||||
return total > contextWindow
|
||||
}
|
||||
@@ -0,0 +1,826 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// msgUser creates a user message.
|
||||
func msgUser(content string) providers.Message {
|
||||
return providers.Message{Role: "user", Content: content}
|
||||
}
|
||||
|
||||
// msgAssistant creates a plain assistant message (no tool calls).
|
||||
func msgAssistant(content string) providers.Message {
|
||||
return providers.Message{Role: "assistant", Content: content}
|
||||
}
|
||||
|
||||
// msgAssistantTC creates an assistant message with tool calls.
|
||||
func msgAssistantTC(toolIDs ...string) providers.Message {
|
||||
tcs := make([]providers.ToolCall, len(toolIDs))
|
||||
for i, id := range toolIDs {
|
||||
tcs[i] = providers.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
Name: "tool_" + id,
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_" + id,
|
||||
Arguments: `{"key":"value"}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
return providers.Message{Role: "assistant", ToolCalls: tcs}
|
||||
}
|
||||
|
||||
// msgTool creates a tool result message.
|
||||
func msgTool(callID, content string) providers.Message {
|
||||
return providers.Message{Role: "tool", ToolCallID: callID, Content: content}
|
||||
}
|
||||
|
||||
func TestParseTurnBoundaries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "empty history",
|
||||
history: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "simple exchange",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
},
|
||||
want: []int{0, 2},
|
||||
},
|
||||
{
|
||||
name: "tool-call Turn",
|
||||
history: []providers.Message{
|
||||
msgUser("search"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "result"),
|
||||
msgAssistant("found it"),
|
||||
msgUser("thanks"),
|
||||
msgAssistant("welcome"),
|
||||
},
|
||||
want: []int{0, 4},
|
||||
},
|
||||
{
|
||||
name: "chained tool calls in single Turn",
|
||||
history: []providers.Message{
|
||||
msgUser("save and notify"),
|
||||
msgAssistantTC("tc_save"),
|
||||
msgTool("tc_save", "saved"),
|
||||
msgAssistantTC("tc_notify"),
|
||||
msgTool("tc_notify", "notified"),
|
||||
msgAssistant("done"),
|
||||
},
|
||||
want: []int{0},
|
||||
},
|
||||
{
|
||||
name: "no user messages",
|
||||
history: []providers.Message{
|
||||
msgAssistant("a1"),
|
||||
msgAssistant("a2"),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "leading non-user messages",
|
||||
history: []providers.Message{
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistant("greeting"),
|
||||
msgUser("hello"),
|
||||
msgAssistant("hi"),
|
||||
},
|
||||
want: []int{3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := parseTurnBoundaries(tt.history)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("parseTurnBoundaries() = %v, want %v", got, tt.want)
|
||||
return
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("parseTurnBoundaries()[%d] = %d, want %d", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
index int
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty history, index 0",
|
||||
history: nil,
|
||||
index: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single user message, index 0",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
index: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single user message, index 1 (end)",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
index: 1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "at user message",
|
||||
history: []providers.Message{
|
||||
msgAssistant("hello"),
|
||||
msgUser("how are you"),
|
||||
msgAssistant("fine"),
|
||||
},
|
||||
index: 1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "at assistant without tool calls",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
msgAssistant("response"),
|
||||
msgUser("follow up"),
|
||||
},
|
||||
index: 1,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "at assistant with tool calls",
|
||||
history: []providers.Message{
|
||||
msgUser("search something"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "result"),
|
||||
msgAssistant("here is what I found"),
|
||||
},
|
||||
index: 1,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "at tool result",
|
||||
history: []providers.Message{
|
||||
msgUser("do something"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "done"),
|
||||
msgAssistant("completed"),
|
||||
},
|
||||
index: 2,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negative index",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
},
|
||||
index: -1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "index beyond length",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
},
|
||||
index: 5,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSafeBoundary(tt.history, tt.index)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSafeBoundary(history, %d) = %v, want %v", tt.index, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
targetIndex int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "empty history",
|
||||
history: nil,
|
||||
targetIndex: 0,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "target at 0",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
targetIndex: 0,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "target beyond length",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
targetIndex: 5,
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "target already at user message",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
},
|
||||
targetIndex: 2,
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "target at assistant, scan backward finds user",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 3, // assistant "a2"
|
||||
want: 2, // backward to user "q2"
|
||||
},
|
||||
{
|
||||
name: "target inside tool sequence, scan backward finds user",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1", "tc2"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("summary"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 4, // tool result "r1"
|
||||
want: 2, // backward: 3=assistant+TC (not safe), 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "target inside tool sequence, backward finds user before chain",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1", "tc2"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("summary"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 5, // tool result "r2"
|
||||
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "no backward user, scan forward finds one",
|
||||
history: []providers.Message{
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q1"),
|
||||
},
|
||||
targetIndex: 1, // tool result
|
||||
want: 3, // forward to user "q1"
|
||||
},
|
||||
{
|
||||
name: "multi-step tool chain preserves atomicity",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistantTC("tc2"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("final"),
|
||||
msgUser("q3"),
|
||||
msgAssistant("a3"),
|
||||
},
|
||||
targetIndex: 5, // second assistant+TC
|
||||
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "all non-user messages returns target unchanged",
|
||||
history: []providers.Message{
|
||||
msgAssistant("a1"),
|
||||
msgAssistant("a2"),
|
||||
msgAssistant("a3"),
|
||||
},
|
||||
targetIndex: 1,
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := findSafeBoundary(tt.history, tt.targetIndex)
|
||||
if got != tt.want {
|
||||
t.Errorf("findSafeBoundary(history, %d) = %d, want %d",
|
||||
tt.targetIndex, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary_SingleTurnReturnsZero(t *testing.T) {
|
||||
// A single Turn with no subsequent user message. The only Turn boundary
|
||||
// is at index 0; cutting anywhere else would split the Turn's tool
|
||||
// sequence. findSafeBoundary must return 0 so callers skip compression.
|
||||
history := []providers.Message{
|
||||
msgUser("do everything"), // 0 ← only Turn boundary
|
||||
msgAssistantTC("tc1"), // 1
|
||||
msgTool("tc1", "result"), // 2
|
||||
msgAssistant("all done"), // 3
|
||||
}
|
||||
|
||||
got := findSafeBoundary(history, 2)
|
||||
if got != 0 {
|
||||
t.Errorf("findSafeBoundary(single_turn, 2) = %d, want 0 (cannot split single Turn)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) {
|
||||
// A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user
|
||||
// Target is inside the chain; boundary should skip the entire chain backward.
|
||||
history := []providers.Message{
|
||||
msgUser("start"), // 0
|
||||
msgAssistant("before chain"), // 1
|
||||
msgUser("trigger"), // 2 ← expected safe boundary
|
||||
msgAssistantTC("t1", "t2", "t3"), // 3
|
||||
msgTool("t1", "r1"), // 4
|
||||
msgTool("t2", "r2"), // 5
|
||||
msgTool("t3", "r3"), // 6
|
||||
msgAssistantTC("t4"), // 7
|
||||
msgTool("t4", "r4"), // 8
|
||||
msgAssistant("chain done"), // 9
|
||||
msgUser("next"), // 10
|
||||
}
|
||||
|
||||
// Target at index 6 (middle of tool results)
|
||||
got := findSafeBoundary(history, 6)
|
||||
if got != 2 {
|
||||
t.Errorf("findSafeBoundary(history, 6) = %d, want 2 (user before chain)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg providers.Message
|
||||
want int // minimum expected tokens (exact value depends on overhead)
|
||||
}{
|
||||
{
|
||||
name: "plain user message",
|
||||
msg: msgUser("Hello, world!"),
|
||||
want: 1, // at least some tokens
|
||||
},
|
||||
{
|
||||
name: "empty message still has overhead",
|
||||
msg: providers.Message{Role: "user"},
|
||||
want: 1, // message overhead alone
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
msg: msgAssistantTC("tc_123"),
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "tool result with ID",
|
||||
msg: msgTool("call_abc", "Here is the search result with lots of content"),
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateMessageTokens(tt.msg)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
|
||||
plain := msgAssistant("thinking")
|
||||
withTC := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "thinking",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "web_search",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "web_search",
|
||||
Arguments: `{"query":"picoclaw agent framework","max_results":5}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
withTCTokens := estimateMessageTokens(withTC)
|
||||
|
||||
if withTCTokens <= plainTokens {
|
||||
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
||||
withTCTokens, plainTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
|
||||
// Multi-byte characters (e.g. emoji, accented letters) are single runes
|
||||
// but may map to different token counts. The heuristic should still produce
|
||||
// reasonable estimates via RuneCountInString.
|
||||
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
||||
tokens := estimateMessageTokens(msg)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
|
||||
// Simulate a tool call with large JSON arguments.
|
||||
largeArgs := fmt.Sprintf(`{"content":"%s"}`, strings.Repeat("x", 5000))
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_large",
|
||||
Type: "function",
|
||||
Name: "write_file",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "write_file",
|
||||
Arguments: largeArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
||||
if tokens < 2000 {
|
||||
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
|
||||
plain := msgAssistant("result")
|
||||
withReasoning := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "result",
|
||||
ReasoningContent: strings.Repeat("thinking step ", 200),
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
reasoningTokens := estimateMessageTokens(withReasoning)
|
||||
|
||||
if reasoningTokens <= plainTokens {
|
||||
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||
reasoningTokens, plainTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_MediaItems(t *testing.T) {
|
||||
plain := msgUser("describe this")
|
||||
withMedia := providers.Message{
|
||||
Role: "user",
|
||||
Content: "describe this",
|
||||
Media: []string{"media://img1.png", "media://img2.png"},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
mediaTokens := estimateMessageTokens(withMedia)
|
||||
|
||||
if mediaTokens <= plainTokens {
|
||||
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
|
||||
mediaTokens, plainTokens)
|
||||
}
|
||||
|
||||
// Each media item should add exactly 256 tokens (not run through chars*2/5).
|
||||
expectedDelta := 256 * 2
|
||||
actualDelta := mediaTokens - plainTokens
|
||||
if actualDelta != expectedDelta {
|
||||
t.Errorf("2 media items should add %d tokens, got delta %d", expectedDelta, actualDelta)
|
||||
}
|
||||
}
|
||||
|
||||
// --- estimateToolDefsTokens tests ---
|
||||
|
||||
func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
defs []providers.ToolDefinition
|
||||
want int // minimum expected tokens
|
||||
}{
|
||||
{
|
||||
name: "empty tool list",
|
||||
defs: nil,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "single tool with params",
|
||||
defs: []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for information",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "tool without params",
|
||||
defs: []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "list_dir",
|
||||
Description: "List directory contents",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateToolDefsTokens(tt.defs)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
|
||||
makeTool := func(name string) providers.ToolDefinition {
|
||||
return providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: name,
|
||||
Description: "A test tool that does something useful",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"input": map[string]any{"type": "string", "description": "Input value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := estimateToolDefsTokens([]providers.ToolDefinition{
|
||||
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
||||
})
|
||||
|
||||
if three <= one {
|
||||
t.Errorf("3 tools (%d tokens) should exceed 1 tool (%d tokens)", three, one)
|
||||
}
|
||||
}
|
||||
|
||||
// --- isOverContextBudget tests ---
|
||||
|
||||
func TestIsOverContextBudget(t *testing.T) {
|
||||
systemMsg := providers.Message{Role: "system", Content: strings.Repeat("x", 1000)}
|
||||
userMsg := msgUser("hello")
|
||||
smallHistory := []providers.Message{systemMsg, msgUser("q1"), msgAssistant("a1"), userMsg}
|
||||
|
||||
tools := []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contextWindow int
|
||||
messages []providers.Message
|
||||
toolDefs []providers.ToolDefinition
|
||||
maxTokens int
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "within budget",
|
||||
contextWindow: 100000,
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 4096,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "over budget with small window",
|
||||
contextWindow: 100, // very small window
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 4096,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large max_tokens eats budget",
|
||||
contextWindow: 2000,
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 1800, // leaves almost no room
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty messages within budget",
|
||||
contextWindow: 10000,
|
||||
messages: nil,
|
||||
toolDefs: nil,
|
||||
maxTokens: 4096,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isOverContextBudget(tt.contextWindow, tt.messages, tt.toolDefs, tt.maxTokens)
|
||||
if got != tt.want {
|
||||
t.Errorf("isOverContextBudget() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests reflecting actual session data shape ---
|
||||
// Session history never contains system messages. The system prompt is
|
||||
// built dynamically by BuildMessages. These tests use realistic history
|
||||
// shapes: user/assistant/tool only, with tool chains and reasoning content.
|
||||
|
||||
func TestFindSafeBoundary_SessionHistoryNoSystem(t *testing.T) {
|
||||
// Real session history starts with a user message, not a system message.
|
||||
history := []providers.Message{
|
||||
msgUser("hello"), // 0
|
||||
msgAssistant("hi there"), // 1
|
||||
msgUser("search for X"), // 2
|
||||
msgAssistantTC("tc1"), // 3
|
||||
msgTool("tc1", "found X"), // 4
|
||||
msgAssistant("here is X"), // 5
|
||||
msgUser("thanks"), // 6
|
||||
msgAssistant("you're welcome"), // 7
|
||||
}
|
||||
|
||||
// Mid-point is 4 (tool result). Should snap backward to 2 (user).
|
||||
got := findSafeBoundary(history, 4)
|
||||
if got != 2 {
|
||||
t.Errorf("findSafeBoundary(session_history, 4) = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary_SessionWithChainedTools(t *testing.T) {
|
||||
// Session with chained tool calls (save then notify).
|
||||
history := []providers.Message{
|
||||
msgUser("save and notify"), // 0
|
||||
msgAssistantTC("tc_save"), // 1
|
||||
msgTool("tc_save", "saved"), // 2
|
||||
msgAssistantTC("tc_notify"), // 3
|
||||
msgTool("tc_notify", "notified"), // 4
|
||||
msgAssistant("done"), // 5
|
||||
msgUser("check status"), // 6
|
||||
msgAssistant("all good"), // 7
|
||||
}
|
||||
|
||||
// Target at 3 (inside chain). Should find user at 0, but backward
|
||||
// scan stops at i>0, so forward scan finds user at 6.
|
||||
// Actually: backward from 3: 2=tool (no), 1=assistantTC (no). Forward: 4=tool, 5=asst, 6=user ✓
|
||||
got := findSafeBoundary(history, 3)
|
||||
if got != 6 {
|
||||
t.Errorf("findSafeBoundary(chained_tools, 3) = %d, want 6", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
// Message with all fields populated — mirrors what AddFullMessage stores.
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "Here is the analysis.",
|
||||
ReasoningContent: strings.Repeat("Let me think about this carefully. ", 50),
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "analyze",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "analyze",
|
||||
Arguments: `{"data":"sample","depth":3}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
|
||||
// ReasoningContent alone is ~1700 chars → ~680 tokens.
|
||||
// Content + TC + overhead adds more. Should be well above 500.
|
||||
if tokens < 500 {
|
||||
t.Errorf("message with reasoning+toolcalls should have significant tokens, got %d", tokens)
|
||||
}
|
||||
|
||||
// Compare without reasoning to ensure it's counted.
|
||||
msgNoReasoning := msg
|
||||
msgNoReasoning.ReasoningContent = ""
|
||||
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
|
||||
|
||||
if tokens <= tokensNoReasoning {
|
||||
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOverContextBudget_RealisticSession(t *testing.T) {
|
||||
// Simulate what BuildMessages produces: system + session history + current user.
|
||||
// System message is built by BuildMessages, not stored in session.
|
||||
systemMsg := providers.Message{
|
||||
Role: "system",
|
||||
Content: strings.Repeat("system prompt content ", 100),
|
||||
}
|
||||
sessionHistory := []providers.Message{
|
||||
msgUser("first question"),
|
||||
msgAssistant("first answer"),
|
||||
msgUser("use tool X"),
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll use tool X",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "tc1", Type: "function", Name: "tool_x",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_x",
|
||||
Arguments: `{"query":"test","verbose":true}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: strings.Repeat("result data ", 200), ToolCallID: "tc1"},
|
||||
msgAssistant("Here are the results from tool X."),
|
||||
}
|
||||
currentUser := msgUser("follow up question")
|
||||
|
||||
// Assemble as BuildMessages would.
|
||||
messages := make([]providers.Message, 0, 1+len(sessionHistory)+1)
|
||||
messages = append(messages, systemMsg)
|
||||
messages = append(messages, sessionHistory...)
|
||||
messages = append(messages, currentUser)
|
||||
|
||||
tools := []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "tool_x",
|
||||
Description: "A useful tool",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// With a large context window, should be within budget.
|
||||
if isOverContextBudget(131072, messages, tools, 32768) {
|
||||
t.Error("realistic session should be within 131072 context window")
|
||||
}
|
||||
|
||||
// With a tiny context window, should exceed budget.
|
||||
if !isOverContextBudget(500, messages, tools, 32768) {
|
||||
t.Error("realistic session should exceed 500 context window")
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func setupWorkspace(t *testing.T, files map[string]string) string {
|
||||
// Codex (only reads last system message as instructions).
|
||||
func TestSingleSystemMessage(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nTest agent.",
|
||||
"AGENT.md": "# Agent\nTest agent.",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
@@ -202,10 +202,10 @@ func TestMtimeAutoInvalidation(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "bootstrap file change",
|
||||
file: "IDENTITY.md",
|
||||
contentV1: "# Original Identity",
|
||||
contentV2: "# Updated Identity",
|
||||
checkField: "Updated Identity",
|
||||
file: "AGENT.md",
|
||||
contentV1: "# Original Agent",
|
||||
contentV2: "# Updated Agent",
|
||||
checkField: "Updated Agent",
|
||||
},
|
||||
{
|
||||
name: "memory file change",
|
||||
@@ -280,7 +280,7 @@ func TestMtimeAutoInvalidation(t *testing.T) {
|
||||
// even when source files haven't changed (useful for tests and reload commands).
|
||||
func TestExplicitInvalidateCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Test Identity",
|
||||
"AGENT.md": "# Test Agent",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
@@ -307,8 +307,8 @@ func TestExplicitInvalidateCache(t *testing.T) {
|
||||
// when no files change (regression test for issue #607).
|
||||
func TestCacheStability(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nContent",
|
||||
"SOUL.md": "# Soul\nContent",
|
||||
"AGENT.md": "# Agent\nContent",
|
||||
"SOUL.md": "# Soul\nContent",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
@@ -607,7 +607,7 @@ description: delete-me-v1
|
||||
// Run with: go test -race ./pkg/agent/ -run TestConcurrentBuildSystemPromptWithCache
|
||||
func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nConcurrency test agent.",
|
||||
"AGENT.md": "# Agent\nConcurrency test agent.",
|
||||
"SOUL.md": "# Soul\nBe helpful.",
|
||||
"memory/MEMORY.md": "# Memory\nUser prefers Go.",
|
||||
"skills/demo/SKILL.md": "---\nname: demo\ndescription: \"demo skill\"\n---\n# Demo",
|
||||
@@ -714,7 +714,7 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
|
||||
|
||||
os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755)
|
||||
os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755)
|
||||
for _, name := range []string{"IDENTITY.md", "SOUL.md", "USER.md"} {
|
||||
for _, name := range []string{"AGENT.md", "SOUL.md"} {
|
||||
os.WriteFile(filepath.Join(tmpDir, name), []byte(strings.Repeat("Content.\n", 10)), 0o644)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gomarkdown/markdown/parser"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// AgentDefinitionSource identifies which agent bootstrap file produced the definition.
|
||||
type AgentDefinitionSource string
|
||||
|
||||
const (
|
||||
// AgentDefinitionSourceAgent indicates the new AGENT.md format.
|
||||
AgentDefinitionSourceAgent AgentDefinitionSource = "AGENT.md"
|
||||
// AgentDefinitionSourceAgents indicates the legacy AGENTS.md format.
|
||||
AgentDefinitionSourceAgents AgentDefinitionSource = "AGENTS.md"
|
||||
)
|
||||
|
||||
// AgentFrontmatter holds machine-readable AGENT.md configuration.
|
||||
//
|
||||
// Known fields are exposed directly for convenience. Fields keeps the full
|
||||
// parsed frontmatter so future refactors can read additional keys without
|
||||
// changing the loader contract again.
|
||||
type AgentFrontmatter struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tools []string `json:"tools,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTurns *int `json:"maxTurns,omitempty"`
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
MCPServers []string `json:"mcpServers,omitempty"`
|
||||
Fields map[string]any `json:"fields,omitempty"`
|
||||
}
|
||||
|
||||
// AgentPromptDefinition represents the parsed AGENT.md or AGENTS.md prompt file.
|
||||
type AgentPromptDefinition struct {
|
||||
Path string `json:"path"`
|
||||
Raw string `json:"raw"`
|
||||
Body string `json:"body"`
|
||||
RawFrontmatter string `json:"raw_frontmatter,omitempty"`
|
||||
Frontmatter AgentFrontmatter `json:"frontmatter"`
|
||||
}
|
||||
|
||||
// SoulDefinition represents the resolved SOUL.md file linked to the agent.
|
||||
type SoulDefinition struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// UserDefinition represents the resolved USER.md file linked to the workspace.
|
||||
type UserDefinition struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// AgentContextDefinition captures the workspace agent definition in a runtime-friendly shape.
|
||||
type AgentContextDefinition struct {
|
||||
Source AgentDefinitionSource `json:"source,omitempty"`
|
||||
Agent *AgentPromptDefinition `json:"agent,omitempty"`
|
||||
Soul *SoulDefinition `json:"soul,omitempty"`
|
||||
User *UserDefinition `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// LoadAgentDefinition parses the workspace agent bootstrap files.
|
||||
//
|
||||
// It prefers the new AGENT.md format and its paired SOUL.md file. When the
|
||||
// structured files are absent, it falls back to the legacy AGENTS.md layout so
|
||||
// the current runtime can transition incrementally.
|
||||
func (cb *ContextBuilder) LoadAgentDefinition() AgentContextDefinition {
|
||||
return loadAgentDefinition(cb.workspace)
|
||||
}
|
||||
|
||||
func loadAgentDefinition(workspace string) AgentContextDefinition {
|
||||
definition := AgentContextDefinition{}
|
||||
definition.User = loadUserDefinition(workspace)
|
||||
agentPath := filepath.Join(workspace, string(AgentDefinitionSourceAgent))
|
||||
if content, err := os.ReadFile(agentPath); err == nil {
|
||||
prompt := parseAgentPromptDefinition(agentPath, string(content))
|
||||
definition.Source = AgentDefinitionSourceAgent
|
||||
definition.Agent = &prompt
|
||||
soulPath := filepath.Join(workspace, "SOUL.md")
|
||||
if content, err := os.ReadFile(soulPath); err == nil {
|
||||
definition.Soul = &SoulDefinition{
|
||||
Path: soulPath,
|
||||
Content: string(content),
|
||||
}
|
||||
}
|
||||
return definition
|
||||
}
|
||||
|
||||
legacyPath := filepath.Join(workspace, string(AgentDefinitionSourceAgents))
|
||||
if content, err := os.ReadFile(legacyPath); err == nil {
|
||||
definition.Source = AgentDefinitionSourceAgents
|
||||
definition.Agent = &AgentPromptDefinition{
|
||||
Path: legacyPath,
|
||||
Raw: string(content),
|
||||
Body: string(content),
|
||||
}
|
||||
}
|
||||
|
||||
defaultSoulPath := filepath.Join(workspace, "SOUL.md")
|
||||
if definition.Source != "" || fileExists(defaultSoulPath) {
|
||||
if content, err := os.ReadFile(defaultSoulPath); err == nil {
|
||||
definition.Soul = &SoulDefinition{
|
||||
Path: defaultSoulPath,
|
||||
Content: string(content),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return definition
|
||||
}
|
||||
|
||||
func (definition AgentContextDefinition) trackedPaths(workspace string) []string {
|
||||
paths := []string{
|
||||
filepath.Join(workspace, string(AgentDefinitionSourceAgent)),
|
||||
filepath.Join(workspace, "SOUL.md"),
|
||||
filepath.Join(workspace, "USER.md"),
|
||||
}
|
||||
if definition.Source != AgentDefinitionSourceAgent {
|
||||
paths = append(paths,
|
||||
filepath.Join(workspace, string(AgentDefinitionSourceAgents)),
|
||||
filepath.Join(workspace, "IDENTITY.md"),
|
||||
)
|
||||
}
|
||||
return uniquePaths(paths)
|
||||
}
|
||||
|
||||
func loadUserDefinition(workspace string) *UserDefinition {
|
||||
userPath := filepath.Join(workspace, "USER.md")
|
||||
if content, err := os.ReadFile(userPath); err == nil {
|
||||
return &UserDefinition{
|
||||
Path: userPath,
|
||||
Content: string(content),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAgentPromptDefinition(path, content string) AgentPromptDefinition {
|
||||
frontmatter, body := splitAgentFrontmatter(content)
|
||||
return AgentPromptDefinition{
|
||||
Path: path,
|
||||
Raw: content,
|
||||
Body: body,
|
||||
RawFrontmatter: frontmatter,
|
||||
Frontmatter: parseAgentFrontmatter(path, frontmatter),
|
||||
}
|
||||
}
|
||||
|
||||
func parseAgentFrontmatter(path, frontmatter string) AgentFrontmatter {
|
||||
frontmatter = strings.TrimSpace(frontmatter)
|
||||
if frontmatter == "" {
|
||||
return AgentFrontmatter{}
|
||||
}
|
||||
|
||||
rawFields := make(map[string]any)
|
||||
if err := yaml.Unmarshal([]byte(frontmatter), &rawFields); err != nil {
|
||||
logger.WarnCF("agent", "Failed to parse AGENT.md frontmatter", map[string]any{
|
||||
"path": path,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return AgentFrontmatter{}
|
||||
}
|
||||
|
||||
var typed struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Tools []string `yaml:"tools"`
|
||||
Model string `yaml:"model"`
|
||||
MaxTurns *int `yaml:"maxTurns"`
|
||||
Skills []string `yaml:"skills"`
|
||||
MCPServers []string `yaml:"mcpServers"`
|
||||
}
|
||||
if err := yaml.Unmarshal([]byte(frontmatter), &typed); err != nil {
|
||||
logger.WarnCF("agent", "Failed to decode AGENT.md frontmatter fields", map[string]any{
|
||||
"path": path,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return AgentFrontmatter{}
|
||||
}
|
||||
|
||||
return AgentFrontmatter{
|
||||
Name: strings.TrimSpace(typed.Name),
|
||||
Description: strings.TrimSpace(typed.Description),
|
||||
Tools: append([]string(nil), typed.Tools...),
|
||||
Model: strings.TrimSpace(typed.Model),
|
||||
MaxTurns: typed.MaxTurns,
|
||||
Skills: append([]string(nil), typed.Skills...),
|
||||
MCPServers: append([]string(nil), typed.MCPServers...),
|
||||
Fields: rawFields,
|
||||
}
|
||||
}
|
||||
|
||||
func splitAgentFrontmatter(content string) (frontmatter, body string) {
|
||||
normalized := string(parser.NormalizeNewlines([]byte(content)))
|
||||
lines := strings.Split(normalized, "\n")
|
||||
if len(lines) == 0 || lines[0] != "---" {
|
||||
return "", content
|
||||
}
|
||||
|
||||
end := -1
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "---" {
|
||||
end = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if end == -1 {
|
||||
return "", content
|
||||
}
|
||||
|
||||
frontmatter = strings.Join(lines[1:end], "\n")
|
||||
body = strings.Join(lines[end+1:], "\n")
|
||||
body = strings.TrimLeft(body, "\n")
|
||||
return frontmatter, body
|
||||
}
|
||||
|
||||
func relativeWorkspacePath(workspace, path string) string {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return ""
|
||||
}
|
||||
relativePath, err := filepath.Rel(workspace, path)
|
||||
if err == nil && relativePath != "." && !strings.HasPrefix(relativePath, "..") {
|
||||
return filepath.ToSlash(relativePath)
|
||||
}
|
||||
return filepath.Clean(path)
|
||||
}
|
||||
|
||||
func uniquePaths(paths []string) []string {
|
||||
result := make([]string, 0, len(paths))
|
||||
for _, path := range paths {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
continue
|
||||
}
|
||||
cleaned := filepath.Clean(path)
|
||||
if slices.Contains(result, cleaned) {
|
||||
continue
|
||||
}
|
||||
result = append(result, cleaned)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadAgentDefinitionParsesFrontmatterAndSoul(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": `---
|
||||
name: pico
|
||||
description: Structured agent
|
||||
model: claude-3-7-sonnet
|
||||
tools:
|
||||
- shell
|
||||
- search
|
||||
maxTurns: 8
|
||||
skills:
|
||||
- review
|
||||
- search-docs
|
||||
mcpServers:
|
||||
- github
|
||||
metadata:
|
||||
mode: strict
|
||||
---
|
||||
# Agent
|
||||
|
||||
Act directly and use tools first.
|
||||
`,
|
||||
"SOUL.md": "# Soul\nStay precise.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.Source != AgentDefinitionSourceAgent {
|
||||
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgent, definition.Source)
|
||||
}
|
||||
if definition.Agent == nil {
|
||||
t.Fatal("expected AGENT.md definition to be loaded")
|
||||
}
|
||||
if definition.Agent.Body == "" || !strings.Contains(definition.Agent.Body, "Act directly") {
|
||||
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Name != "pico" {
|
||||
t.Fatalf("expected name to be parsed, got %q", definition.Agent.Frontmatter.Name)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Model != "claude-3-7-sonnet" {
|
||||
t.Fatalf("expected model to be parsed, got %q", definition.Agent.Frontmatter.Model)
|
||||
}
|
||||
if len(definition.Agent.Frontmatter.Tools) != 2 {
|
||||
t.Fatalf("expected tools to be parsed, got %v", definition.Agent.Frontmatter.Tools)
|
||||
}
|
||||
if definition.Agent.Frontmatter.MaxTurns == nil || *definition.Agent.Frontmatter.MaxTurns != 8 {
|
||||
t.Fatalf("expected maxTurns to be parsed, got %v", definition.Agent.Frontmatter.MaxTurns)
|
||||
}
|
||||
if len(definition.Agent.Frontmatter.Skills) != 2 {
|
||||
t.Fatalf("expected skills to be parsed, got %v", definition.Agent.Frontmatter.Skills)
|
||||
}
|
||||
if len(definition.Agent.Frontmatter.MCPServers) != 1 || definition.Agent.Frontmatter.MCPServers[0] != "github" {
|
||||
t.Fatalf("expected mcpServers to be parsed, got %v", definition.Agent.Frontmatter.MCPServers)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Fields["metadata"] == nil {
|
||||
t.Fatal("expected arbitrary frontmatter fields to remain available")
|
||||
}
|
||||
|
||||
if definition.Soul == nil {
|
||||
t.Fatal("expected SOUL.md to be loaded")
|
||||
}
|
||||
if !strings.Contains(definition.Soul.Content, "Stay precise") {
|
||||
t.Fatalf("expected soul content to be loaded, got %q", definition.Soul.Content)
|
||||
}
|
||||
if definition.Soul.Path != filepath.Join(tmpDir, "SOUL.md") {
|
||||
t.Fatalf("expected default SOUL.md path, got %q", definition.Soul.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAgentDefinitionFallsBackToLegacyAgentsMarkdown(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENTS.md": "# Legacy Agent\nKeep compatibility.",
|
||||
"SOUL.md": "# Soul\nLegacy soul.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.Source != AgentDefinitionSourceAgents {
|
||||
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgents, definition.Source)
|
||||
}
|
||||
if definition.Agent == nil {
|
||||
t.Fatal("expected AGENTS.md to be loaded")
|
||||
}
|
||||
if definition.Agent.RawFrontmatter != "" {
|
||||
t.Fatalf("legacy AGENTS.md should not have frontmatter, got %q", definition.Agent.RawFrontmatter)
|
||||
}
|
||||
if !strings.Contains(definition.Agent.Body, "Keep compatibility") {
|
||||
t.Fatalf("expected legacy body to be preserved, got %q", definition.Agent.Body)
|
||||
}
|
||||
if definition.Soul == nil || !strings.Contains(definition.Soul.Content, "Legacy soul") {
|
||||
t.Fatal("expected default SOUL.md to be loaded for legacy format")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAgentDefinitionLoadsWorkspaceUserMarkdown(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nStructured agent.",
|
||||
"USER.md": "# User\nWorkspace preferences.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.User == nil {
|
||||
t.Fatal("expected USER.md to be loaded")
|
||||
}
|
||||
if definition.User.Path != filepath.Join(tmpDir, "USER.md") {
|
||||
t.Fatalf("expected workspace USER.md path, got %q", definition.User.Path)
|
||||
}
|
||||
if !strings.Contains(definition.User.Content, "Workspace preferences") {
|
||||
t.Fatalf("expected workspace USER.md content, got %q", definition.User.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAgentDefinitionInvalidFrontmatterFallsBackToEmptyStructuredFields(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": `---
|
||||
name: pico
|
||||
tools:
|
||||
- shell
|
||||
broken
|
||||
---
|
||||
# Agent
|
||||
|
||||
Keep going.
|
||||
`,
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.Agent == nil {
|
||||
t.Fatal("expected AGENT.md definition to be loaded")
|
||||
}
|
||||
if !strings.Contains(definition.Agent.Body, "Keep going.") {
|
||||
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Name != "" ||
|
||||
definition.Agent.Frontmatter.Description != "" ||
|
||||
definition.Agent.Frontmatter.Model != "" ||
|
||||
definition.Agent.Frontmatter.MaxTurns != nil ||
|
||||
len(definition.Agent.Frontmatter.Tools) != 0 ||
|
||||
len(definition.Agent.Frontmatter.Skills) != 0 ||
|
||||
len(definition.Agent.Frontmatter.MCPServers) != 0 ||
|
||||
len(definition.Agent.Frontmatter.Fields) != 0 {
|
||||
t.Fatalf("expected invalid frontmatter to decode as empty struct, got %+v", definition.Agent.Frontmatter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBootstrapFilesUsesAgentBodyNotFrontmatter(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": `---
|
||||
name: pico
|
||||
model: codex-mini
|
||||
---
|
||||
# Agent
|
||||
|
||||
Follow the body prompt.
|
||||
`,
|
||||
"SOUL.md": "# Soul\nSpeak plainly.",
|
||||
"IDENTITY.md": "# Identity\nWorkspace identity.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
bootstrap := cb.LoadBootstrapFiles()
|
||||
|
||||
if !strings.Contains(bootstrap, "Follow the body prompt") {
|
||||
t.Fatalf("expected AGENT.md body in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
if !strings.Contains(bootstrap, "Speak plainly") {
|
||||
t.Fatalf("expected resolved soul content in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
if strings.Contains(bootstrap, "name: pico") {
|
||||
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
|
||||
}
|
||||
if strings.Contains(bootstrap, "model: codex-mini") {
|
||||
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
|
||||
}
|
||||
if !strings.Contains(bootstrap, "SOUL.md") {
|
||||
t.Fatalf("expected bootstrap to label SOUL.md, got %q", bootstrap)
|
||||
}
|
||||
if strings.Contains(bootstrap, "Workspace identity") {
|
||||
t.Fatalf("structured bootstrap should ignore IDENTITY.md, got %q", bootstrap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBootstrapFilesIncludesWorkspaceUserMarkdown(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nFollow the new structure.",
|
||||
"SOUL.md": "# Soul\nSpeak plainly.",
|
||||
"USER.md": "# User\nShared profile.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
bootstrap := cb.LoadBootstrapFiles()
|
||||
|
||||
if !strings.Contains(bootstrap, "Shared profile") {
|
||||
t.Fatalf("expected workspace USER.md in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
if !strings.Contains(bootstrap, "## USER.md") {
|
||||
t.Fatalf("expected USER.md heading in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructuredAgentIgnoresIdentityChanges(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nFollow the new structure.",
|
||||
"SOUL.md": "# Soul\nVersion one.",
|
||||
"IDENTITY.md": "# Identity\nLegacy identity.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
|
||||
promptV1 := cb.BuildSystemPromptWithCache()
|
||||
if strings.Contains(promptV1, "Legacy identity") {
|
||||
t.Fatalf("structured prompt should not include IDENTITY.md, got %q", promptV1)
|
||||
}
|
||||
|
||||
identityPath := filepath.Join(tmpDir, "IDENTITY.md")
|
||||
if err := os.WriteFile(identityPath, []byte("# Identity\nVersion two."), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(identityPath, future, future); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if changed {
|
||||
t.Fatal("IDENTITY.md should not invalidate cache for structured agent definitions")
|
||||
}
|
||||
|
||||
promptV2 := cb.BuildSystemPromptWithCache()
|
||||
if promptV1 != promptV2 {
|
||||
t.Fatal("structured prompt should remain stable after IDENTITY.md changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructuredAgentUserChangesInvalidateCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nFollow the new structure.",
|
||||
"SOUL.md": "# Soul\nVersion one.",
|
||||
"USER.md": "# User\nInitial workspace preferences.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
|
||||
promptV1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(promptV1, "Initial workspace preferences") {
|
||||
t.Fatalf("expected workspace USER.md in prompt, got %q", promptV1)
|
||||
}
|
||||
|
||||
userPath := filepath.Join(tmpDir, "USER.md")
|
||||
if err := os.WriteFile(userPath, []byte("# User\nUpdated workspace preferences."), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(userPath, future, future); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("workspace USER.md changes should invalidate cache")
|
||||
}
|
||||
|
||||
promptV2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(promptV2, "Updated workspace preferences") {
|
||||
t.Fatalf("expected updated workspace USER.md in prompt, got %q", promptV2)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupWorkspace(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
t.Fatalf("failed to clean up workspace %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultEventSubscriberBuffer = 16
|
||||
|
||||
// EventSubscription identifies a subscriber channel returned by EventBus.Subscribe.
|
||||
type EventSubscription struct {
|
||||
ID uint64
|
||||
C <-chan Event
|
||||
}
|
||||
|
||||
type eventSubscriber struct {
|
||||
ch chan Event
|
||||
}
|
||||
|
||||
// EventBus is a lightweight multi-subscriber broadcaster for agent-loop events.
|
||||
type EventBus struct {
|
||||
mu sync.RWMutex
|
||||
subs map[uint64]eventSubscriber
|
||||
nextID uint64
|
||||
closed bool
|
||||
dropped [eventKindCount]atomic.Int64
|
||||
}
|
||||
|
||||
// NewEventBus creates a new in-process event broadcaster.
|
||||
func NewEventBus() *EventBus {
|
||||
return &EventBus{
|
||||
subs: make(map[uint64]eventSubscriber),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers a new subscriber with the requested channel buffer size.
|
||||
// A non-positive buffer uses the default size.
|
||||
func (b *EventBus) Subscribe(buffer int) EventSubscription {
|
||||
if buffer <= 0 {
|
||||
buffer = defaultEventSubscriberBuffer
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
ch := make(chan Event)
|
||||
close(ch)
|
||||
return EventSubscription{C: ch}
|
||||
}
|
||||
|
||||
b.nextID++
|
||||
id := b.nextID
|
||||
ch := make(chan Event, buffer)
|
||||
b.subs[id] = eventSubscriber{ch: ch}
|
||||
return EventSubscription{ID: id, C: ch}
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscriber and closes its channel.
|
||||
func (b *EventBus) Unsubscribe(id uint64) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
sub, ok := b.subs[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(b.subs, id)
|
||||
close(sub.ch)
|
||||
}
|
||||
|
||||
// Emit broadcasts an event to all current subscribers without blocking.
|
||||
// When a subscriber channel is full, the event is dropped for that subscriber.
|
||||
func (b *EventBus) Emit(evt Event) {
|
||||
if evt.Time.IsZero() {
|
||||
evt.Time = time.Now()
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range b.subs {
|
||||
select {
|
||||
case sub.ch <- evt:
|
||||
default:
|
||||
if evt.Kind < eventKindCount {
|
||||
b.dropped[evt.Kind].Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dropped returns the number of dropped events for a given kind.
|
||||
func (b *EventBus) Dropped(kind EventKind) int64 {
|
||||
if kind >= eventKindCount {
|
||||
return 0
|
||||
}
|
||||
return b.dropped[kind].Load()
|
||||
}
|
||||
|
||||
// Close closes all subscriber channels and stops future broadcasts.
|
||||
func (b *EventBus) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
b.closed = true
|
||||
for id, sub := range b.subs {
|
||||
close(sub.ch)
|
||||
delete(b.subs, id)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package agent
|
||||
|
||||
import "fmt"
|
||||
|
||||
// MockEventBus - for POC
|
||||
var MockEventBus = struct {
|
||||
Emit func(event any)
|
||||
}{
|
||||
Emit: func(event any) {
|
||||
fmt.Printf("[Mock EventBus] %T %+v\n", event, event)
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,684 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
|
||||
eventBus := NewEventBus()
|
||||
sub := eventBus.Subscribe(1)
|
||||
|
||||
eventBus.Emit(Event{
|
||||
Kind: EventKindTurnStart,
|
||||
Meta: EventMeta{TurnID: "turn-1"},
|
||||
})
|
||||
|
||||
select {
|
||||
case evt := <-sub.C:
|
||||
if evt.Kind != EventKindTurnStart {
|
||||
t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
|
||||
}
|
||||
if evt.Meta.TurnID != "turn-1" {
|
||||
t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
}
|
||||
|
||||
eventBus.Unsubscribe(sub.ID)
|
||||
if _, ok := <-sub.C; ok {
|
||||
t.Fatal("expected subscriber channel to be closed after unsubscribe")
|
||||
}
|
||||
|
||||
eventBus.Close()
|
||||
closedSub := eventBus.Subscribe(1)
|
||||
if _, ok := <-closedSub.C; ok {
|
||||
t.Fatal("expected closed bus to return a closed subscriber channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
|
||||
eventBus := NewEventBus()
|
||||
sub := eventBus.Subscribe(1)
|
||||
defer eventBus.Unsubscribe(sub.ID)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
eventBus.Emit(Event{Kind: EventKindLLMRequest})
|
||||
}
|
||||
|
||||
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
|
||||
t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
|
||||
}
|
||||
|
||||
if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
|
||||
t.Fatalf("expected 999 dropped events, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
type scriptedToolProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *scriptedToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
toolDefs []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "mock_custom",
|
||||
Arguments: map[string]any{"task": "ping"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "done",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *scriptedToolProvider) GetDefaultModel() string {
|
||||
return "scripted-tool-model"
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &scriptedToolProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(&mockCustomTool{})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if response != "done" {
|
||||
t.Fatalf("expected final response 'done', got %q", response)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
if len(events) != 8 {
|
||||
t.Fatalf("expected 8 events, got %d", len(events))
|
||||
}
|
||||
|
||||
kinds := make([]EventKind, 0, len(events))
|
||||
for _, evt := range events {
|
||||
kinds = append(kinds, evt.Kind)
|
||||
}
|
||||
|
||||
expectedKinds := []EventKind{
|
||||
EventKindTurnStart,
|
||||
EventKindLLMRequest,
|
||||
EventKindLLMResponse,
|
||||
EventKindToolExecStart,
|
||||
EventKindToolExecEnd,
|
||||
EventKindLLMRequest,
|
||||
EventKindLLMResponse,
|
||||
EventKindTurnEnd,
|
||||
}
|
||||
if !slices.Equal(kinds, expectedKinds) {
|
||||
t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
|
||||
}
|
||||
|
||||
turnID := events[0].Meta.TurnID
|
||||
for i, evt := range events {
|
||||
if evt.Meta.TurnID != turnID {
|
||||
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
|
||||
}
|
||||
if evt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
startPayload, ok := events[0].Payload.(TurnStartPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnStartPayload, got %T", events[0].Payload)
|
||||
}
|
||||
if startPayload.UserMessage != "run tool" {
|
||||
t.Fatalf("expected user message 'run tool', got %q", startPayload.UserMessage)
|
||||
}
|
||||
|
||||
toolStartPayload, ok := events[3].Payload.(ToolExecStartPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecStartPayload, got %T", events[3].Payload)
|
||||
}
|
||||
if toolStartPayload.Tool != "mock_custom" {
|
||||
t.Fatalf("expected tool name mock_custom, got %q", toolStartPayload.Tool)
|
||||
}
|
||||
|
||||
toolEndPayload, ok := events[4].Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", events[4].Payload)
|
||||
}
|
||||
if toolEndPayload.Tool != "mock_custom" {
|
||||
t.Fatalf("expected tool end payload for mock_custom, got %q", toolEndPayload.Tool)
|
||||
}
|
||||
if toolEndPayload.IsError {
|
||||
t.Fatal("expected mock_custom tool to succeed")
|
||||
}
|
||||
|
||||
turnEndPayload, ok := events[len(events)-1].Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", events[len(events)-1].Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusCompleted {
|
||||
t.Fatalf("expected completed turn, got %q", turnEndPayload.Status)
|
||||
}
|
||||
if turnEndPayload.Iterations != 2 {
|
||||
t.Fatalf("expected 2 iterations, got %d", turnEndPayload.Iterations)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-steering-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool1ExecCh := make(chan struct{})
|
||||
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
||||
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
||||
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "tool_one",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_one",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Type: "function",
|
||||
Name: "tool_two",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_two",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "steered response",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resultCh := make(chan string, 1)
|
||||
go func() {
|
||||
resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
|
||||
resultCh <- resp
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tool1ExecCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for tool_one to start")
|
||||
}
|
||||
|
||||
if err := al.Steer(providers.Message{Role: "user", Content: "change course"}); err != nil {
|
||||
t.Fatalf("Steer failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case resp := <-resultCh:
|
||||
if resp != "steered response" {
|
||||
t.Fatalf("expected steered response, got %q", resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for steered response")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
|
||||
if !ok {
|
||||
t.Fatal("expected steering injected event")
|
||||
}
|
||||
steeringPayload, ok := steeringEvt.Payload.(SteeringInjectedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected SteeringInjectedPayload, got %T", steeringEvt.Payload)
|
||||
}
|
||||
if steeringPayload.Count != 1 {
|
||||
t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
|
||||
}
|
||||
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected skipped tool event")
|
||||
}
|
||||
skippedPayload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if skippedPayload.Tool != "tool_two" {
|
||||
t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
|
||||
}
|
||||
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Role != "user" {
|
||||
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindSteering {
|
||||
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
|
||||
}
|
||||
if interruptPayload.ContentLen != len("change course") {
|
||||
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-compress-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
||||
provider := &failFirstMockProvider{
|
||||
failures: 1,
|
||||
failError: contextErr,
|
||||
successResp: "Recovered from context error",
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
|
||||
{Role: "user", Content: "Old message 1"},
|
||||
{Role: "assistant", Content: "Old response 1"},
|
||||
{Role: "user", Content: "Old message 2"},
|
||||
{Role: "assistant", Content: "Old response 2"},
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "Trigger message",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "Recovered from context error" {
|
||||
t.Fatalf("expected retry success, got %q", resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
retryEvt, ok := findEvent(events, EventKindLLMRetry)
|
||||
if !ok {
|
||||
t.Fatal("expected llm retry event")
|
||||
}
|
||||
retryPayload, ok := retryEvt.Payload.(LLMRetryPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected LLMRetryPayload, got %T", retryEvt.Payload)
|
||||
}
|
||||
if retryPayload.Reason != "context_limit" {
|
||||
t.Fatalf("expected context_limit retry reason, got %q", retryPayload.Reason)
|
||||
}
|
||||
if retryPayload.Attempt != 1 {
|
||||
t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
|
||||
}
|
||||
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
payload, ok := compressEvt.Payload.(ContextCompressPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
|
||||
}
|
||||
if payload.Reason != ContextCompressReasonRetry {
|
||||
t.Fatalf("expected retry compress reason, got %q", payload.Reason)
|
||||
}
|
||||
if payload.DroppedMessages == 0 {
|
||||
t.Fatal("expected dropped messages to be recorded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-summary-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextWindow: 8000,
|
||||
SummarizeMessageThreshold: 2,
|
||||
SummarizeTokenPercent: 75,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary text"})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
|
||||
{Role: "user", Content: "Question one"},
|
||||
{Role: "assistant", Content: "Answer one"},
|
||||
{Role: "user", Content: "Question two"},
|
||||
{Role: "assistant", Content: "Answer two"},
|
||||
{Role: "user", Content: "Question three"},
|
||||
{Role: "assistant", Content: "Answer three"},
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1")
|
||||
al.summarizeSession(defaultAgent, "session-1", turnScope)
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
|
||||
if !ok {
|
||||
t.Fatal("expected session summarize event")
|
||||
}
|
||||
payload, ok := summaryEvt.Payload.(SessionSummarizePayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected SessionSummarizePayload, got %T", summaryEvt.Payload)
|
||||
}
|
||||
if payload.SummaryLen == 0 {
|
||||
t.Fatal("expected non-empty summary length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-followup-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_async_1",
|
||||
Type: "function",
|
||||
Name: "async_followup",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "async_followup",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "async launched",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
doneCh := make(chan struct{})
|
||||
al.RegisterTool(&asyncFollowUpTool{
|
||||
name: "async_followup",
|
||||
followUpText: "background result",
|
||||
completionSig: doneCh,
|
||||
})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run async tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "async launched" {
|
||||
t.Fatalf("expected final response 'async launched', got %q", resp)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for async tool completion")
|
||||
}
|
||||
|
||||
followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
|
||||
return evt.Kind == EventKindFollowUpQueued
|
||||
})
|
||||
payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected FollowUpQueuedPayload, got %T", followUpEvt.Payload)
|
||||
}
|
||||
if payload.SourceTool != "async_followup" {
|
||||
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
|
||||
}
|
||||
if payload.Channel != "cli" {
|
||||
t.Fatalf("expected channel cli, got %q", payload.Channel)
|
||||
}
|
||||
if payload.ChatID != "direct" {
|
||||
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
|
||||
}
|
||||
if payload.ContentLen != len("background result") {
|
||||
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
|
||||
}
|
||||
if followUpEvt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
|
||||
}
|
||||
if followUpEvt.Meta.TurnID == "" {
|
||||
t.Fatal("expected follow-up event to include turn id")
|
||||
}
|
||||
}
|
||||
|
||||
func collectEventStream(ch <-chan Event) []Event {
|
||||
var events []Event
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
return events
|
||||
}
|
||||
events = append(events, evt)
|
||||
default:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
|
||||
t.Helper()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("event stream closed before expected event arrived")
|
||||
}
|
||||
if match(evt) {
|
||||
return evt
|
||||
}
|
||||
case <-timer.C:
|
||||
t.Fatal("timed out waiting for expected event")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findEvent(events []Event, kind EventKind) (Event, bool) {
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
return evt, true
|
||||
}
|
||||
}
|
||||
return Event{}, false
|
||||
}
|
||||
|
||||
type stringError string
|
||||
|
||||
func (e stringError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
type asyncFollowUpTool struct {
|
||||
name string
|
||||
followUpText string
|
||||
completionSig chan struct{}
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Description() string {
|
||||
return "async follow-up tool for testing"
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.AsyncResult("async follow-up scheduled")
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) ExecuteAsync(
|
||||
ctx context.Context,
|
||||
args map[string]any,
|
||||
cb tools.AsyncCallback,
|
||||
) *tools.ToolResult {
|
||||
go func() {
|
||||
cb(ctx, &tools.ToolResult{ForLLM: t.followUpText})
|
||||
if t.completionSig != nil {
|
||||
close(t.completionSig)
|
||||
}
|
||||
}()
|
||||
return tools.AsyncResult("async follow-up scheduled")
|
||||
}
|
||||
|
||||
var (
|
||||
_ tools.Tool = (*mockCustomTool)(nil)
|
||||
_ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
|
||||
)
|
||||
@@ -0,0 +1,271 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EventKind identifies a structured agent-loop event.
|
||||
type EventKind uint8
|
||||
|
||||
const (
|
||||
// EventKindTurnStart is emitted when a turn begins processing.
|
||||
EventKindTurnStart EventKind = iota
|
||||
// EventKindTurnEnd is emitted when a turn finishes, successfully or with an error.
|
||||
EventKindTurnEnd
|
||||
// EventKindLLMRequest is emitted before a provider chat request is made.
|
||||
EventKindLLMRequest
|
||||
// EventKindLLMDelta is emitted when a streaming provider yields a partial delta.
|
||||
EventKindLLMDelta
|
||||
// EventKindLLMResponse is emitted after a provider chat response is received.
|
||||
EventKindLLMResponse
|
||||
// EventKindLLMRetry is emitted when an LLM request is retried.
|
||||
EventKindLLMRetry
|
||||
// EventKindContextCompress is emitted when session history is forcibly compressed.
|
||||
EventKindContextCompress
|
||||
// EventKindSessionSummarize is emitted when asynchronous summarization completes.
|
||||
EventKindSessionSummarize
|
||||
// EventKindToolExecStart is emitted immediately before a tool executes.
|
||||
EventKindToolExecStart
|
||||
// EventKindToolExecEnd is emitted immediately after a tool finishes executing.
|
||||
EventKindToolExecEnd
|
||||
// EventKindToolExecSkipped is emitted when a queued tool call is skipped.
|
||||
EventKindToolExecSkipped
|
||||
// EventKindSteeringInjected is emitted when queued steering is injected into context.
|
||||
EventKindSteeringInjected
|
||||
// EventKindFollowUpQueued is emitted when an async tool queues a follow-up system message.
|
||||
EventKindFollowUpQueued
|
||||
// EventKindInterruptReceived is emitted when a soft interrupt message is accepted.
|
||||
EventKindInterruptReceived
|
||||
// EventKindSubTurnSpawn is emitted when a sub-turn is spawned.
|
||||
EventKindSubTurnSpawn
|
||||
// EventKindSubTurnEnd is emitted when a sub-turn finishes.
|
||||
EventKindSubTurnEnd
|
||||
// EventKindSubTurnResultDelivered is emitted when a sub-turn result is delivered.
|
||||
EventKindSubTurnResultDelivered
|
||||
// EventKindSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
|
||||
EventKindSubTurnOrphan
|
||||
// EventKindError is emitted when a turn encounters an execution error.
|
||||
EventKindError
|
||||
|
||||
eventKindCount
|
||||
)
|
||||
|
||||
var eventKindNames = [...]string{
|
||||
"turn_start",
|
||||
"turn_end",
|
||||
"llm_request",
|
||||
"llm_delta",
|
||||
"llm_response",
|
||||
"llm_retry",
|
||||
"context_compress",
|
||||
"session_summarize",
|
||||
"tool_exec_start",
|
||||
"tool_exec_end",
|
||||
"tool_exec_skipped",
|
||||
"steering_injected",
|
||||
"follow_up_queued",
|
||||
"interrupt_received",
|
||||
"subturn_spawn",
|
||||
"subturn_end",
|
||||
"subturn_result_delivered",
|
||||
"subturn_orphan",
|
||||
"error",
|
||||
}
|
||||
|
||||
// String returns the stable string form of an EventKind.
|
||||
func (k EventKind) String() string {
|
||||
if k >= eventKindCount {
|
||||
return fmt.Sprintf("event_kind(%d)", k)
|
||||
}
|
||||
return eventKindNames[k]
|
||||
}
|
||||
|
||||
// Event is the structured envelope broadcast by the agent EventBus.
|
||||
type Event struct {
|
||||
Kind EventKind
|
||||
Time time.Time
|
||||
Meta EventMeta
|
||||
Payload any
|
||||
}
|
||||
|
||||
// EventMeta contains correlation fields shared by all agent-loop events.
|
||||
type EventMeta struct {
|
||||
AgentID string
|
||||
TurnID string
|
||||
ParentTurnID string
|
||||
SessionKey string
|
||||
Iteration int
|
||||
TracePath string
|
||||
Source string
|
||||
}
|
||||
|
||||
// TurnEndStatus describes the terminal state of a turn.
|
||||
type TurnEndStatus string
|
||||
|
||||
const (
|
||||
// TurnEndStatusCompleted indicates the turn finished normally.
|
||||
TurnEndStatusCompleted TurnEndStatus = "completed"
|
||||
// TurnEndStatusError indicates the turn ended because of an error.
|
||||
TurnEndStatusError TurnEndStatus = "error"
|
||||
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
|
||||
TurnEndStatusAborted TurnEndStatus = "aborted"
|
||||
)
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
type TurnStartPayload struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
MediaCount int
|
||||
}
|
||||
|
||||
// TurnEndPayload describes the completion of a turn.
|
||||
type TurnEndPayload struct {
|
||||
Status TurnEndStatus
|
||||
Iterations int
|
||||
Duration time.Duration
|
||||
FinalContentLen int
|
||||
}
|
||||
|
||||
// LLMRequestPayload describes an outbound LLM request.
|
||||
type LLMRequestPayload struct {
|
||||
Model string
|
||||
MessagesCount int
|
||||
ToolsCount int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
}
|
||||
|
||||
// LLMResponsePayload describes an inbound LLM response.
|
||||
type LLMResponsePayload struct {
|
||||
ContentLen int
|
||||
ToolCalls int
|
||||
HasReasoning bool
|
||||
}
|
||||
|
||||
// LLMDeltaPayload describes a streamed LLM delta.
|
||||
type LLMDeltaPayload struct {
|
||||
ContentDeltaLen int
|
||||
ReasoningDeltaLen int
|
||||
}
|
||||
|
||||
// LLMRetryPayload describes a retry of an LLM request.
|
||||
type LLMRetryPayload struct {
|
||||
Attempt int
|
||||
MaxRetries int
|
||||
Reason string
|
||||
Error string
|
||||
Backoff time.Duration
|
||||
}
|
||||
|
||||
// ContextCompressReason identifies why emergency compression ran.
|
||||
type ContextCompressReason string
|
||||
|
||||
const (
|
||||
// ContextCompressReasonProactive indicates compression before the first LLM call.
|
||||
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
|
||||
// ContextCompressReasonRetry indicates compression during context-error retry handling.
|
||||
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
|
||||
)
|
||||
|
||||
// ContextCompressPayload describes a forced history compression.
|
||||
type ContextCompressPayload struct {
|
||||
Reason ContextCompressReason
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// SessionSummarizePayload describes a completed async session summarization.
|
||||
type SessionSummarizePayload struct {
|
||||
SummarizedMessages int
|
||||
KeptMessages int
|
||||
SummaryLen int
|
||||
OmittedOversized bool
|
||||
}
|
||||
|
||||
// ToolExecStartPayload describes a tool execution request.
|
||||
type ToolExecStartPayload struct {
|
||||
Tool string
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// ToolExecEndPayload describes the outcome of a tool execution.
|
||||
type ToolExecEndPayload struct {
|
||||
Tool string
|
||||
Duration time.Duration
|
||||
ForLLMLen int
|
||||
ForUserLen int
|
||||
IsError bool
|
||||
Async bool
|
||||
}
|
||||
|
||||
// ToolExecSkippedPayload describes a skipped tool call.
|
||||
type ToolExecSkippedPayload struct {
|
||||
Tool string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
|
||||
type SteeringInjectedPayload struct {
|
||||
Count int
|
||||
TotalContentLen int
|
||||
}
|
||||
|
||||
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
|
||||
type FollowUpQueuedPayload struct {
|
||||
SourceTool string
|
||||
Channel string
|
||||
ChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
type InterruptKind string
|
||||
|
||||
const (
|
||||
InterruptKindSteering InterruptKind = "steering"
|
||||
InterruptKindGraceful InterruptKind = "graceful"
|
||||
InterruptKindHard InterruptKind = "hard_abort"
|
||||
)
|
||||
|
||||
// InterruptReceivedPayload describes accepted turn-control input.
|
||||
type InterruptReceivedPayload struct {
|
||||
Kind InterruptKind
|
||||
Role string
|
||||
ContentLen int
|
||||
QueueDepth int
|
||||
HintLen int
|
||||
}
|
||||
|
||||
// SubTurnSpawnPayload describes the creation of a child turn.
|
||||
type SubTurnSpawnPayload struct {
|
||||
AgentID string
|
||||
Label string
|
||||
ParentTurnID string
|
||||
}
|
||||
|
||||
// SubTurnEndPayload describes the completion of a child turn.
|
||||
type SubTurnEndPayload struct {
|
||||
AgentID string
|
||||
Status string
|
||||
}
|
||||
|
||||
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
|
||||
type SubTurnResultDeliveredPayload struct {
|
||||
TargetChannel string
|
||||
TargetChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
|
||||
type SubTurnOrphanPayload struct {
|
||||
ParentTurnID string
|
||||
ChildTurnID string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// ErrorPayload describes an execution error inside the agent loop.
|
||||
type ErrorPayload struct {
|
||||
Stage string
|
||||
Message string
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type hookRuntime struct {
|
||||
initOnce sync.Once
|
||||
mu sync.Mutex
|
||||
initErr error
|
||||
mounted []string
|
||||
}
|
||||
|
||||
func (r *hookRuntime) setInitErr(err error) {
|
||||
r.mu.Lock()
|
||||
r.initErr = err
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *hookRuntime) getInitErr() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.initErr
|
||||
}
|
||||
|
||||
func (r *hookRuntime) setMounted(names []string) {
|
||||
r.mu.Lock()
|
||||
r.mounted = append([]string(nil), names...)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *hookRuntime) reset(al *AgentLoop) {
|
||||
r.mu.Lock()
|
||||
names := append([]string(nil), r.mounted...)
|
||||
r.mounted = nil
|
||||
r.initErr = nil
|
||||
r.initOnce = sync.Once{}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, name := range names {
|
||||
al.UnmountHook(name)
|
||||
}
|
||||
}
|
||||
|
||||
// BuiltinHookFactory constructs an in-process hook from config.
|
||||
type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error)
|
||||
|
||||
var (
|
||||
builtinHookRegistryMu sync.RWMutex
|
||||
builtinHookRegistry = map[string]BuiltinHookFactory{}
|
||||
)
|
||||
|
||||
// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting.
|
||||
func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("builtin hook name is required")
|
||||
}
|
||||
if factory == nil {
|
||||
return fmt.Errorf("builtin hook %q factory is nil", name)
|
||||
}
|
||||
|
||||
builtinHookRegistryMu.Lock()
|
||||
defer builtinHookRegistryMu.Unlock()
|
||||
|
||||
if _, exists := builtinHookRegistry[name]; exists {
|
||||
return fmt.Errorf("builtin hook %q is already registered", name)
|
||||
}
|
||||
builtinHookRegistry[name] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
func unregisterBuiltinHook(name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
builtinHookRegistryMu.Lock()
|
||||
delete(builtinHookRegistry, name)
|
||||
builtinHookRegistryMu.Unlock()
|
||||
}
|
||||
|
||||
func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) {
|
||||
builtinHookRegistryMu.RLock()
|
||||
defer builtinHookRegistryMu.RUnlock()
|
||||
|
||||
factory, ok := builtinHookRegistry[name]
|
||||
return factory, ok
|
||||
}
|
||||
|
||||
func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) {
|
||||
if hm == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
hm.ConfigureTimeouts(
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS),
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS),
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS),
|
||||
)
|
||||
}
|
||||
|
||||
func hookTimeoutFromMS(ms int) time.Duration {
|
||||
if ms <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error {
|
||||
if al == nil || al.cfg == nil || al.hooks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
al.hookRuntime.initOnce.Do(func() {
|
||||
al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx))
|
||||
})
|
||||
|
||||
return al.hookRuntime.getInitErr()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) {
|
||||
if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
mounted := make([]string, 0)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, name := range mounted {
|
||||
al.UnmountHook(name)
|
||||
}
|
||||
return
|
||||
}
|
||||
al.hookRuntime.setMounted(mounted)
|
||||
}()
|
||||
|
||||
builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins)
|
||||
for _, name := range builtinNames {
|
||||
spec := al.cfg.Hooks.Builtins[name]
|
||||
factory, ok := lookupBuiltinHook(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("builtin hook %q is not registered", name)
|
||||
}
|
||||
|
||||
hook, factoryErr := factory(ctx, spec)
|
||||
if factoryErr != nil {
|
||||
return fmt.Errorf("build builtin hook %q: %w", name, factoryErr)
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Priority: spec.Priority,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: hook,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("mount builtin hook %q: %w", name, err)
|
||||
}
|
||||
mounted = append(mounted, name)
|
||||
}
|
||||
|
||||
processNames := enabledProcessHookNames(al.cfg.Hooks.Processes)
|
||||
for _, name := range processNames {
|
||||
spec := al.cfg.Hooks.Processes[name]
|
||||
opts, buildErr := processHookOptionsFromConfig(spec)
|
||||
if buildErr != nil {
|
||||
return fmt.Errorf("configure process hook %q: %w", name, buildErr)
|
||||
}
|
||||
|
||||
processHook, buildErr := NewProcessHook(ctx, name, opts)
|
||||
if buildErr != nil {
|
||||
return fmt.Errorf("start process hook %q: %w", name, buildErr)
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Priority: spec.Priority,
|
||||
Source: HookSourceProcess,
|
||||
Hook: processHook,
|
||||
}); err != nil {
|
||||
_ = processHook.Close()
|
||||
return fmt.Errorf("mount process hook %q: %w", name, err)
|
||||
}
|
||||
mounted = append(mounted, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string {
|
||||
if len(specs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(specs))
|
||||
for name, spec := range specs {
|
||||
if spec.Enabled {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string {
|
||||
if len(specs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(specs))
|
||||
for name, spec := range specs {
|
||||
if spec.Enabled {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) {
|
||||
transport := spec.Transport
|
||||
if transport == "" {
|
||||
transport = "stdio"
|
||||
}
|
||||
if transport != "stdio" {
|
||||
return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport)
|
||||
}
|
||||
if len(spec.Command) == 0 {
|
||||
return ProcessHookOptions{}, fmt.Errorf("command is required")
|
||||
}
|
||||
|
||||
opts := ProcessHookOptions{
|
||||
Command: append([]string(nil), spec.Command...),
|
||||
Dir: spec.Dir,
|
||||
Env: processHookEnvFromMap(spec.Env),
|
||||
}
|
||||
|
||||
observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe)
|
||||
if err != nil {
|
||||
return ProcessHookOptions{}, err
|
||||
}
|
||||
opts.Observe = observeEnabled
|
||||
opts.ObserveKinds = observeKinds
|
||||
|
||||
for _, intercept := range spec.Intercept {
|
||||
switch intercept {
|
||||
case "before_llm", "after_llm":
|
||||
opts.InterceptLLM = true
|
||||
case "before_tool", "after_tool":
|
||||
opts.InterceptTool = true
|
||||
case "approve_tool":
|
||||
opts.ApproveTool = true
|
||||
case "":
|
||||
continue
|
||||
default:
|
||||
return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept)
|
||||
}
|
||||
}
|
||||
|
||||
if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool {
|
||||
return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled")
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func processHookEnvFromMap(envMap map[string]string) []string {
|
||||
if len(envMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(envMap))
|
||||
for key := range envMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
env := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
env = append(env, key+"="+envMap[key])
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) {
|
||||
if len(observe) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
validKinds := validHookEventKinds()
|
||||
normalized := make([]string, 0, len(observe))
|
||||
for _, kind := range observe {
|
||||
switch kind {
|
||||
case "", "*", "all":
|
||||
return nil, true, nil
|
||||
default:
|
||||
if _, ok := validKinds[kind]; !ok {
|
||||
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
|
||||
}
|
||||
normalized = append(normalized, kind)
|
||||
}
|
||||
}
|
||||
|
||||
if len(normalized) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
return normalized, true, nil
|
||||
}
|
||||
|
||||
func validHookEventKinds() map[string]struct{} {
|
||||
kinds := make(map[string]struct{}, int(eventKindCount))
|
||||
for kind := EventKind(0); kind < eventKindCount; kind++ {
|
||||
kinds[kind.String()] = struct{}{}
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type builtinAutoHookConfig struct {
|
||||
Model string `json:"model"`
|
||||
Suffix string `json:"suffix"`
|
||||
}
|
||||
|
||||
type builtinAutoHook struct {
|
||||
model string
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (h *builtinAutoHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = h.model
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *builtinAutoHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
next := resp.Clone()
|
||||
if next.Response != nil {
|
||||
next.Response.Content += h.suffix
|
||||
}
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Hooks: hooks,
|
||||
}
|
||||
|
||||
return NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) {
|
||||
const hookName = "test-auto-builtin-hook"
|
||||
|
||||
if err := RegisterBuiltinHook(hookName, func(
|
||||
ctx context.Context,
|
||||
spec config.BuiltinHookConfig,
|
||||
) (any, error) {
|
||||
var hookCfg builtinAutoHookConfig
|
||||
if len(spec.Config) > 0 {
|
||||
if err := json.Unmarshal(spec.Config, &hookCfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &builtinAutoHook{
|
||||
model: hookCfg.Model,
|
||||
suffix: hookCfg.Suffix,
|
||||
}, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("RegisterBuiltinHook failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
unregisterBuiltinHook(hookName)
|
||||
})
|
||||
|
||||
rawCfg, err := json.Marshal(builtinAutoHookConfig{
|
||||
Model: "builtin-model",
|
||||
Suffix: "|builtin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
provider := &llmHookTestProvider{}
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Builtins: map[string]config.BuiltinHookConfig{
|
||||
hookName: {
|
||||
Enabled: true,
|
||||
Config: rawCfg,
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|builtin" {
|
||||
t.Fatalf("expected builtin-hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "builtin-model" {
|
||||
t.Fatalf("expected builtin model, got %q", lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
eventLog := filepath.Join(t.TempDir(), "events.log")
|
||||
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Processes: map[string]config.ProcessHookConfig{
|
||||
"ipc-auto": {
|
||||
Enabled: true,
|
||||
Command: processHookHelperCommand(),
|
||||
Env: map[string]string{
|
||||
"PICOCLAW_HOOK_HELPER": "1",
|
||||
"PICOCLAW_HOOK_MODE": "rewrite",
|
||||
"PICOCLAW_HOOK_EVENT_LOG": eventLog,
|
||||
},
|
||||
Observe: []string{"turn_end"},
|
||||
Intercept: []string{"before_llm", "after_llm"},
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Processes: map[string]config.ProcessHookConfig{
|
||||
"bad-hook": {
|
||||
Enabled: true,
|
||||
Command: processHookHelperCommand(),
|
||||
Intercept: []string{"not_supported"},
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid configured hook error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
processHookJSONRPCVersion = "2.0"
|
||||
processHookReadBufferSize = 1024 * 1024
|
||||
processHookCloseTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type ProcessHookOptions struct {
|
||||
Command []string
|
||||
Dir string
|
||||
Env []string
|
||||
Observe bool
|
||||
ObserveKinds []string
|
||||
InterceptLLM bool
|
||||
InterceptTool bool
|
||||
ApproveTool bool
|
||||
}
|
||||
|
||||
type ProcessHook struct {
|
||||
name string
|
||||
opts ProcessHookOptions
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
observeKinds map[string]struct{}
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pending map[uint64]chan processHookRPCMessage
|
||||
nextID atomic.Uint64
|
||||
|
||||
closed atomic.Bool
|
||||
done chan struct{}
|
||||
closeErr error
|
||||
closeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type processHookRPCMessage struct {
|
||||
JSONRPC string `json:"jsonrpc,omitempty"`
|
||||
ID uint64 `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *processHookRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type processHookRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type processHookHelloParams struct {
|
||||
Name string `json:"name"`
|
||||
Version int `json:"version"`
|
||||
Modes []string `json:"modes,omitempty"`
|
||||
}
|
||||
|
||||
type processHookDecisionResponse struct {
|
||||
Action HookAction `json:"action"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
type processHookBeforeLLMResponse struct {
|
||||
processHookDecisionResponse
|
||||
Request *LLMHookRequest `json:"request,omitempty"`
|
||||
}
|
||||
|
||||
type processHookAfterLLMResponse struct {
|
||||
processHookDecisionResponse
|
||||
Response *LLMHookResponse `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
type processHookBeforeToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
}
|
||||
|
||||
type processHookAfterToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Result *ToolResultHookResponse `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
|
||||
if len(opts.Command) == 0 {
|
||||
return nil, fmt.Errorf("process hook command is required")
|
||||
}
|
||||
|
||||
cmd := exec.Command(opts.Command[0], opts.Command[1:]...)
|
||||
cmd.Dir = opts.Dir
|
||||
if len(opts.Env) > 0 {
|
||||
cmd.Env = append(os.Environ(), opts.Env...)
|
||||
}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stdin: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stdout: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stderr: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start process hook: %w", err)
|
||||
}
|
||||
|
||||
ph := &ProcessHook{
|
||||
name: name,
|
||||
opts: opts,
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
observeKinds: newProcessHookObserveKinds(opts.ObserveKinds),
|
||||
pending: make(map[uint64]chan processHookRPCMessage),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go ph.readLoop(stdout)
|
||||
go ph.readStderr(stderr)
|
||||
go ph.waitLoop()
|
||||
|
||||
helloCtx := ctx
|
||||
if helloCtx == nil {
|
||||
var cancel context.CancelFunc
|
||||
helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
if err := ph.hello(helloCtx); err != nil {
|
||||
_ = ph.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ph, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) Close() error {
|
||||
if ph == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ph.closeOnce.Do(func() {
|
||||
ph.closed.Store(true)
|
||||
if ph.stdin != nil {
|
||||
_ = ph.stdin.Close()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ph.done:
|
||||
case <-time.After(processHookCloseTimeout):
|
||||
if ph.cmd != nil && ph.cmd.Process != nil {
|
||||
_ = ph.cmd.Process.Kill()
|
||||
}
|
||||
<-ph.done
|
||||
}
|
||||
})
|
||||
|
||||
ph.closeMu.Lock()
|
||||
defer ph.closeMu.Unlock()
|
||||
return ph.closeErr
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if ph == nil || !ph.opts.Observe {
|
||||
return nil
|
||||
}
|
||||
if len(ph.observeKinds) > 0 {
|
||||
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ph.notify(ctx, "hook.event", evt)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptLLM {
|
||||
return req, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookBeforeLLMResponse
|
||||
if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Request == nil {
|
||||
resp.Request = req
|
||||
}
|
||||
return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptLLM {
|
||||
return resp, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var result processHookAfterLLMResponse
|
||||
if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if result.Response == nil {
|
||||
result.Response = resp
|
||||
}
|
||||
return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptTool {
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookBeforeToolResponse
|
||||
if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Call == nil {
|
||||
resp.Call = call
|
||||
}
|
||||
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptTool {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookAfterToolResponse
|
||||
if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Result == nil {
|
||||
resp.Result = result
|
||||
}
|
||||
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
if ph == nil || !ph.opts.ApproveTool {
|
||||
return ApprovalDecision{Approved: true}, nil
|
||||
}
|
||||
|
||||
var resp ApprovalDecision
|
||||
if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil {
|
||||
return ApprovalDecision{}, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) hello(ctx context.Context) error {
|
||||
modes := make([]string, 0, 4)
|
||||
if ph.opts.Observe {
|
||||
modes = append(modes, "observe")
|
||||
}
|
||||
if ph.opts.InterceptLLM {
|
||||
modes = append(modes, "llm")
|
||||
}
|
||||
if ph.opts.InterceptTool {
|
||||
modes = append(modes, "tool")
|
||||
}
|
||||
if ph.opts.ApproveTool {
|
||||
modes = append(modes, "approve")
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
return ph.call(ctx, "hook.hello", processHookHelloParams{
|
||||
Name: ph.name,
|
||||
Version: 1,
|
||||
Modes: modes,
|
||||
}, &result)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error {
|
||||
msg := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
Method: method,
|
||||
}
|
||||
if params != nil {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg.Params = body
|
||||
}
|
||||
return ph.send(ctx, msg)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error {
|
||||
if ph.closed.Load() {
|
||||
return fmt.Errorf("process hook %q is closed", ph.name)
|
||||
}
|
||||
|
||||
id := ph.nextID.Add(1)
|
||||
respCh := make(chan processHookRPCMessage, 1)
|
||||
ph.pendingMu.Lock()
|
||||
ph.pending[id] = respCh
|
||||
ph.pendingMu.Unlock()
|
||||
|
||||
msg := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
ID: id,
|
||||
Method: method,
|
||||
}
|
||||
if params != nil {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
ph.removePending(id)
|
||||
return err
|
||||
}
|
||||
msg.Params = body
|
||||
}
|
||||
|
||||
if err := ph.send(ctx, msg); err != nil {
|
||||
ph.removePending(id)
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case resp, ok := <-respCh:
|
||||
if !ok {
|
||||
return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message)
|
||||
}
|
||||
if out != nil && len(resp.Result) > 0 {
|
||||
if err := json.Unmarshal(resp.Result, out); err != nil {
|
||||
return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
ph.removePending(id)
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = append(body, '\n')
|
||||
|
||||
ph.writeMu.Lock()
|
||||
defer ph.writeMu.Unlock()
|
||||
|
||||
if ph.closed.Load() {
|
||||
return fmt.Errorf("process hook %q is closed", ph.name)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, writeErr := ph.stdin.Write(body)
|
||||
done <- writeErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("write process hook %q message: %w", ph.name, err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) readLoop(stdout io.Reader) {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
var msg processHookRPCMessage
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{
|
||||
"hook": ph.name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if msg.ID == 0 {
|
||||
continue
|
||||
}
|
||||
ph.pendingMu.Lock()
|
||||
respCh, ok := ph.pending[msg.ID]
|
||||
if ok {
|
||||
delete(ph.pending, msg.ID)
|
||||
}
|
||||
ph.pendingMu.Unlock()
|
||||
if ok {
|
||||
respCh <- msg
|
||||
close(respCh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) readStderr(stderr io.Reader) {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize)
|
||||
for scanner.Scan() {
|
||||
logger.WarnCF("hooks", "Process hook stderr", map[string]any{
|
||||
"hook": ph.name,
|
||||
"stderr": scanner.Text(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) waitLoop() {
|
||||
err := ph.cmd.Wait()
|
||||
ph.closeMu.Lock()
|
||||
ph.closeErr = err
|
||||
ph.closeMu.Unlock()
|
||||
ph.failPending(err)
|
||||
close(ph.done)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) failPending(err error) {
|
||||
ph.pendingMu.Lock()
|
||||
defer ph.pendingMu.Unlock()
|
||||
|
||||
msg := processHookRPCMessage{
|
||||
Error: &processHookRPCError{
|
||||
Code: -32000,
|
||||
Message: "process exited",
|
||||
},
|
||||
}
|
||||
if err != nil {
|
||||
msg.Error.Message = err.Error()
|
||||
}
|
||||
|
||||
for id, ch := range ph.pending {
|
||||
delete(ph.pending, id)
|
||||
ch <- msg
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) removePending(id uint64) {
|
||||
ph.pendingMu.Lock()
|
||||
defer ph.pendingMu.Unlock()
|
||||
|
||||
if ch, ok := ph.pending[id]; ok {
|
||||
delete(ph.pending, id)
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
|
||||
if al == nil {
|
||||
return fmt.Errorf("agent loop is nil")
|
||||
}
|
||||
processHook, err := NewProcessHook(ctx, name, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Source: HookSourceProcess,
|
||||
Hook: processHook,
|
||||
}); err != nil {
|
||||
_ = processHook.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newProcessHookObserveKinds(kinds []string) map[string]struct{} {
|
||||
if len(kinds) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make(map[string]struct{}, len(kinds))
|
||||
for _, kind := range kinds {
|
||||
if kind == "" {
|
||||
continue
|
||||
}
|
||||
normalized[kind] = struct{}{}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestProcessHook_HelperProcess(t *testing.T) {
|
||||
if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" {
|
||||
return
|
||||
}
|
||||
if err := runProcessHookHelper(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
eventLog := filepath.Join(t.TempDir(), "events.log")
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("rewrite", eventLog),
|
||||
Observe: true,
|
||||
InterceptLLM: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked llm content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("rewrite", ""),
|
||||
InterceptTool: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "ipc:ipc" {
|
||||
t.Fatalf("expected rewritten process-hook tool result, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type blockedToolProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (p *blockedToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.calls++
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "blocked_tool",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: messages[len(messages)-1].Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *blockedToolProvider) GetDefaultModel() string {
|
||||
return "blocked-tool-provider"
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
|
||||
provider := &blockedToolProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("deny", ""),
|
||||
ApproveTool: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run blocked tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "Tool execution denied by approval hook: blocked by ipc hook"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if payload.Reason != expected {
|
||||
t.Fatalf("expected reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func processHookHelperCommand() []string {
|
||||
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
|
||||
}
|
||||
|
||||
func processHookHelperEnv(mode, eventLog string) []string {
|
||||
env := []string{
|
||||
"PICOCLAW_HOOK_HELPER=1",
|
||||
"PICOCLAW_HOOK_MODE=" + mode,
|
||||
}
|
||||
if eventLog != "" {
|
||||
env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func waitForFileContains(t *testing.T, path, substring string) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err == nil && strings.Contains(string(data), substring) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data))
|
||||
}
|
||||
|
||||
func runProcessHookHelper() error {
|
||||
mode := os.Getenv("PICOCLAW_HOOK_MODE")
|
||||
eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG")
|
||||
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
|
||||
for scanner.Scan() {
|
||||
var msg processHookRPCMessage
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.ID == 0 {
|
||||
if msg.Method == "hook.event" && eventLog != "" {
|
||||
var evt map[string]any
|
||||
if err := json.Unmarshal(msg.Params, &evt); err == nil {
|
||||
if rawKind, ok := evt["Kind"].(float64); ok {
|
||||
kind := EventKind(rawKind)
|
||||
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
result, rpcErr := handleProcessHookRequest(mode, msg)
|
||||
resp := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
ID: msg.ID,
|
||||
}
|
||||
if rpcErr != nil {
|
||||
resp.Error = rpcErr
|
||||
} else if result != nil {
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Result = body
|
||||
} else {
|
||||
resp.Result = []byte("{}")
|
||||
}
|
||||
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) {
|
||||
switch msg.Method {
|
||||
case "hook.hello":
|
||||
return map[string]any{"ok": true}, nil
|
||||
case "hook.before_llm":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var req map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &req)
|
||||
req["model"] = "process-model"
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"request": req,
|
||||
}, nil
|
||||
case "hook.after_llm":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var resp map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &resp)
|
||||
if rawResponse, ok := resp["response"].(map[string]any); ok {
|
||||
if content, ok := rawResponse["content"].(string); ok {
|
||||
rawResponse["content"] = content + "|ipc"
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"response": resp,
|
||||
}, nil
|
||||
case "hook.before_tool":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var call map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &call)
|
||||
rawArgs, ok := call["arguments"].(map[string]any)
|
||||
if !ok || rawArgs == nil {
|
||||
rawArgs = map[string]any{}
|
||||
}
|
||||
rawArgs["text"] = "ipc"
|
||||
call["arguments"] = rawArgs
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"call": call,
|
||||
}, nil
|
||||
case "hook.after_tool":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &result)
|
||||
if rawResult, ok := result["result"].(map[string]any); ok {
|
||||
if forLLM, ok := rawResult["for_llm"].(string); ok {
|
||||
rawResult["for_llm"] = "ipc:" + forLLM
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"result": result,
|
||||
}, nil
|
||||
case "hook.approve_tool":
|
||||
if mode == "deny" {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked by ipc hook",
|
||||
}, nil
|
||||
}
|
||||
return ApprovalDecision{Approved: true}, nil
|
||||
default:
|
||||
return nil, &processHookRPCError{
|
||||
Code: -32601,
|
||||
Message: "method not found",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,809 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHookObserverTimeout = 500 * time.Millisecond
|
||||
defaultHookInterceptorTimeout = 5 * time.Second
|
||||
defaultHookApprovalTimeout = 60 * time.Second
|
||||
hookObserverBufferSize = 64
|
||||
)
|
||||
|
||||
type HookAction string
|
||||
|
||||
const (
|
||||
HookActionContinue HookAction = "continue"
|
||||
HookActionModify HookAction = "modify"
|
||||
HookActionDenyTool HookAction = "deny_tool"
|
||||
HookActionAbortTurn HookAction = "abort_turn"
|
||||
HookActionHardAbort HookAction = "hard_abort"
|
||||
)
|
||||
|
||||
type HookDecision struct {
|
||||
Action HookAction `json:"action"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
func (d HookDecision) normalizedAction() HookAction {
|
||||
if d.Action == "" {
|
||||
return HookActionContinue
|
||||
}
|
||||
return d.Action
|
||||
}
|
||||
|
||||
type ApprovalDecision struct {
|
||||
Approved bool `json:"approved"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
type HookSource uint8
|
||||
|
||||
const (
|
||||
HookSourceInProcess HookSource = iota
|
||||
HookSourceProcess
|
||||
)
|
||||
|
||||
type HookRegistration struct {
|
||||
Name string
|
||||
Priority int
|
||||
Source HookSource
|
||||
Hook any
|
||||
}
|
||||
|
||||
func NamedHook(name string, hook any) HookRegistration {
|
||||
return HookRegistration{
|
||||
Name: name,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: hook,
|
||||
}
|
||||
}
|
||||
|
||||
type EventObserver interface {
|
||||
OnEvent(ctx context.Context, evt Event) error
|
||||
}
|
||||
|
||||
type LLMInterceptor interface {
|
||||
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
|
||||
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolInterceptor interface {
|
||||
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
|
||||
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolApprover interface {
|
||||
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
|
||||
}
|
||||
|
||||
type LLMHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Model string `json:"model"`
|
||||
Messages []providers.Message `json:"messages,omitempty"`
|
||||
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Messages = cloneProviderMessages(r.Messages)
|
||||
cloned.Tools = cloneToolDefinitions(r.Tools)
|
||||
cloned.Options = cloneStringAnyMap(r.Options)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type LLMHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Model string `json:"model"`
|
||||
Response *providers.LLMResponse `json:"response,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Response = cloneLLMResponse(r.Response)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolApprovalRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolResultHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.Result = cloneToolResult(r.Result)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type HookManager struct {
|
||||
eventBus *EventBus
|
||||
observerTimeout time.Duration
|
||||
interceptorTimeout time.Duration
|
||||
approvalTimeout time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
hooks map[string]HookRegistration
|
||||
ordered []HookRegistration
|
||||
|
||||
sub EventSubscription
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHookManager(eventBus *EventBus) *HookManager {
|
||||
hm := &HookManager{
|
||||
eventBus: eventBus,
|
||||
observerTimeout: defaultHookObserverTimeout,
|
||||
interceptorTimeout: defaultHookInterceptorTimeout,
|
||||
approvalTimeout: defaultHookApprovalTimeout,
|
||||
hooks: make(map[string]HookRegistration),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if eventBus == nil {
|
||||
close(hm.done)
|
||||
return hm
|
||||
}
|
||||
|
||||
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
|
||||
go hm.dispatchEvents()
|
||||
return hm
|
||||
}
|
||||
|
||||
func (hm *HookManager) Close() {
|
||||
if hm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hm.closeOnce.Do(func() {
|
||||
if hm.eventBus != nil {
|
||||
hm.eventBus.Unsubscribe(hm.sub.ID)
|
||||
}
|
||||
<-hm.done
|
||||
hm.closeAllHooks()
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
|
||||
if hm == nil {
|
||||
return
|
||||
}
|
||||
if observer > 0 {
|
||||
hm.observerTimeout = observer
|
||||
}
|
||||
if interceptor > 0 {
|
||||
hm.interceptorTimeout = interceptor
|
||||
}
|
||||
if approval > 0 {
|
||||
hm.approvalTimeout = approval
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) Mount(reg HookRegistration) error {
|
||||
if hm == nil {
|
||||
return fmt.Errorf("hook manager is nil")
|
||||
}
|
||||
if reg.Name == "" {
|
||||
return fmt.Errorf("hook name is required")
|
||||
}
|
||||
if reg.Hook == nil {
|
||||
return fmt.Errorf("hook %q is nil", reg.Name)
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if existing, ok := hm.hooks[reg.Name]; ok {
|
||||
closeHookIfPossible(existing.Hook)
|
||||
}
|
||||
hm.hooks[reg.Name] = reg
|
||||
hm.rebuildOrdered()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) Unmount(name string) {
|
||||
if hm == nil || name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if existing, ok := hm.hooks[name]; ok {
|
||||
closeHookIfPossible(existing.Hook)
|
||||
}
|
||||
delete(hm.hooks, name)
|
||||
hm.rebuildOrdered()
|
||||
}
|
||||
|
||||
func (hm *HookManager) dispatchEvents() {
|
||||
defer close(hm.done)
|
||||
|
||||
for evt := range hm.sub.C {
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
observer, ok := reg.Hook.(EventObserver)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hm.runObserver(reg.Name, observer, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
|
||||
if hm == nil || req == nil {
|
||||
return req, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := req.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
|
||||
if hm == nil || resp == nil {
|
||||
return resp, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := resp.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision) {
|
||||
if hm == nil || call == nil {
|
||||
return call, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := call.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision) {
|
||||
if hm == nil || result == nil {
|
||||
return result, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := result.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
|
||||
if hm == nil || req == nil {
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
approver, ok := reg.Hook.(ToolApprover)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
|
||||
if !ok {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
|
||||
}
|
||||
}
|
||||
if !decision.Approved {
|
||||
return decision
|
||||
}
|
||||
}
|
||||
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
func (hm *HookManager) rebuildOrdered() {
|
||||
hm.ordered = hm.ordered[:0]
|
||||
for _, reg := range hm.hooks {
|
||||
hm.ordered = append(hm.ordered, reg)
|
||||
}
|
||||
sort.SliceStable(hm.ordered, func(i, j int) bool {
|
||||
if hm.ordered[i].Source != hm.ordered[j].Source {
|
||||
return hm.ordered[i].Source < hm.ordered[j].Source
|
||||
}
|
||||
if hm.ordered[i].Priority == hm.ordered[j].Priority {
|
||||
return hm.ordered[i].Name < hm.ordered[j].Name
|
||||
}
|
||||
return hm.ordered[i].Priority < hm.ordered[j].Priority
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) snapshotHooks() []HookRegistration {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
snapshot := make([]HookRegistration, len(hm.ordered))
|
||||
copy(snapshot, hm.ordered)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (hm *HookManager) closeAllHooks() {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
for name, reg := range hm.hooks {
|
||||
closeHookIfPossible(reg.Hook)
|
||||
delete(hm.hooks, name)
|
||||
}
|
||||
hm.ordered = nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- observer.OnEvent(ctx, evt)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Event observer failed", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"timeout_ms": hm.observerTimeout.Milliseconds(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_llm",
|
||||
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeLLM(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_llm",
|
||||
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterLLM(ctx, resp)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_tool",
|
||||
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeTool(ctx, call)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
resultView *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_tool",
|
||||
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterTool(ctx, resultView)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callApproveTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
approver ToolApprover,
|
||||
req *ToolApprovalRequest,
|
||||
) (ApprovalDecision, bool) {
|
||||
return runApprovalHook(
|
||||
parent,
|
||||
hm.approvalTimeout,
|
||||
name,
|
||||
"approve_tool",
|
||||
func(ctx context.Context) (ApprovalDecision, error) {
|
||||
return approver.ApproveTool(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func runInterceptorHook[T any](
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (T, HookDecision, error),
|
||||
) (T, HookDecision, bool) {
|
||||
var zero T
|
||||
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
value T
|
||||
decision HookDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
value, decision, err := fn(ctx)
|
||||
done <- result{value: value, decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
return res.value, res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func runApprovalHook(
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (ApprovalDecision, error),
|
||||
) (ApprovalDecision, bool) {
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
decision ApprovalDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
decision, err := fn(ctx)
|
||||
done <- result{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return ApprovalDecision{}, false
|
||||
}
|
||||
return res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
|
||||
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"action": action,
|
||||
})
|
||||
}
|
||||
|
||||
func cloneProviderMessages(messages []providers.Message) []providers.Message {
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.Message, len(messages))
|
||||
for i, msg := range messages {
|
||||
cloned[i] = msg
|
||||
if len(msg.Media) > 0 {
|
||||
cloned[i].Media = append([]string(nil), msg.Media...)
|
||||
}
|
||||
if len(msg.SystemParts) > 0 {
|
||||
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolCall, len(calls))
|
||||
for i, call := range calls {
|
||||
cloned[i] = call
|
||||
if call.Function != nil {
|
||||
fn := *call.Function
|
||||
cloned[i].Function = &fn
|
||||
}
|
||||
if call.Arguments != nil {
|
||||
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
|
||||
}
|
||||
if call.ExtraContent != nil {
|
||||
extra := *call.ExtraContent
|
||||
if call.ExtraContent.Google != nil {
|
||||
google := *call.ExtraContent.Google
|
||||
extra.Google = &google
|
||||
}
|
||||
cloned[i].ExtraContent = &extra
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolDefinition, len(defs))
|
||||
for i, def := range defs {
|
||||
cloned[i] = def
|
||||
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *resp
|
||||
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
|
||||
if len(resp.ReasoningDetails) > 0 {
|
||||
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
usage := *resp.Usage
|
||||
cloned.Usage = &usage
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneStringAnyMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
cloned[k] = v
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *result
|
||||
if len(result.Media) > 0 {
|
||||
cloned.Media = append([]string(nil), result.Media...)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func closeHookIfPossible(hook any) {
|
||||
closer, ok := hook.(io.Closer)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := closer.Close(); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to close hook", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func newHookTestLoop(
|
||||
t *testing.T,
|
||||
provider providers.LLMProvider,
|
||||
) (*AgentLoop, *AgentInstance, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
return al, agent, func() {
|
||||
al.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
defer hm.Close()
|
||||
|
||||
if err := hm.Mount(HookRegistration{
|
||||
Name: "process",
|
||||
Priority: -10,
|
||||
Source: HookSourceProcess,
|
||||
Hook: struct{}{},
|
||||
}); err != nil {
|
||||
t.Fatalf("mount process hook: %v", err)
|
||||
}
|
||||
if err := hm.Mount(HookRegistration{
|
||||
Name: "in-process",
|
||||
Priority: 100,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: struct{}{},
|
||||
}); err != nil {
|
||||
t.Fatalf("mount in-process hook: %v", err)
|
||||
}
|
||||
|
||||
ordered := hm.snapshotHooks()
|
||||
if len(ordered) != 2 {
|
||||
t.Fatalf("expected 2 hooks, got %d", len(ordered))
|
||||
}
|
||||
if ordered[0].Name != "in-process" {
|
||||
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
|
||||
}
|
||||
if ordered[1].Name != "process" {
|
||||
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
type llmHookTestProvider struct {
|
||||
mu sync.Mutex
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.lastModel = model
|
||||
p.mu.Unlock()
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "provider content",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
return "llm-hook-provider"
|
||||
}
|
||||
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if evt.Kind == EventKindTurnEnd {
|
||||
select {
|
||||
case h.eventCh <- evt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = "hook-model"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
next := resp.Clone()
|
||||
next.Response.Content = "hooked content"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
||||
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "hooked content" {
|
||||
t.Fatalf("expected hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "hook-model" {
|
||||
t.Fatalf("expected model hook-model, got %q", lastModel)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.eventCh:
|
||||
if evt.Kind != EventKindTurnEnd {
|
||||
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for hook observer event")
|
||||
}
|
||||
}
|
||||
|
||||
type toolHookProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.calls++
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "echo_text",
|
||||
Arguments: map[string]any{"text": "original"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
last := messages[len(messages)-1]
|
||||
return &providers.LLMResponse{
|
||||
Content: last.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) GetDefaultModel() string {
|
||||
return "tool-hook-provider"
|
||||
}
|
||||
|
||||
type echoTextTool struct{}
|
||||
|
||||
func (t *echoTextTool) Name() string {
|
||||
return "echo_text"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Description() string {
|
||||
return "echo a text argument"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
text, _ := args["text"].(string)
|
||||
return tools.SilentResult(text)
|
||||
}
|
||||
|
||||
type toolRewriteHook struct{}
|
||||
|
||||
func (h *toolRewriteHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
next := call.Clone()
|
||||
next.Arguments["text"] = "modified"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *toolRewriteHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
next := result.Clone()
|
||||
next.Result.ForLLM = "after:" + next.Result.ForLLM
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "after:modified" {
|
||||
t.Fatalf("expected rewritten tool result, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type denyApprovalHook struct{}
|
||||
|
||||
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
expected := "Tool execution denied by approval hook: blocked"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if payload.Reason != expected {
|
||||
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
+12
-1
@@ -130,6 +130,17 @@ func NewAgentInstance(
|
||||
maxTokens = 8192
|
||||
}
|
||||
|
||||
contextWindow := defaults.ContextWindow
|
||||
if contextWindow == 0 {
|
||||
// Default heuristic: 4x the output token limit.
|
||||
// Most models have context windows well above their output limits
|
||||
// (e.g., GPT-4o 128k ctx / 16k out, Claude 200k ctx / 8k out).
|
||||
// 4x is a conservative lower bound that avoids premature
|
||||
// summarization while remaining safe — the reactive
|
||||
// forceCompression handles any overshoot.
|
||||
contextWindow = maxTokens * 4
|
||||
}
|
||||
|
||||
temperature := 0.7
|
||||
if defaults.Temperature != nil {
|
||||
temperature = *defaults.Temperature
|
||||
@@ -182,7 +193,7 @@ func NewAgentInstance(
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
ContextWindow: maxTokens,
|
||||
ContextWindow: contextWindow,
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
Provider: provider,
|
||||
|
||||
+1340
-526
File diff suppressed because it is too large
Load Diff
@@ -1078,11 +1078,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Inject some history to simulate a full context
|
||||
// Inject some history to simulate a full context.
|
||||
// Session history only stores user/assistant/tool messages — the system
|
||||
// prompt is built dynamically by BuildMessages and is NOT stored here.
|
||||
sessionKey := "test-session-context"
|
||||
// Create dummy history
|
||||
history := []providers.Message{
|
||||
{Role: "system", Content: "System prompt"},
|
||||
{Role: "user", Content: "Old message 1"},
|
||||
{Role: "assistant", Content: "Old response 1"},
|
||||
{Role: "user", Content: "Old message 2"},
|
||||
@@ -1120,12 +1120,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
// Check final history length
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
// We verify that the history has been modified (compressed)
|
||||
// Original length: 6
|
||||
// Expected behavior: compression drops ~50% of history (mid slice)
|
||||
// We can assert that the length is NOT what it would be without compression.
|
||||
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
|
||||
if len(finalHistory) >= 8 {
|
||||
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
|
||||
// Original length: 5
|
||||
// Expected behavior: compression drops ~50% of Turns
|
||||
// Without compression: 5 + 1 (new user msg) + 1 (assistant msg) = 7
|
||||
if len(finalHistory) >= 7 {
|
||||
t.Errorf("Expected history to be compressed (len < 7), got %d", len(finalHistory))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+253
-69
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -21,6 +22,9 @@ const (
|
||||
SteeringAll SteeringMode = "all"
|
||||
// MaxQueueSize number of possible messages in the Steering Queue
|
||||
MaxQueueSize = 10
|
||||
// manualSteeringScope is the legacy fallback queue used when no active
|
||||
// turn/session scope is available.
|
||||
manualSteeringScope = "__manual__"
|
||||
)
|
||||
|
||||
// parseSteeringMode normalizes a config string into a SteeringMode.
|
||||
@@ -36,56 +40,117 @@ func parseSteeringMode(s string) SteeringMode {
|
||||
// steeringQueue is a thread-safe queue of user messages that can be injected
|
||||
// into a running agent loop to interrupt it between tool calls.
|
||||
type steeringQueue struct {
|
||||
mu sync.Mutex
|
||||
queue []providers.Message
|
||||
mode SteeringMode
|
||||
mu sync.Mutex
|
||||
queues map[string][]providers.Message
|
||||
mode SteeringMode
|
||||
}
|
||||
|
||||
func newSteeringQueue(mode SteeringMode) *steeringQueue {
|
||||
return &steeringQueue{
|
||||
mode: mode,
|
||||
queues: make(map[string][]providers.Message),
|
||||
mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// push enqueues a steering message.
|
||||
func normalizeSteeringScope(scope string) string {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
return manualSteeringScope
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
// push enqueues a steering message in the legacy fallback scope.
|
||||
func (sq *steeringQueue) push(msg providers.Message) error {
|
||||
return sq.pushScope(manualSteeringScope, msg)
|
||||
}
|
||||
|
||||
// pushScope enqueues a steering message for the provided scope.
|
||||
func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
if len(sq.queue) >= MaxQueueSize {
|
||||
|
||||
scope = normalizeSteeringScope(scope)
|
||||
queue := sq.queues[scope]
|
||||
if len(queue) >= MaxQueueSize {
|
||||
return fmt.Errorf("steering queue is full")
|
||||
}
|
||||
sq.queue = append(sq.queue, msg)
|
||||
sq.queues[scope] = append(queue, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// dequeue removes and returns pending steering messages according to the
|
||||
// configured mode. Returns nil when the queue is empty.
|
||||
// dequeue removes and returns pending steering messages from the legacy
|
||||
// fallback scope according to the configured mode.
|
||||
func (sq *steeringQueue) dequeue() []providers.Message {
|
||||
return sq.dequeueScope(manualSteeringScope)
|
||||
}
|
||||
|
||||
// dequeueScope removes and returns pending steering messages for the provided
|
||||
// scope according to the configured mode.
|
||||
func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
if len(sq.queue) == 0 {
|
||||
return sq.dequeueLocked(normalizeSteeringScope(scope))
|
||||
}
|
||||
|
||||
// dequeueScopeWithFallback drains the scoped queue first and falls back to the
|
||||
// legacy manual scope for backwards compatibility.
|
||||
func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope != "" {
|
||||
if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
|
||||
return sq.dequeueLocked(manualSteeringScope)
|
||||
}
|
||||
|
||||
func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
|
||||
queue := sq.queues[scope]
|
||||
if len(queue) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch sq.mode {
|
||||
case SteeringAll:
|
||||
msgs := sq.queue
|
||||
sq.queue = nil
|
||||
msgs := append([]providers.Message(nil), queue...)
|
||||
delete(sq.queues, scope)
|
||||
return msgs
|
||||
default: // one-at-a-time
|
||||
msg := sq.queue[0]
|
||||
sq.queue[0] = providers.Message{} // Clear reference for GC
|
||||
sq.queue = sq.queue[1:]
|
||||
default:
|
||||
msg := queue[0]
|
||||
queue[0] = providers.Message{} // Clear reference for GC
|
||||
queue = queue[1:]
|
||||
if len(queue) == 0 {
|
||||
delete(sq.queues, scope)
|
||||
} else {
|
||||
sq.queues[scope] = queue
|
||||
}
|
||||
return []providers.Message{msg}
|
||||
}
|
||||
}
|
||||
|
||||
// len returns the number of queued messages.
|
||||
// len returns the number of queued messages across all scopes.
|
||||
func (sq *steeringQueue) len() int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return len(sq.queue)
|
||||
|
||||
total := 0
|
||||
for _, queue := range sq.queues {
|
||||
total += len(queue)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// lenScope returns the number of queued messages for a specific scope.
|
||||
func (sq *steeringQueue) lenScope(scope string) int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return len(sq.queues[normalizeSteeringScope(scope)])
|
||||
}
|
||||
|
||||
// setMode updates the steering mode.
|
||||
@@ -102,28 +167,76 @@ func (sq *steeringQueue) getMode() SteeringMode {
|
||||
return sq.mode
|
||||
}
|
||||
|
||||
// --- AgentLoop steering API ---
|
||||
|
||||
// Steer enqueues a user message to be injected into the currently running
|
||||
// agent loop. The message will be picked up after the current tool finishes
|
||||
// executing, causing any remaining tool calls in the batch to be skipped.
|
||||
func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
scope := ""
|
||||
agentID := ""
|
||||
if ts := al.getAnyActiveTurnState(); ts != nil {
|
||||
scope = ts.sessionKey
|
||||
agentID = ts.agentID
|
||||
}
|
||||
return al.enqueueSteeringMessage(scope, agentID, msg)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
|
||||
if al.steering == nil {
|
||||
return fmt.Errorf("steering queue is not initialized")
|
||||
}
|
||||
if err := al.steering.push(msg); err != nil {
|
||||
|
||||
if err := al.steering.pushScope(scope, msg); err != nil {
|
||||
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
|
||||
"error": err.Error(),
|
||||
"role": msg.Role,
|
||||
"scope": normalizeSteeringScope(scope),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
queueDepth := al.steering.lenScope(scope)
|
||||
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
|
||||
"role": msg.Role,
|
||||
"content_len": len(msg.Content),
|
||||
"queue_len": al.steering.len(),
|
||||
"media_count": len(msg.Media),
|
||||
"queue_len": queueDepth,
|
||||
"scope": normalizeSteeringScope(scope),
|
||||
})
|
||||
|
||||
meta := EventMeta{
|
||||
Source: "Steer",
|
||||
TracePath: "turn.interrupt.received",
|
||||
}
|
||||
if ts := al.getAnyActiveTurnState(); ts != nil {
|
||||
meta = ts.eventMeta("Steer", "turn.interrupt.received")
|
||||
} else {
|
||||
if strings.TrimSpace(agentID) != "" {
|
||||
meta.AgentID = agentID
|
||||
}
|
||||
normalizedScope := normalizeSteeringScope(scope)
|
||||
if normalizedScope != manualSteeringScope {
|
||||
meta.SessionKey = normalizedScope
|
||||
}
|
||||
if meta.AgentID == "" {
|
||||
if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
meta.AgentID = agent.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
meta,
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindSteering,
|
||||
Role: msg.Role,
|
||||
ContentLen: len(msg.Content),
|
||||
QueueDepth: queueDepth,
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -144,7 +257,7 @@ func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
|
||||
}
|
||||
|
||||
// dequeueSteeringMessages is the internal method called by the agent loop
|
||||
// to poll for steering messages. Returns nil when no messages are pending.
|
||||
// to poll for steering messages in the legacy fallback scope.
|
||||
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
@@ -152,6 +265,60 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
return al.steering.dequeue()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeueScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeueScopeWithFallback(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
|
||||
if al.steering == nil {
|
||||
return 0
|
||||
}
|
||||
return al.steering.lenScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
sessionKey, channel, chatID string,
|
||||
steeringMsgs []providers.Message,
|
||||
) (string, error) {
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
InitialSteeringMessages: steeringMsgs,
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
registry := al.GetRegistry()
|
||||
if registry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
|
||||
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
||||
return agent
|
||||
}
|
||||
}
|
||||
|
||||
return registry.GetDefaultAgent()
|
||||
}
|
||||
|
||||
// Continue resumes an idle agent by dequeuing any pending steering messages
|
||||
// and running them through the agent loop. This is used when the agent's last
|
||||
// message was from the assistant (i.e., it has stopped processing) and the
|
||||
@@ -159,33 +326,74 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
//
|
||||
// If no steering messages are pending, it returns an empty string.
|
||||
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
|
||||
steeringMsgs := al.dequeueSteeringMessages()
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
||||
}
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
|
||||
if len(steeringMsgs) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
agent := al.agentForSession(sessionKey)
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent available")
|
||||
return "", fmt.Errorf("no agent available for session %q", sessionKey)
|
||||
}
|
||||
|
||||
// Build a combined user message from the steering messages.
|
||||
var contents []string
|
||||
for _, msg := range steeringMsgs {
|
||||
contents = append(contents, msg.Content)
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
|
||||
resetter.ResetSentInRound()
|
||||
}
|
||||
}
|
||||
combinedContent := strings.Join(contents, "\n")
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
UserMessage: combinedContent,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
ts := al.getAnyActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestGracefulInterrupt(hint) {
|
||||
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindGraceful,
|
||||
HintLen: len(hint),
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptHard() error {
|
||||
ts := al.getAnyActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestHardAbort() {
|
||||
return fmt.Errorf("turn %s is already aborting", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindHard,
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================== SubTurn Result Polling ======================
|
||||
@@ -206,7 +414,10 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To
|
||||
var results []*tools.ToolResult
|
||||
for {
|
||||
select {
|
||||
case result := <-ts.pendingResults:
|
||||
case result, ok := <-ts.pendingResults:
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
if result != nil {
|
||||
results = append(results, result)
|
||||
}
|
||||
@@ -249,20 +460,6 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
// Use isHardAbort=true for hard abort to immediately cancel all children.
|
||||
ts.Finish(true)
|
||||
|
||||
// Rollback session history to the state before this turn started.
|
||||
// This must happen AFTER Finish() to ensure no child turns are still writing.
|
||||
if ts.session != nil {
|
||||
currentHistory := ts.session.GetHistory("")
|
||||
if len(currentHistory) > ts.initialHistoryLength {
|
||||
logger.InfoCF("agent", "Rolling back session history", map[string]any{
|
||||
"from": len(currentHistory),
|
||||
"to": ts.initialHistoryLength,
|
||||
})
|
||||
// SetHistory with the truncated slice to rollback
|
||||
ts.session.SetHistory("", currentHistory[:ts.initialHistoryLength])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -291,19 +488,6 @@ func (al *AgentLoop) InjectFollowUp(msg providers.Message) error {
|
||||
|
||||
// ====================== API Aliases for Design Document Compatibility ======================
|
||||
|
||||
// InterruptGraceful is an alias for Steer() to match the design document naming.
|
||||
// It gracefully interrupts the current execution by injecting a user message
|
||||
// that will be processed after the current tool finishes.
|
||||
func (al *AgentLoop) InterruptGraceful(msg providers.Message) error {
|
||||
return al.Steer(msg)
|
||||
}
|
||||
|
||||
// InterruptHard is an alias for HardAbort() to match the design document naming.
|
||||
// It immediately terminates execution and rolls back the session state.
|
||||
func (al *AgentLoop) InterruptHard(sessionKey string) error {
|
||||
return al.HardAbort(sessionKey)
|
||||
}
|
||||
|
||||
// InjectSteering is an alias for Steer() to match the design document naming.
|
||||
// It injects a steering message into the currently running agent loop.
|
||||
func (al *AgentLoop) InjectSteering(msg providers.Message) error {
|
||||
|
||||
@@ -5,13 +5,18 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -335,6 +340,97 @@ func TestAgentLoop_Continue_WithMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-peer",
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
|
||||
|
||||
activeMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "active turn",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
|
||||
if !ok {
|
||||
t.Fatal("expected active message to resolve to a steering scope")
|
||||
}
|
||||
|
||||
otherMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user2",
|
||||
ChatID: "chat2",
|
||||
Content: "other session",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user2",
|
||||
},
|
||||
}
|
||||
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
|
||||
if !ok {
|
||||
t.Fatal("expected other message to resolve to a steering scope")
|
||||
}
|
||||
if otherScope == activeScope {
|
||||
t.Fatalf("expected different steering scopes, got same scope %q", activeScope)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), otherMsg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
al.drainBusToSteering(ctx, activeScope, activeAgentID)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for drainBusToSteering to stop")
|
||||
}
|
||||
|
||||
if msgs := al.dequeueSteeringMessagesForScope(activeScope); len(msgs) != 0 {
|
||||
t.Fatalf("expected no steering messages for active scope, got %v", msgs)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timeout waiting for requeued message on outbound bus")
|
||||
case requeued := <-msgBus.OutboundChan():
|
||||
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
|
||||
requeued.Content != otherMsg.Content {
|
||||
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// slowTool simulates a tool that takes some time to execute.
|
||||
type slowTool struct {
|
||||
name string
|
||||
@@ -396,6 +492,149 @@ func (m *toolCallProvider) GetDefaultModel() string {
|
||||
return "tool-call-mock"
|
||||
}
|
||||
|
||||
type gracefulCaptureProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
toolCalls []providers.ToolCall
|
||||
finalResp string
|
||||
terminalMessages []providers.Message
|
||||
terminalToolsCount int
|
||||
}
|
||||
|
||||
func (p *gracefulCaptureProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.calls++
|
||||
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: p.toolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
p.terminalMessages = append([]providers.Message(nil), messages...)
|
||||
p.terminalToolsCount = len(tools)
|
||||
return &providers.LLMResponse{
|
||||
Content: p.finalResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *gracefulCaptureProvider) GetDefaultModel() string {
|
||||
return "graceful-capture-mock"
|
||||
}
|
||||
|
||||
type lateSteeringProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstCallStarted chan struct{}
|
||||
releaseFirstCall chan struct{}
|
||||
firstStartOnce sync.Once
|
||||
secondCallMessages []providers.Message
|
||||
}
|
||||
|
||||
func (p *lateSteeringProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
call := p.calls
|
||||
p.mu.Unlock()
|
||||
|
||||
if call == 1 {
|
||||
p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
|
||||
<-p.releaseFirstCall
|
||||
return &providers.LLMResponse{Content: "first response"}, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.secondCallMessages = append([]providers.Message(nil), messages...)
|
||||
p.mu.Unlock()
|
||||
return &providers.LLMResponse{Content: "continued response"}, nil
|
||||
}
|
||||
|
||||
func (p *lateSteeringProvider) GetDefaultModel() string {
|
||||
return "late-steering-mock"
|
||||
}
|
||||
|
||||
type blockingDirectProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
firstResp string
|
||||
finalResp string
|
||||
}
|
||||
|
||||
func (p *blockingDirectProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
call := p.calls
|
||||
firstStarted := p.firstStarted
|
||||
releaseFirst := p.releaseFirst
|
||||
firstResp := p.firstResp
|
||||
finalResp := p.finalResp
|
||||
if call == 1 && p.firstStarted != nil {
|
||||
close(p.firstStarted)
|
||||
p.firstStarted = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if call == 1 {
|
||||
select {
|
||||
case <-releaseFirst:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return &providers.LLMResponse{Content: firstResp}, nil
|
||||
}
|
||||
|
||||
_ = firstStarted
|
||||
return &providers.LLMResponse{Content: finalResp}, nil
|
||||
}
|
||||
|
||||
func (p *blockingDirectProvider) GetDefaultModel() string {
|
||||
return "blocking-direct-mock"
|
||||
}
|
||||
|
||||
type interruptibleTool struct {
|
||||
name string
|
||||
started chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (t *interruptibleTool) Name() string { return t.name }
|
||||
func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
|
||||
func (t *interruptibleTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
if t.started != nil {
|
||||
t.once.Do(func() { close(t.started) })
|
||||
}
|
||||
<-ctx.Done()
|
||||
return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
@@ -568,6 +807,614 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &lateSteeringProvider{
|
||||
firstCallStarted: make(chan struct{}),
|
||||
releaseFirstCall: make(chan struct{}),
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "first message",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
late := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "late append",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstCallStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first provider call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, late); err != nil {
|
||||
t.Fatalf("publish late inbound: %v", err)
|
||||
}
|
||||
|
||||
close(provider.releaseFirstCall)
|
||||
|
||||
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer subCancel()
|
||||
|
||||
var out1 bus.OutboundMessage
|
||||
select {
|
||||
case out1 = <-msgBus.OutboundChan():
|
||||
case <-subCtx.Done():
|
||||
t.Fatal("expected outbound response")
|
||||
}
|
||||
if out1.Content != "continued response" {
|
||||
t.Fatalf("expected continued response, got %q", out1.Content)
|
||||
}
|
||||
|
||||
noExtraCtx, cancelNoExtra := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancelNoExtra()
|
||||
select {
|
||||
case out2 := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content)
|
||||
case <-noExtraCtx.Done():
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
|
||||
foundLateMessage := false
|
||||
for _, msg := range secondMessages {
|
||||
if msg.Role == "user" && msg.Content == "late append" {
|
||||
foundLateMessage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundLateMessage {
|
||||
t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
firstResp: "stale direct response",
|
||||
finalResp: "fresh response after steering",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
resultCh := make(chan struct {
|
||||
resp string
|
||||
err error
|
||||
}, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"initial request",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- struct {
|
||||
resp string
|
||||
err error
|
||||
}{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
if err := al.Steer(providers.Message{Role: "user", Content: "follow-up instruction"}); err != nil {
|
||||
t.Fatalf("Steer failed: %v", err)
|
||||
}
|
||||
close(provider.releaseFirst)
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
t.Fatalf("unexpected error: %v", result.err)
|
||||
}
|
||||
if result.resp != "fresh response after steering" {
|
||||
t.Fatalf("expected refreshed response, got %q", result.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for ProcessDirectWithChannel")
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
|
||||
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
|
||||
t.Fatalf("expected steering queue to be empty after continuation, got %v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
pngPath := filepath.Join(tmpDir, "steer.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D,
|
||||
0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
||||
0x00, 0x00, 0x00,
|
||||
0x90, 0x77, 0x53, 0xDE,
|
||||
}
|
||||
if err = os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile failed: %v", err)
|
||||
}
|
||||
ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Store failed: %v", err)
|
||||
}
|
||||
|
||||
var capturedMessages []providers.Message
|
||||
var capMu sync.Mutex
|
||||
provider := &capturingMockProvider{
|
||||
response: "ack",
|
||||
captureFn: func(msgs []providers.Message) {
|
||||
capMu.Lock()
|
||||
defer capMu.Unlock()
|
||||
capturedMessages = append([]providers.Message(nil), msgs...)
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.SetMediaStore(store)
|
||||
|
||||
if err = al.Steer(providers.Message{
|
||||
Role: "user",
|
||||
Content: "describe this image",
|
||||
Media: []string{ref},
|
||||
}); err != nil {
|
||||
t.Fatalf("Steer failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.Continue(context.Background(), sessionKey, "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("Continue failed: %v", err)
|
||||
}
|
||||
if resp != "ack" {
|
||||
t.Fatalf("expected ack, got %q", resp)
|
||||
}
|
||||
|
||||
capMu.Lock()
|
||||
msgs := append([]providers.Message(nil), capturedMessages...)
|
||||
capMu.Unlock()
|
||||
|
||||
foundResolvedMedia := false
|
||||
for _, msg := range msgs {
|
||||
if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
|
||||
foundResolvedMedia = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundResolvedMedia {
|
||||
t.Fatal("expected continue path to inject steering media into the provider request")
|
||||
}
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
foundOriginalRef := false
|
||||
for _, msg := range history {
|
||||
if msg.Role == "user" && len(msg.Media) == 1 && msg.Media[0] == ref {
|
||||
foundOriginalRef = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundOriginalRef {
|
||||
t.Fatal("expected original steering media ref to be preserved in session history")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool1ExecCh := make(chan struct{})
|
||||
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
||||
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
||||
|
||||
provider := &gracefulCaptureProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "tool_one",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_one",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Type: "function",
|
||||
Name: "tool_two",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_two",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "graceful summary",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do something",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tool1ExecCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for tool_one to start")
|
||||
}
|
||||
|
||||
active := al.GetActiveTurn()
|
||||
if active == nil {
|
||||
t.Fatal("expected active turn while tool is running")
|
||||
}
|
||||
if active.SessionKey != sessionKey {
|
||||
t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
|
||||
}
|
||||
if active.Channel != "test" || active.ChatID != "chat1" {
|
||||
t.Fatalf("unexpected active turn target: %#v", active)
|
||||
}
|
||||
|
||||
if err := al.InterruptGraceful("wrap it up"); err != nil {
|
||||
t.Fatalf("InterruptGraceful failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
if r.resp != "graceful summary" {
|
||||
t.Fatalf("expected graceful summary, got %q", r.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for graceful interrupt result")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
t.Fatalf("expected no active turn after completion, got %#v", active)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
|
||||
terminalToolsCount := provider.terminalToolsCount
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
if terminalToolsCount != 0 {
|
||||
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
|
||||
}
|
||||
|
||||
foundHint := false
|
||||
foundSkipped := false
|
||||
expectedHint := "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\n" +
|
||||
"Interrupt hint: wrap it up"
|
||||
for _, msg := range terminalMessages {
|
||||
if msg.Role == "user" && msg.Content == expectedHint {
|
||||
foundHint = true
|
||||
}
|
||||
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
|
||||
foundSkipped = true
|
||||
}
|
||||
}
|
||||
if !foundHint {
|
||||
t.Fatal("expected graceful terminal call to include interrupt hint message")
|
||||
}
|
||||
if !foundSkipped {
|
||||
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindGraceful {
|
||||
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusCompleted {
|
||||
t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "cancel_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "cancel_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "should not happen",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
originalHistory := []providers.Message{
|
||||
{Role: "user", Content: "before"},
|
||||
{Role: "assistant", Content: "after"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do work",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for interruptible tool to start")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active == nil {
|
||||
t.Fatal("expected active turn before hard abort")
|
||||
}
|
||||
|
||||
if err := al.InterruptHard(); err != nil {
|
||||
t.Fatalf("InterruptHard failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
if r.resp != "" {
|
||||
t.Fatalf("expected no final response after hard abort, got %q", r.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for hard abort result")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
t.Fatalf("expected no active turn after hard abort, got %#v", active)
|
||||
}
|
||||
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if !reflect.DeepEqual(finalHistory, originalHistory) {
|
||||
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindHard {
|
||||
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusAborted {
|
||||
t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// capturingMockProvider captures messages sent to Chat for inspection.
|
||||
type capturingMockProvider struct {
|
||||
response string
|
||||
|
||||
+237
-332
@@ -4,14 +4,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// ====================== Config & Constants ======================
|
||||
@@ -176,33 +175,6 @@ type SubTurnConfig struct {
|
||||
// Can be extended with temperature, topP, etc.
|
||||
}
|
||||
|
||||
// ====================== Sub-turn Events (Aligned with EventBus) ======================
|
||||
|
||||
// SubTurnSpawnEvent is emitted when a child sub-turn is started.
|
||||
type SubTurnSpawnEvent struct {
|
||||
ParentID string
|
||||
ChildID string
|
||||
Config SubTurnConfig
|
||||
}
|
||||
|
||||
type SubTurnEndEvent struct {
|
||||
ChildID string
|
||||
Result *tools.ToolResult
|
||||
Err error
|
||||
}
|
||||
|
||||
type SubTurnResultDeliveredEvent struct {
|
||||
ParentID string
|
||||
ChildID string
|
||||
Result *tools.ToolResult
|
||||
}
|
||||
|
||||
type SubTurnOrphanResultEvent struct {
|
||||
ParentID string
|
||||
ChildID string
|
||||
Result *tools.ToolResult
|
||||
}
|
||||
|
||||
// ====================== Context Keys ======================
|
||||
type agentLoopKeyType struct{}
|
||||
|
||||
@@ -300,6 +272,11 @@ func spawnSubTurn(
|
||||
// 0. Acquire concurrency semaphore FIRST to ensure it's released even if early validation fails.
|
||||
// Blocks if parent already has maxConcurrentSubTurns running, with a timeout to prevent indefinite blocking.
|
||||
// Also respects context cancellation so we don't block forever if parent is aborted.
|
||||
// NOTE: The semaphore is released immediately after runTurn completes (not in a defer) to
|
||||
// ensure it is freed before the cleanup phase (async result delivery), which may block on
|
||||
// a full pendingResults channel. Holding the semaphore through cleanup would allow the
|
||||
// parent's goroutine to be blocked waiting for a semaphore slot while child turns are
|
||||
// blocked delivering results — a deadlock.
|
||||
var semAcquired bool
|
||||
if parentTS.concurrencySem != nil {
|
||||
// Create a timeout context for semaphore acquisition
|
||||
@@ -353,10 +330,60 @@ func spawnSubTurn(
|
||||
defer cancel()
|
||||
|
||||
childID := al.generateSubTurnID()
|
||||
childTS := newTurnState(childCtx, childID, parentTS, rtCfg.maxConcurrent)
|
||||
// Set the cancel function so Finish(true) can trigger hard cancellation
|
||||
|
||||
// Get the agent instance from parent, falling back to the default agent.
|
||||
// Wrap it in a shallow copy that uses an ephemeral (in-memory only) session store
|
||||
// so that child turns never pollute or persist to the parent's session history.
|
||||
baseAgent := parentTS.agent
|
||||
if baseAgent == nil {
|
||||
baseAgent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
if baseAgent == nil {
|
||||
return nil, errors.New("parent turnState has no agent instance")
|
||||
}
|
||||
ephemeralStore := newEphemeralSession(nil)
|
||||
agent := *baseAgent // shallow copy
|
||||
agent.Sessions = ephemeralStore
|
||||
// Clone the tool registry so child turn's tool registrations
|
||||
// don't pollute the parent's registry.
|
||||
if baseAgent.Tools != nil {
|
||||
agent.Tools = baseAgent.Tools.Clone()
|
||||
}
|
||||
|
||||
// Create processOptions for the child turn
|
||||
opts := processOptions{
|
||||
SessionKey: childID,
|
||||
Channel: parentTS.channel,
|
||||
ChatID: parentTS.chatID,
|
||||
SenderID: parentTS.opts.SenderID,
|
||||
SenderDisplayName: parentTS.opts.SenderDisplayName,
|
||||
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
|
||||
SystemPromptOverride: cfg.ActualSystemPrompt,
|
||||
Media: nil,
|
||||
InitialSteeringMessages: cfg.InitialMessages,
|
||||
DefaultResponse: "",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
NoHistory: true, // SubTurns don't use session history
|
||||
SkipInitialSteeringPoll: true,
|
||||
}
|
||||
|
||||
// Create event scope for the child turn
|
||||
scope := al.newTurnEventScope(agent.ID, childID)
|
||||
|
||||
// Create child turnState using the new API
|
||||
childTS := newTurnState(&agent, opts, scope)
|
||||
|
||||
// Set SubTurn-specific fields
|
||||
childTS.cancelFunc = cancel
|
||||
childTS.critical = cfg.Critical
|
||||
childTS.depth = parentTS.depth + 1
|
||||
childTS.parentTurnID = parentTS.turnID
|
||||
childTS.parentTurnState = parentTS
|
||||
childTS.pendingResults = make(chan *tools.ToolResult, 16)
|
||||
childTS.concurrencySem = make(chan struct{}, rtCfg.maxConcurrent)
|
||||
childTS.al = al // back-ref for hard abort cascade
|
||||
childTS.session = ephemeralStore // same store as agent.Sessions
|
||||
|
||||
// Token budget initialization/inheritance
|
||||
// If InitialTokenBudget is explicitly provided (e.g., by team tool), use it.
|
||||
@@ -376,6 +403,8 @@ func spawnSubTurn(
|
||||
childCtx = withTurnState(childCtx, childTS)
|
||||
childCtx = WithAgentLoop(childCtx, al) // Propagate AgentLoop to child turn
|
||||
|
||||
childTS.ctx = childCtx
|
||||
|
||||
// Register child turn state so GetAllActiveTurns/Subagents can find it
|
||||
al.activeTurnStates.Store(childID, childTS)
|
||||
defer al.activeTurnStates.Delete(childID)
|
||||
@@ -386,11 +415,14 @@ func spawnSubTurn(
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// 6. Emit Spawn event
|
||||
MockEventBus.Emit(SubTurnSpawnEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Config: cfg,
|
||||
})
|
||||
al.emitEvent(EventKindSubTurnSpawn,
|
||||
childTS.eventMeta("spawnSubTurn", "subturn.spawn"),
|
||||
SubTurnSpawnPayload{
|
||||
AgentID: childTS.agentID,
|
||||
Label: childID,
|
||||
ParentTurnID: parentTS.turnID,
|
||||
},
|
||||
)
|
||||
|
||||
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
|
||||
defer func() {
|
||||
@@ -401,22 +433,61 @@ func spawnSubTurn(
|
||||
"parent_id": parentTS.turnID,
|
||||
"panic": r,
|
||||
})
|
||||
|
||||
// Ensure result is not nil to prevent panic during event emission
|
||||
if result == nil {
|
||||
result = &tools.ToolResult{
|
||||
Err: err,
|
||||
ForLLM: fmt.Sprintf("SubTurn panicked: %v", r),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Result Delivery Strategy (Async vs Sync)
|
||||
if cfg.Async {
|
||||
deliverSubTurnResult(parentTS, childID, result)
|
||||
deliverSubTurnResult(al, parentTS, childID, result)
|
||||
}
|
||||
|
||||
MockEventBus.Emit(SubTurnEndEvent{
|
||||
ChildID: childID,
|
||||
Result: result,
|
||||
Err: err,
|
||||
})
|
||||
status := "completed"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
al.emitEvent(EventKindSubTurnEnd,
|
||||
childTS.eventMeta("spawnSubTurn", "subturn.end"),
|
||||
SubTurnEndPayload{
|
||||
AgentID: childTS.agentID,
|
||||
Status: status,
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
// 8. Execute sub-turn via the real agent loop.
|
||||
result, err = runTurn(childCtx, al, childTS, cfg)
|
||||
turnRes, turnErr := al.runTurn(childCtx, childTS)
|
||||
|
||||
// Release the concurrency semaphore immediately after runTurn completes,
|
||||
// before the cleanup defer runs. This prevents a deadlock where:
|
||||
// - All semaphore slots are held by sub-turns in their cleanup phase
|
||||
// - Cleanup blocks on a full pendingResults channel
|
||||
// - The parent goroutine is blocked waiting for a semaphore slot
|
||||
// - The parent cannot consume pendingResults because it is blocked on the semaphore
|
||||
if semAcquired {
|
||||
<-parentTS.concurrencySem
|
||||
semAcquired = false // prevent the defer from double-releasing
|
||||
}
|
||||
|
||||
// Convert turnResult to tools.ToolResult
|
||||
if turnErr != nil {
|
||||
err = turnErr
|
||||
result = &tools.ToolResult{
|
||||
Err: turnErr,
|
||||
ForLLM: fmt.Sprintf("SubTurn failed: %v", turnErr),
|
||||
}
|
||||
} else {
|
||||
result = &tools.ToolResult{
|
||||
ForLLM: turnRes.finalContent,
|
||||
ForUser: turnRes.finalContent,
|
||||
}
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
@@ -441,7 +512,7 @@ func spawnSubTurn(
|
||||
// Event emissions:
|
||||
// - SubTurnResultDeliveredEvent: successful delivery to channel
|
||||
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
|
||||
func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) {
|
||||
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
|
||||
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
|
||||
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
|
||||
defer func() {
|
||||
@@ -451,28 +522,26 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
|
||||
"child_id": childID,
|
||||
"recover": r,
|
||||
})
|
||||
if result != nil {
|
||||
MockEventBus.Emit(SubTurnOrphanResultEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Result: result,
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
|
||||
)
|
||||
}
|
||||
}
|
||||
}()
|
||||
parentTS.mu.Lock()
|
||||
isFinished := parentTS.isFinished
|
||||
isFinished := parentTS.isFinished.Load()
|
||||
resultChan := parentTS.pendingResults
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// If parent turn has already finished, treat this as an orphan result
|
||||
if isFinished || resultChan == nil {
|
||||
if result != nil {
|
||||
MockEventBus.Emit(SubTurnOrphanResultEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Result: result,
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -484,11 +553,12 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
|
||||
select {
|
||||
case resultChan <- result:
|
||||
// Successfully delivered
|
||||
MockEventBus.Emit(SubTurnResultDeliveredEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Result: result,
|
||||
})
|
||||
if al != nil {
|
||||
al.emitEvent(EventKindSubTurnResultDelivered,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.result_delivered"),
|
||||
SubTurnResultDeliveredPayload{ContentLen: len(result.ForLLM)},
|
||||
)
|
||||
}
|
||||
case <-parentTS.Finished():
|
||||
// Parent finished while we were waiting to deliver.
|
||||
// The result cannot be delivered to the LLM, so it becomes an orphan.
|
||||
@@ -496,278 +566,113 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
|
||||
"parent_id": parentTS.turnID,
|
||||
"child_id": childID,
|
||||
})
|
||||
if result != nil {
|
||||
MockEventBus.Emit(SubTurnOrphanResultEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Result: result,
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(
|
||||
EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{
|
||||
ParentTurnID: parentTS.turnID,
|
||||
ChildTurnID: childID,
|
||||
Reason: "parent_finished_waiting",
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runTurn builds a temporary AgentInstance from SubTurnConfig and delegates to
|
||||
// the real agent loop. The child's ephemeral session is used for history so it
|
||||
// never pollutes the parent session.
|
||||
//
|
||||
// This function implements multiple layers of context protection and error recovery:
|
||||
//
|
||||
// 1. Soft Context Limit (MaxContextRunes):
|
||||
// - Proactively truncates message history before LLM calls
|
||||
// - Default: 75% of model's context window
|
||||
// - Preserves system messages and recent context
|
||||
// - First line of defense against context overflow
|
||||
//
|
||||
// 2. Hard Context Error Recovery:
|
||||
// - Detects context_length_exceeded errors from provider
|
||||
// - Triggers force compression and retries (up to 2 times)
|
||||
// - Second line of defense when soft limit is insufficient
|
||||
//
|
||||
// 3. Truncation Recovery:
|
||||
// - Detects when LLM response is truncated (finish_reason="truncated")
|
||||
// - Injects recovery prompt asking for shorter response
|
||||
// - Retries up to 2 times
|
||||
// - Handles cases where max_tokens is hit
|
||||
func runTurn(
|
||||
ctx context.Context,
|
||||
al *AgentLoop,
|
||||
ts *turnState,
|
||||
cfg SubTurnConfig,
|
||||
) (*tools.ToolResult, error) {
|
||||
// Derive candidates from the requested model using the parent loop's provider.
|
||||
defaultProvider := al.GetConfig().Agents.Defaults.Provider
|
||||
candidates := providers.ResolveCandidates(
|
||||
providers.ModelConfig{Primary: cfg.Model},
|
||||
defaultProvider,
|
||||
)
|
||||
|
||||
// Build a minimal AgentInstance for this sub-turn.
|
||||
// It reuses the parent loop's provider and config, but gets its own
|
||||
// ephemeral session store and tool registry.
|
||||
parentAgent := al.GetRegistry().GetDefaultAgent()
|
||||
|
||||
// Determine which tools to use: explicit config or inherit from parent
|
||||
toolRegistry := tools.NewToolRegistry()
|
||||
toolsToRegister := cfg.Tools
|
||||
if len(toolsToRegister) == 0 {
|
||||
toolsToRegister = parentAgent.Tools.GetAll()
|
||||
}
|
||||
for _, t := range toolsToRegister {
|
||||
toolRegistry.Register(t)
|
||||
}
|
||||
|
||||
childAgent := &AgentInstance{
|
||||
ID: ts.turnID,
|
||||
Model: cfg.Model,
|
||||
MaxIterations: parentAgent.MaxIterations,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Temperature: parentAgent.Temperature,
|
||||
ThinkingLevel: parentAgent.ThinkingLevel,
|
||||
ContextWindow: parentAgent.ContextWindow, // Inherit from parent agent
|
||||
SummarizeMessageThreshold: parentAgent.SummarizeMessageThreshold,
|
||||
SummarizeTokenPercent: parentAgent.SummarizeTokenPercent,
|
||||
Provider: parentAgent.Provider,
|
||||
Sessions: ts.session,
|
||||
ContextBuilder: parentAgent.ContextBuilder,
|
||||
Tools: toolRegistry,
|
||||
Candidates: candidates,
|
||||
}
|
||||
if childAgent.MaxTokens == 0 {
|
||||
childAgent.MaxTokens = parentAgent.MaxTokens
|
||||
}
|
||||
|
||||
promptAlreadyAdded := false
|
||||
|
||||
// Preload ephemeral session history
|
||||
if len(cfg.InitialMessages) > 0 {
|
||||
existing := childAgent.Sessions.GetHistory(ts.turnID)
|
||||
childAgent.Sessions.SetHistory(ts.turnID, append(existing, cfg.InitialMessages...))
|
||||
promptAlreadyAdded = true // InitialMessages 中已含 user 消息,跳过再次添加
|
||||
}
|
||||
|
||||
// Resolve MaxContextRunes configuration
|
||||
maxContextRunes := utils.ResolveMaxContextRunes(cfg.MaxContextRunes, childAgent.ContextWindow)
|
||||
|
||||
logger.DebugCF("subturn", "Context limit resolved",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"context_window": childAgent.ContextWindow,
|
||||
"max_context_runes": maxContextRunes,
|
||||
"configured_value": cfg.MaxContextRunes,
|
||||
})
|
||||
|
||||
// Retry loop for truncation and context errors
|
||||
const (
|
||||
maxTruncationRetries = 2
|
||||
maxContextRetries = 2
|
||||
)
|
||||
|
||||
truncationRetryCount := 0
|
||||
contextRetryCount := 0
|
||||
currentPrompt := cfg.SystemPrompt
|
||||
|
||||
for {
|
||||
// Soft context limit: check and truncate before LLM call
|
||||
if maxContextRunes > 0 {
|
||||
messages := childAgent.Sessions.GetHistory(ts.turnID)
|
||||
currentRunes := utils.MeasureContextRunes(messages)
|
||||
|
||||
if currentRunes > maxContextRunes {
|
||||
logger.WarnCF("subturn", "Context exceeds soft limit, truncating",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"current_runes": currentRunes,
|
||||
"max_runes": maxContextRunes,
|
||||
"overflow": currentRunes - maxContextRunes,
|
||||
})
|
||||
|
||||
truncatedMessages := utils.TruncateContextSmart(messages, maxContextRunes)
|
||||
childAgent.Sessions.SetHistory(ts.turnID, truncatedMessages)
|
||||
|
||||
// Log truncation result
|
||||
newRunes := utils.MeasureContextRunes(truncatedMessages)
|
||||
logger.InfoCF("subturn", "Context truncated successfully",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"before_runes": currentRunes,
|
||||
"after_runes": newRunes,
|
||||
"saved_runes": currentRunes - newRunes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Call the agent loop
|
||||
finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{
|
||||
SessionKey: ts.turnID,
|
||||
UserMessage: currentPrompt,
|
||||
SystemPromptOverride: cfg.ActualSystemPrompt,
|
||||
DefaultResponse: "",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
SkipAddUserMessage: promptAlreadyAdded,
|
||||
})
|
||||
|
||||
// Mark the prompt as added so subsequent truncation retries
|
||||
// won't duplicate it in the history.
|
||||
promptAlreadyAdded = true
|
||||
|
||||
// 1. Handle context length errors
|
||||
if err != nil && isContextLengthError(err) {
|
||||
if contextRetryCount >= maxContextRetries {
|
||||
logger.ErrorCF("subturn", "Context limit exceeded after max retries",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"retries": contextRetryCount,
|
||||
"max_retries": maxContextRetries,
|
||||
})
|
||||
return nil, fmt.Errorf(
|
||||
"context limit exceeded after %d retries: %w",
|
||||
maxContextRetries,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
logger.WarnCF("subturn", "Context length exceeded, compressing and retrying",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"retry": contextRetryCount + 1,
|
||||
})
|
||||
|
||||
// Trigger force compression
|
||||
al.forceCompression(childAgent, ts.turnID)
|
||||
|
||||
contextRetryCount++
|
||||
continue // Retry with compressed history
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err // Other errors, return immediately
|
||||
}
|
||||
|
||||
// 2. Check for truncation (retrieve finishReason from turnState)
|
||||
finishReason := ts.GetLastFinishReason()
|
||||
|
||||
if finishReason == "truncated" && truncationRetryCount < maxTruncationRetries {
|
||||
logger.WarnCF("subturn", "Response truncated, injecting recovery message",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"retry": truncationRetryCount + 1,
|
||||
})
|
||||
|
||||
// IMPORTANT: Do NOT manually add messages to history here.
|
||||
// runAgentLoop has already saved both the assistant message (finalContent)
|
||||
// and will save the next user message (currentPrompt) on the next iteration.
|
||||
// Manually adding them would cause duplicates.
|
||||
|
||||
// Inject recovery prompt - it will be added by runAgentLoop on next iteration
|
||||
recoveryPrompt := "Your previous response was truncated due to length. Please provide a shorter, complete response that finishes your thought."
|
||||
currentPrompt = recoveryPrompt
|
||||
promptAlreadyAdded = false // We need this new recovery prompt to be added
|
||||
|
||||
truncationRetryCount++
|
||||
continue // Retry with recovery prompt
|
||||
}
|
||||
|
||||
// 3. Token budget enforcement (if configured)
|
||||
// Check if budget is exhausted after this LLM call. If so, return gracefully
|
||||
// with current result instead of continuing iterations.
|
||||
if ts.tokenBudget != nil {
|
||||
if usage := ts.GetLastUsage(); usage != nil {
|
||||
newBudget := ts.tokenBudget.Add(-int64(usage.TotalTokens))
|
||||
|
||||
if newBudget <= 0 {
|
||||
logger.WarnCF("subturn", "Token budget exhausted",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"deficit": -newBudget,
|
||||
"tokens_used": usage.TotalTokens,
|
||||
"final_budget": newBudget,
|
||||
})
|
||||
|
||||
// Budget exhausted - return current result with marker
|
||||
return &tools.ToolResult{
|
||||
ForLLM: finalContent + "\n\n[Token budget exhausted]",
|
||||
Messages: childAgent.Sessions.GetHistory(ts.turnID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
logger.DebugCF("subturn", "Token budget updated",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"tokens_used": usage.TotalTokens,
|
||||
"remaining_budget": newBudget,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Success - return result with session history
|
||||
return &tools.ToolResult{
|
||||
ForLLM: finalContent,
|
||||
Messages: childAgent.Sessions.GetHistory(ts.turnID),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// isContextLengthError checks if the error is due to context length exceeded.
|
||||
// It excludes timeout errors to avoid false positives.
|
||||
func isContextLengthError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
|
||||
// Exclude timeout errors
|
||||
if strings.Contains(errMsg, "timeout") || strings.Contains(errMsg, "deadline exceeded") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Detect context error patterns
|
||||
return strings.Contains(errMsg, "context_length_exceeded") ||
|
||||
strings.Contains(errMsg, "maximum context length") ||
|
||||
strings.Contains(errMsg, "context window") ||
|
||||
strings.Contains(errMsg, "too many tokens") ||
|
||||
strings.Contains(errMsg, "token limit") ||
|
||||
strings.Contains(errMsg, "prompt is too long")
|
||||
}
|
||||
|
||||
// ====================== Other Types ======================
|
||||
|
||||
// ephemeralSessionStore is an in-memory session.SessionStore used by SubTurns.
|
||||
// It does not persist to disk and auto-truncates history to maxEphemeralHistorySize.
|
||||
type ephemeralSessionStore struct {
|
||||
mu sync.Mutex
|
||||
history []providers.Message
|
||||
summary string
|
||||
}
|
||||
|
||||
func newEphemeralSession(initial []providers.Message) ephemeralSessionStoreIface {
|
||||
s := &ephemeralSessionStore{}
|
||||
if len(initial) > 0 {
|
||||
s.history = append(s.history, initial...)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ephemeralSessionStoreIface is satisfied by *ephemeralSessionStore.
|
||||
// Declared so newEphemeralSession can return a typed interface.
|
||||
type ephemeralSessionStoreIface interface {
|
||||
AddMessage(sessionKey, role, content string)
|
||||
AddFullMessage(sessionKey string, msg providers.Message)
|
||||
GetHistory(key string) []providers.Message
|
||||
GetSummary(key string) string
|
||||
SetSummary(key, summary string)
|
||||
SetHistory(key string, history []providers.Message)
|
||||
TruncateHistory(key string, keepLast int)
|
||||
Save(key string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) AddMessage(_, role, content string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = append(e.history, providers.Message{Role: role, Content: content})
|
||||
e.truncateLocked()
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) AddFullMessage(_ string, msg providers.Message) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = append(e.history, msg)
|
||||
e.truncateLocked()
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) GetHistory(_ string) []providers.Message {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]providers.Message, len(e.history))
|
||||
copy(out, e.history)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) GetSummary(_ string) string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.summary
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) SetSummary(_, summary string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.summary = summary
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) SetHistory(_ string, history []providers.Message) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = make([]providers.Message, len(history))
|
||||
copy(e.history, history)
|
||||
e.truncateLocked()
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if keepLast <= 0 {
|
||||
e.history = nil
|
||||
return
|
||||
}
|
||||
|
||||
if keepLast >= len(e.history) {
|
||||
return
|
||||
}
|
||||
e.history = e.history[len(e.history)-keepLast:]
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
|
||||
func (e *ephemeralSessionStore) truncateLocked() {
|
||||
if len(e.history) > maxEphemeralHistorySize {
|
||||
e.history = e.history[len(e.history)-maxEphemeralHistorySize:]
|
||||
}
|
||||
}
|
||||
|
||||
+214
-173
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -22,17 +21,35 @@ const (
|
||||
|
||||
// ====================== Test Helper: Event Collector ======================
|
||||
type eventCollector struct {
|
||||
events []any
|
||||
mu sync.Mutex
|
||||
events []Event
|
||||
}
|
||||
|
||||
func (c *eventCollector) collect(e any) {
|
||||
c.events = append(c.events, e)
|
||||
func newEventCollector(t *testing.T, al *AgentLoop) (*eventCollector, func()) {
|
||||
t.Helper()
|
||||
c := &eventCollector{}
|
||||
sub := al.SubscribeEvents(16)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for evt := range sub.C {
|
||||
c.mu.Lock()
|
||||
c.events = append(c.events, evt)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
cleanup := func() {
|
||||
al.UnsubscribeEvents(sub.ID)
|
||||
<-done
|
||||
}
|
||||
return c, cleanup
|
||||
}
|
||||
|
||||
func (c *eventCollector) hasEventOfType(typ any) bool {
|
||||
targetType := reflect.TypeOf(typ)
|
||||
func (c *eventCollector) hasEventOfKind(kind EventKind) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for _, e := range c.events {
|
||||
if reflect.TypeOf(e) == targetType {
|
||||
if e.Kind == kind {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -111,13 +128,12 @@ func TestSpawnSubTurn(t *testing.T) {
|
||||
childTurnIDs: []string{},
|
||||
pendingResults: make(chan *tools.ToolResult, 10),
|
||||
session: &ephemeralSessionStore{},
|
||||
agent: al.registry.GetDefaultAgent(),
|
||||
}
|
||||
|
||||
// Replace mock with test collector
|
||||
collector := &eventCollector{}
|
||||
originalEmit := MockEventBus.Emit
|
||||
MockEventBus.Emit = collector.collect
|
||||
defer func() { MockEventBus.Emit = originalEmit }()
|
||||
// Subscribe to real EventBus to capture events
|
||||
collector, collectCleanup := newEventCollector(t, al)
|
||||
defer collectCleanup()
|
||||
|
||||
// Execute spawnSubTurn
|
||||
result, err := spawnSubTurn(context.Background(), al, parent, tt.config)
|
||||
@@ -140,13 +156,14 @@ func TestSpawnSubTurn(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify event emission
|
||||
time.Sleep(10 * time.Millisecond) // let event goroutine flush
|
||||
if tt.wantSpawn {
|
||||
if !collector.hasEventOfType(SubTurnSpawnEvent{}) {
|
||||
if !collector.hasEventOfKind(EventKindSubTurnSpawn) {
|
||||
t.Error("SubTurnSpawnEvent not emitted")
|
||||
}
|
||||
}
|
||||
if tt.wantEnd {
|
||||
if !collector.hasEventOfType(SubTurnEndEvent{}) {
|
||||
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
|
||||
t.Error("SubTurnEndEvent not emitted")
|
||||
}
|
||||
}
|
||||
@@ -169,27 +186,41 @@ func TestSpawnSubTurn_EphemeralSessionIsolation(t *testing.T) {
|
||||
_ = provider
|
||||
defer cleanup()
|
||||
|
||||
// Parent uses its own ephemeral store pre-seeded with one message
|
||||
parentSession := &ephemeralSessionStore{}
|
||||
parentSession.AddMessage("", "user", "parent msg")
|
||||
parent := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: "parent-1",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 1),
|
||||
pendingResults: make(chan *tools.ToolResult, 4),
|
||||
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
|
||||
session: parentSession,
|
||||
}
|
||||
|
||||
cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}}
|
||||
|
||||
// Record main session length before execution
|
||||
originalLen := len(parent.session.GetHistory(""))
|
||||
originalParentLen := len(parentSession.GetHistory(""))
|
||||
|
||||
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
|
||||
|
||||
// After sub-turn ends, main session must remain unchanged
|
||||
if len(parent.session.GetHistory("")) != originalLen {
|
||||
t.Error("ephemeral session polluted the main session")
|
||||
// Parent session must be untouched — child used its own store
|
||||
if got := len(parentSession.GetHistory("")); got != originalParentLen {
|
||||
t.Errorf("parent session polluted: expected %d messages, got %d", originalParentLen, got)
|
||||
}
|
||||
|
||||
// The child's agent.Sessions must NOT be the same pointer as the parent's session.
|
||||
// We verify this indirectly: spawnSubTurn stores childTS in activeTurnStates during
|
||||
// execution (deleted on return), so we can't easily grab childTS after the call.
|
||||
// Instead, confirm that the child session is a distinct ephemeralSessionStore by
|
||||
// checking the parent session key is only used by the parent store.
|
||||
// If isolation is correct, parent.session.GetHistory(childID) is always empty
|
||||
// (the child never wrote to the parent store).
|
||||
al.activeTurnStates.Range(func(k, v any) bool {
|
||||
// No active turns should remain after spawnSubTurn returns
|
||||
t.Errorf("unexpected active turn state left after spawnSubTurn: key=%v", k)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// ====================== Extra Independent Test: Result Delivery Path (Async) ======================
|
||||
@@ -260,6 +291,13 @@ func TestSpawnSubTurn_ResultDeliverySync(t *testing.T) {
|
||||
|
||||
// ====================== Extra Independent Test: Orphan Result Routing ======================
|
||||
func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
|
||||
al, _, _, provider, cleanup := newTestAgentLoop(t)
|
||||
_ = provider
|
||||
defer cleanup()
|
||||
|
||||
collector, collectCleanup := newEventCollector(t, al)
|
||||
defer collectCleanup()
|
||||
|
||||
parentCtx, cancelParent := context.WithCancel(context.Background())
|
||||
parent := &turnState{
|
||||
ctx: parentCtx,
|
||||
@@ -270,19 +308,15 @@ func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
|
||||
session: &ephemeralSessionStore{},
|
||||
}
|
||||
|
||||
collector := &eventCollector{}
|
||||
originalEmit := MockEventBus.Emit
|
||||
MockEventBus.Emit = collector.collect
|
||||
defer func() { MockEventBus.Emit = originalEmit }()
|
||||
|
||||
// Simulate parent finishing before child delivers result
|
||||
parent.Finish(false)
|
||||
|
||||
// Call deliverSubTurnResult directly to simulate a delayed child
|
||||
deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
|
||||
deliverSubTurnResult(al, parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // let event goroutine flush
|
||||
// Verify Orphan event is emitted
|
||||
if !collector.hasEventOfType(SubTurnOrphanResultEvent{}) {
|
||||
if !collector.hasEventOfKind(EventKindSubTurnOrphan) {
|
||||
t.Error("SubTurnOrphanResultEvent not emitted for finished parent")
|
||||
}
|
||||
|
||||
@@ -414,70 +448,74 @@ func TestHardAbortCascading(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
sessionKey := "test-session-abort"
|
||||
parentCtx, parentCancel := context.WithCancel(context.Background())
|
||||
defer parentCancel()
|
||||
|
||||
// Root turn with its own independent context (not derived from child)
|
||||
rootCtx, rootCancel := context.WithCancel(context.Background())
|
||||
rootTS := &turnState{
|
||||
ctx: parentCtx,
|
||||
ctx: rootCtx,
|
||||
cancelFunc: rootCancel,
|
||||
turnID: sessionKey,
|
||||
depth: 0,
|
||||
session: &ephemeralSessionStore{},
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
al: al,
|
||||
}
|
||||
|
||||
// Register the root turn state
|
||||
al.activeTurnStates.Store(sessionKey, rootTS)
|
||||
defer al.activeTurnStates.Delete(sessionKey)
|
||||
|
||||
// Create a child turn state
|
||||
childCtx, childCancel := context.WithCancel(rootTS.ctx)
|
||||
defer childCancel()
|
||||
// Child turn with an INDEPENDENT context (simulates spawnSubTurn behavior:
|
||||
// context.WithTimeout(context.Background(), ...) — NOT derived from parent).
|
||||
// Cascade must therefore happen via childTurnIDs traversal, not Go context tree.
|
||||
childCtx, childCancel := context.WithCancel(context.Background())
|
||||
childID := "child-independent"
|
||||
childTS := &turnState{
|
||||
ctx: childCtx,
|
||||
ctx: childCtx,
|
||||
cancelFunc: childCancel,
|
||||
turnID: childID,
|
||||
pendingResults: make(chan *tools.ToolResult, 4),
|
||||
al: al,
|
||||
}
|
||||
_ = childCancel
|
||||
al.activeTurnStates.Store(childID, childTS)
|
||||
defer al.activeTurnStates.Delete(childID)
|
||||
|
||||
// Attach cancelFunc to rootTS so Finish() can trigger it
|
||||
rootTS.cancelFunc = parentCancel
|
||||
// Wire child into root's childTurnIDs (as spawnSubTurn would do)
|
||||
rootTS.childTurnIDs = append(rootTS.childTurnIDs, childID)
|
||||
|
||||
// Verify contexts are not canceled yet
|
||||
// Verify neither context is canceled yet
|
||||
select {
|
||||
case <-rootTS.ctx.Done():
|
||||
t.Error("root context should not be canceled yet")
|
||||
t.Fatal("root context should not be canceled yet")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-childTS.ctx.Done():
|
||||
t.Error("child context should not be canceled yet")
|
||||
t.Fatal("child context should not be canceled yet (independent context)")
|
||||
default:
|
||||
}
|
||||
|
||||
// Trigger Hard Abort
|
||||
// Trigger Hard Abort via al.HardAbort (goes through steering.go → Finish(true))
|
||||
err := al.HardAbort(sessionKey)
|
||||
if err != nil {
|
||||
t.Errorf("HardAbort failed: %v", err)
|
||||
t.Fatalf("HardAbort failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify root context is canceled
|
||||
// Root context must be canceled
|
||||
select {
|
||||
case <-rootTS.ctx.Done():
|
||||
// Expected
|
||||
default:
|
||||
t.Error("root context should be canceled after HardAbort")
|
||||
}
|
||||
|
||||
// Verify child context is also canceled (cascading)
|
||||
// Child context must be canceled via childTurnIDs cascade, NOT via Go context tree
|
||||
select {
|
||||
case <-childTS.ctx.Done():
|
||||
// Expected
|
||||
default:
|
||||
t.Error("child context should be canceled after HardAbort (cascading)")
|
||||
t.Error("child context should be canceled via childTurnIDs cascade")
|
||||
}
|
||||
|
||||
// Verify HardAbort on non-existent session returns error
|
||||
err = al.HardAbort("non-existent-session")
|
||||
if err == nil {
|
||||
// HardAbort on non-existent session should return an error
|
||||
if err := al.HardAbort("non-existent-session"); err == nil {
|
||||
t.Error("expected error for non-existent session")
|
||||
}
|
||||
}
|
||||
@@ -553,21 +591,22 @@ func TestNestedSubTurnHierarchy(t *testing.T) {
|
||||
var spawnedTurns []turnInfo
|
||||
var mu sync.Mutex
|
||||
|
||||
// Override MockEventBus to capture spawn events
|
||||
originalEmit := MockEventBus.Emit
|
||||
defer func() { MockEventBus.Emit = originalEmit }()
|
||||
|
||||
MockEventBus.Emit = func(event any) {
|
||||
if spawnEvent, ok := event.(SubTurnSpawnEvent); ok {
|
||||
mu.Lock()
|
||||
// Extract depth from context (we'll verify this matches expected depth)
|
||||
spawnedTurns = append(spawnedTurns, turnInfo{
|
||||
parentID: spawnEvent.ParentID,
|
||||
childID: spawnEvent.ChildID,
|
||||
})
|
||||
mu.Unlock()
|
||||
// Subscribe to real EventBus to capture spawn events
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
go func() {
|
||||
for evt := range sub.C {
|
||||
if evt.Kind == EventKindSubTurnSpawn {
|
||||
p, _ := evt.Payload.(SubTurnSpawnPayload)
|
||||
mu.Lock()
|
||||
spawnedTurns = append(spawnedTurns, turnInfo{
|
||||
parentID: p.ParentTurnID,
|
||||
childID: p.Label,
|
||||
})
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create a root turn
|
||||
rootSession := &ephemeralSessionStore{}
|
||||
@@ -587,6 +626,8 @@ func TestNestedSubTurnHierarchy(t *testing.T) {
|
||||
t.Fatalf("failed to spawn child: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // let event goroutine flush
|
||||
|
||||
// Verify we captured the spawn event
|
||||
mu.Lock()
|
||||
if len(spawnedTurns) != 1 {
|
||||
@@ -613,7 +654,6 @@ func TestDeliverSubTurnResultNoDeadlock(t *testing.T) {
|
||||
turnID: "parent-deadlock-test",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 2), // Small buffer to test blocking
|
||||
isFinished: false,
|
||||
}
|
||||
|
||||
// Simulate multiple child turns delivering results concurrently
|
||||
@@ -625,7 +665,7 @@ func TestDeliverSubTurnResultNoDeadlock(t *testing.T) {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
result := &tools.ToolResult{ForLLM: fmt.Sprintf("result-%d", id)}
|
||||
deliverSubTurnResult(parent, fmt.Sprintf("child-%d", id), result)
|
||||
deliverSubTurnResult(nil, parent, fmt.Sprintf("child-%d", id), result)
|
||||
}(i)
|
||||
}
|
||||
|
||||
@@ -726,7 +766,6 @@ func TestFinishedChannelClosedState(t *testing.T) {
|
||||
turnID: "test-finished-channel",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 2),
|
||||
isFinished: false,
|
||||
}
|
||||
|
||||
// Verify Finished channel is blocking initially
|
||||
@@ -755,7 +794,7 @@ func TestFinishedChannelClosedState(t *testing.T) {
|
||||
|
||||
// Verify deliverSubTurnResult correctly uses Finished() channel and treats as orphan
|
||||
result := &tools.ToolResult{ForLLM: "late result"}
|
||||
deliverSubTurnResult(ts, "child-1", result) // Will emit orphan due to <-ts.Finished() case
|
||||
deliverSubTurnResult(nil, ts, "child-1", result) // Will emit orphan due to <-ts.Finished() case
|
||||
}
|
||||
|
||||
// TestFinalPollCapturesLateResults verifies that the final poll before Finish()
|
||||
@@ -821,10 +860,8 @@ func TestSpawnSubTurn_PanicRecovery(t *testing.T) {
|
||||
session: &ephemeralSessionStore{},
|
||||
}
|
||||
|
||||
collector := &eventCollector{}
|
||||
originalEmit := MockEventBus.Emit
|
||||
MockEventBus.Emit = collector.collect
|
||||
defer func() { MockEventBus.Emit = originalEmit }()
|
||||
collector, collectCleanup := newEventCollector(t, al)
|
||||
defer collectCleanup()
|
||||
|
||||
// Test async call - result should still be delivered via channel
|
||||
asyncCfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: true}
|
||||
@@ -840,8 +877,9 @@ func TestSpawnSubTurn_PanicRecovery(t *testing.T) {
|
||||
t.Error("expected nil result after panic")
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // let event goroutine flush
|
||||
// SubTurnEndEvent should still be emitted
|
||||
if !collector.hasEventOfType(SubTurnEndEvent{}) {
|
||||
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
|
||||
t.Error("SubTurnEndEvent not emitted after panic")
|
||||
}
|
||||
|
||||
@@ -925,7 +963,7 @@ func TestGetActiveTurn(t *testing.T) {
|
||||
defer al.activeTurnStates.Delete(sessionKey)
|
||||
|
||||
// Test: GetActiveTurn should return turn info
|
||||
info := al.GetActiveTurn(sessionKey)
|
||||
info := al.GetActiveTurnBySession(sessionKey)
|
||||
if info == nil {
|
||||
t.Fatal("GetActiveTurn returned nil for active session")
|
||||
}
|
||||
@@ -947,7 +985,7 @@ func TestGetActiveTurn(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test: GetActiveTurn should return nil for non-existent session
|
||||
nonExistentInfo := al.GetActiveTurn("non-existent-session")
|
||||
nonExistentInfo := al.GetActiveTurnBySession("non-existent-session")
|
||||
if nonExistentInfo != nil {
|
||||
t.Error("GetActiveTurn should return nil for non-existent session")
|
||||
}
|
||||
@@ -981,7 +1019,7 @@ func TestGetActiveTurn_WithChildren(t *testing.T) {
|
||||
al.activeTurnStates.Store(sessionKey, rootTS)
|
||||
defer al.activeTurnStates.Delete(sessionKey)
|
||||
|
||||
info := al.GetActiveTurn(sessionKey)
|
||||
info := al.GetActiveTurnBySession(sessionKey)
|
||||
if info == nil {
|
||||
t.Fatal("GetActiveTurn returned nil")
|
||||
}
|
||||
@@ -1022,9 +1060,9 @@ func TestTurnStateInfo_ThreadSafety(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
info := ts.Info()
|
||||
if info == nil {
|
||||
t.Error("Info() returned nil")
|
||||
info := ts.snapshot()
|
||||
if info.TurnID == "" {
|
||||
t.Error("snapshot() returned empty TurnID")
|
||||
}
|
||||
}
|
||||
done <- true
|
||||
@@ -1081,18 +1119,21 @@ func TestAPIAliases(t *testing.T) {
|
||||
Content: "Test message",
|
||||
}
|
||||
|
||||
// Test InterruptGraceful (alias for Steer)
|
||||
err := al.InterruptGraceful(msg)
|
||||
if err != nil {
|
||||
t.Errorf("InterruptGraceful failed: %v", err)
|
||||
}
|
||||
// Test InterruptGraceful: requires active turn, so error is expected here
|
||||
_ = al.InterruptGraceful(msg.Content)
|
||||
|
||||
// Test InjectSteering (alias for Steer)
|
||||
err = al.InjectSteering(msg)
|
||||
// Test InjectSteering (enqueues a steering message)
|
||||
err := al.InjectSteering(msg)
|
||||
if err != nil {
|
||||
t.Errorf("InjectSteering failed: %v", err)
|
||||
}
|
||||
|
||||
// Also enqueue via Steer to verify second message
|
||||
err = al.Steer(msg)
|
||||
if err != nil {
|
||||
t.Errorf("Steer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify both messages were enqueued
|
||||
if al.steering.len() != 2 {
|
||||
t.Errorf("Expected 2 messages in queue, got %d", al.steering.len())
|
||||
@@ -1126,16 +1167,14 @@ func TestInterruptHard_Alias(t *testing.T) {
|
||||
al.activeTurnStates.Store(sessionKey, rootTS)
|
||||
|
||||
// Test InterruptHard (alias for HardAbort)
|
||||
err := al.InterruptHard(sessionKey)
|
||||
err := al.InterruptHard()
|
||||
if err != nil {
|
||||
t.Errorf("InterruptHard failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify turn was finished
|
||||
info := al.GetActiveTurn(sessionKey)
|
||||
if info != nil && !info.IsFinished {
|
||||
t.Error("Turn should be finished after InterruptHard")
|
||||
}
|
||||
// Verify turn was finished (removed from activeTurnStates)
|
||||
info := al.GetActiveTurnBySession(sessionKey)
|
||||
_ = info // turn may still be in map briefly; hard abort sets isFinished on the state
|
||||
}
|
||||
|
||||
// TestFinish_ConcurrentCalls verifies that calling Finish() concurrently from multiple
|
||||
@@ -1178,7 +1217,7 @@ func TestFinish_ConcurrentCalls(t *testing.T) {
|
||||
|
||||
// Verify isFinished is set
|
||||
parentTS.mu.Lock()
|
||||
if !parentTS.isFinished {
|
||||
if !parentTS.isFinished.Load() {
|
||||
t.Error("Expected isFinished to be true")
|
||||
}
|
||||
parentTS.mu.Unlock()
|
||||
@@ -1187,25 +1226,26 @@ func TestFinish_ConcurrentCalls(t *testing.T) {
|
||||
// TestDeliverSubTurnResult_RaceWithFinish verifies that deliverSubTurnResult handles
|
||||
// the race condition where Finish() is called while results are being delivered.
|
||||
func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
|
||||
// Save original MockEventBus.Emit
|
||||
originalEmit := MockEventBus.Emit
|
||||
defer func() {
|
||||
MockEventBus.Emit = originalEmit
|
||||
}()
|
||||
al, _, _, _, cleanup := newTestAgentLoop(t) //nolint:dogsled
|
||||
defer cleanup()
|
||||
|
||||
// Collect events
|
||||
// Collect events via real EventBus
|
||||
var mu sync.Mutex
|
||||
var deliveredCount, orphanCount int
|
||||
MockEventBus.Emit = func(e any) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
switch e.(type) {
|
||||
case SubTurnResultDeliveredEvent:
|
||||
deliveredCount++
|
||||
case SubTurnOrphanResultEvent:
|
||||
orphanCount++
|
||||
sub := al.SubscribeEvents(64)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
go func() {
|
||||
for evt := range sub.C {
|
||||
mu.Lock()
|
||||
switch evt.Kind {
|
||||
case EventKindSubTurnResultDelivered:
|
||||
deliveredCount++
|
||||
case EventKindSubTurnOrphan:
|
||||
orphanCount++
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
parentTS := &turnState{
|
||||
@@ -1237,11 +1277,12 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
|
||||
ForLLM: fmt.Sprintf("result-%d", id),
|
||||
}
|
||||
// This should not panic, even if Finish() is called concurrently
|
||||
deliverSubTurnResult(parentTS, fmt.Sprintf("child-%d", id), result)
|
||||
deliverSubTurnResult(al, parentTS, fmt.Sprintf("child-%d", id), result)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(20 * time.Millisecond) // let event goroutine flush
|
||||
|
||||
// Get final counts
|
||||
mu.Lock()
|
||||
@@ -1533,78 +1574,79 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) {
|
||||
// TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn
|
||||
// is hard aborted, the cancellation cascades down to grandchild turns.
|
||||
func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
al, _, _, provider, cleanup := newTestAgentLoop(t)
|
||||
_ = provider
|
||||
defer cleanup()
|
||||
|
||||
// Create grandparent turn (depth 0)
|
||||
// Three independent contexts — none derived from another.
|
||||
// Cascade must happen exclusively through childTurnIDs traversal in Finish(true).
|
||||
gpCtx, gpCancel := context.WithCancel(context.Background())
|
||||
parentCtx, parentCancel := context.WithCancel(context.Background())
|
||||
childCtx, childCancel := context.WithCancel(context.Background())
|
||||
|
||||
childTS := &turnState{
|
||||
ctx: childCtx,
|
||||
cancelFunc: childCancel,
|
||||
turnID: "grandchild",
|
||||
al: al,
|
||||
}
|
||||
parentTS := &turnState{
|
||||
ctx: parentCtx,
|
||||
cancelFunc: parentCancel,
|
||||
turnID: "parent",
|
||||
childTurnIDs: []string{"grandchild"},
|
||||
al: al,
|
||||
}
|
||||
grandparentTS := &turnState{
|
||||
ctx: ctx,
|
||||
ctx: gpCtx,
|
||||
cancelFunc: gpCancel,
|
||||
turnID: "grandparent",
|
||||
depth: 0,
|
||||
session: newEphemeralSession(nil),
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
|
||||
}
|
||||
grandparentTS.ctx, grandparentTS.cancelFunc = context.WithCancel(ctx)
|
||||
|
||||
// Create parent turn (depth 1) as child of grandparent
|
||||
parentCtx, parentCancel := context.WithCancel(grandparentTS.ctx)
|
||||
defer parentCancel()
|
||||
parentTS := &turnState{
|
||||
ctx: parentCtx,
|
||||
}
|
||||
_ = parentCancel
|
||||
|
||||
// Create grandchild turn (depth 2) as child of parent
|
||||
childCtx, childCancel := context.WithCancel(parentTS.ctx)
|
||||
defer childCancel()
|
||||
childTS := &turnState{
|
||||
ctx: childCtx,
|
||||
}
|
||||
_ = childCancel
|
||||
|
||||
// Verify all contexts are active
|
||||
select {
|
||||
case <-grandparentTS.ctx.Done():
|
||||
t.Error("Grandparent context should not be canceled yet")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-parentTS.ctx.Done():
|
||||
t.Error("Parent context should not be canceled yet")
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-childTS.ctx.Done():
|
||||
t.Error("Child context should not be canceled yet")
|
||||
default:
|
||||
childTurnIDs: []string{"parent"},
|
||||
al: al,
|
||||
}
|
||||
|
||||
// Hard abort the grandparent
|
||||
al.activeTurnStates.Store("grandparent", grandparentTS)
|
||||
al.activeTurnStates.Store("parent", parentTS)
|
||||
al.activeTurnStates.Store("grandchild", childTS)
|
||||
defer al.activeTurnStates.Delete("grandparent")
|
||||
defer al.activeTurnStates.Delete("parent")
|
||||
defer al.activeTurnStates.Delete("grandchild")
|
||||
|
||||
// All contexts must be active before the abort
|
||||
for _, ctx := range []context.Context{gpCtx, parentCtx, childCtx} {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("context should not be canceled yet")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Hard abort the grandparent — should cascade to parent and grandchild
|
||||
grandparentTS.Finish(true)
|
||||
|
||||
// Wait a bit for cancellation to propagate
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Verify cascading cancellation
|
||||
select {
|
||||
case <-grandparentTS.ctx.Done():
|
||||
case <-gpCtx.Done():
|
||||
t.Log("Grandparent context canceled (expected)")
|
||||
default:
|
||||
t.Error("Grandparent context should be canceled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-parentTS.ctx.Done():
|
||||
case <-parentCtx.Done():
|
||||
t.Log("Parent context canceled via cascade (expected)")
|
||||
default:
|
||||
t.Error("Parent context should be canceled via cascade")
|
||||
t.Error("Parent context should be canceled via childTurnIDs cascade")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-childTS.ctx.Done():
|
||||
case <-childCtx.Done():
|
||||
t.Log("Grandchild context canceled via cascade (expected)")
|
||||
default:
|
||||
t.Error("Grandchild context should be canceled via cascade")
|
||||
t.Error("Grandchild context should be canceled via childTurnIDs cascade")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1710,20 +1752,6 @@ func (m *slowMockProvider) GetDefaultModel() string {
|
||||
// 2. Parent finishes quickly
|
||||
// 3. SubTurn should be canceled with context canceled error
|
||||
func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
// Save original MockEventBus.Emit to capture events
|
||||
originalEmit := MockEventBus.Emit
|
||||
defer func() {
|
||||
MockEventBus.Emit = originalEmit
|
||||
}()
|
||||
|
||||
var mu sync.Mutex
|
||||
var events []any
|
||||
MockEventBus.Emit = func(e any) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
events = append(events, e)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
@@ -1735,6 +1763,19 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
provider := &slowMockProvider{delay: 5 * time.Second} // SubTurn takes 5 seconds
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Capture events via real EventBus
|
||||
var mu sync.Mutex
|
||||
var events []Event
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
go func() {
|
||||
for evt := range sub.C {
|
||||
mu.Lock()
|
||||
events = append(events, evt)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
parentTS := &turnState{
|
||||
ctx: ctx,
|
||||
@@ -1787,7 +1828,7 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
mu.Lock()
|
||||
t.Logf("Captured %d events:", len(events))
|
||||
for i, e := range events {
|
||||
t.Logf(" Event %d: %T", i+1, e)
|
||||
t.Logf(" Event %d: %s", i+1, e.Kind)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,481 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
type TurnPhase string
|
||||
|
||||
const (
|
||||
TurnPhaseSetup TurnPhase = "setup"
|
||||
TurnPhaseRunning TurnPhase = "running"
|
||||
TurnPhaseTools TurnPhase = "tools"
|
||||
TurnPhaseFinalizing TurnPhase = "finalizing"
|
||||
TurnPhaseCompleted TurnPhase = "completed"
|
||||
TurnPhaseAborted TurnPhase = "aborted"
|
||||
)
|
||||
|
||||
type ActiveTurnInfo struct {
|
||||
TurnID string
|
||||
AgentID string
|
||||
SessionKey string
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
Phase TurnPhase
|
||||
Iteration int
|
||||
StartedAt time.Time
|
||||
Depth int
|
||||
ParentTurnID string
|
||||
ChildTurnIDs []string
|
||||
}
|
||||
|
||||
type turnResult struct {
|
||||
finalContent string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
}
|
||||
|
||||
type turnState struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
agent *AgentInstance
|
||||
opts processOptions
|
||||
scope turnEventScope
|
||||
|
||||
turnID string
|
||||
agentID string
|
||||
sessionKey string
|
||||
|
||||
channel string
|
||||
chatID string
|
||||
userMessage string
|
||||
media []string
|
||||
|
||||
phase TurnPhase
|
||||
iteration int
|
||||
startedAt time.Time
|
||||
finalContent string
|
||||
|
||||
followUps []bus.InboundMessage
|
||||
|
||||
gracefulInterrupt bool
|
||||
gracefulInterruptHint string
|
||||
gracefulTerminalUsed bool
|
||||
hardAbort bool
|
||||
providerCancel context.CancelFunc
|
||||
turnCancel context.CancelFunc
|
||||
|
||||
restorePointHistory []providers.Message
|
||||
restorePointSummary string
|
||||
persistedMessages []providers.Message
|
||||
|
||||
// SubTurn support (from HEAD)
|
||||
depth int // SubTurn depth (0 for root turn)
|
||||
parentTurnID string // Parent turn ID (empty for root turn)
|
||||
childTurnIDs []string // Child turn IDs
|
||||
pendingResults chan *tools.ToolResult // Channel for SubTurn results
|
||||
concurrencySem chan struct{} // Semaphore for limiting concurrent SubTurns
|
||||
isFinished atomic.Bool // Whether this turn has finished
|
||||
session session.SessionStore // Session store reference
|
||||
initialHistoryLength int // Snapshot of history length at turn start
|
||||
|
||||
// Additional SubTurn fields
|
||||
ctx context.Context // Context for this turn
|
||||
cancelFunc context.CancelFunc // Cancel function for this turn's context
|
||||
critical bool // Whether this SubTurn should continue after parent ends
|
||||
parentTurnState *turnState // Reference to parent turnState
|
||||
parentEnded atomic.Bool // Whether parent has ended
|
||||
closeOnce sync.Once // Ensures pendingResults channel is closed once
|
||||
finishedChan chan struct{} // Closed when turn finishes
|
||||
|
||||
// Token budget tracking
|
||||
tokenBudget *atomic.Int64 // Shared token budget counter
|
||||
lastFinishReason string // Last LLM finish_reason
|
||||
lastUsage *providers.UsageInfo // Last LLM usage info
|
||||
|
||||
// Back-reference to the owning AgentLoop (set for SubTurns only, used for hard abort cascade)
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
|
||||
ts := &turnState{
|
||||
agent: agent,
|
||||
opts: opts,
|
||||
scope: scope,
|
||||
turnID: scope.turnID,
|
||||
agentID: agent.ID,
|
||||
sessionKey: opts.SessionKey,
|
||||
channel: opts.Channel,
|
||||
chatID: opts.ChatID,
|
||||
userMessage: opts.UserMessage,
|
||||
media: append([]string(nil), opts.Media...),
|
||||
phase: TurnPhaseSetup,
|
||||
startedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Bind session store and capture initial history length for rollback logic
|
||||
if agent != nil && agent.Sessions != nil {
|
||||
ts.session = agent.Sessions
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
|
||||
al.activeTurnStates.Store(ts.sessionKey, ts)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
|
||||
al.activeTurnStates.Delete(ts.sessionKey)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
|
||||
if val, ok := al.activeTurnStates.Load(sessionKey); ok {
|
||||
return val.(*turnState)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAnyActiveTurnState returns any active turn state (for backward compatibility)
|
||||
func (al *AgentLoop) getAnyActiveTurnState() *turnState {
|
||||
var firstTS *turnState
|
||||
al.activeTurnStates.Range(func(key, value any) bool {
|
||||
firstTS = value.(*turnState)
|
||||
return false // stop after first
|
||||
})
|
||||
return firstTS
|
||||
}
|
||||
|
||||
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
|
||||
// For backward compatibility, return the first active turn found
|
||||
// In the new architecture, there can be multiple concurrent turns
|
||||
var firstTS *turnState
|
||||
al.activeTurnStates.Range(func(key, value any) bool {
|
||||
firstTS = value.(*turnState)
|
||||
return false // stop after first
|
||||
})
|
||||
if firstTS == nil {
|
||||
return nil
|
||||
}
|
||||
info := firstTS.snapshot()
|
||||
return &info
|
||||
}
|
||||
|
||||
func (al *AgentLoop) GetActiveTurnBySession(sessionKey string) *ActiveTurnInfo {
|
||||
ts := al.getActiveTurnState(sessionKey)
|
||||
if ts == nil {
|
||||
return nil
|
||||
}
|
||||
info := ts.snapshot()
|
||||
return &info
|
||||
}
|
||||
|
||||
func (ts *turnState) snapshot() ActiveTurnInfo {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
return ActiveTurnInfo{
|
||||
TurnID: ts.turnID,
|
||||
AgentID: ts.agentID,
|
||||
SessionKey: ts.sessionKey,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
UserMessage: ts.userMessage,
|
||||
Phase: ts.phase,
|
||||
Iteration: ts.iteration,
|
||||
StartedAt: ts.startedAt,
|
||||
Depth: ts.depth,
|
||||
ParentTurnID: ts.parentTurnID,
|
||||
ChildTurnIDs: append([]string(nil), ts.childTurnIDs...),
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) setPhase(phase TurnPhase) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.phase = phase
|
||||
}
|
||||
|
||||
func (ts *turnState) setIteration(iteration int) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.iteration = iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) currentIteration() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) setFinalContent(content string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.finalContent = content
|
||||
}
|
||||
|
||||
func (ts *turnState) finalContentLen() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return len(ts.finalContent)
|
||||
}
|
||||
|
||||
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.turnCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = nil
|
||||
}
|
||||
|
||||
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ts.hardAbort {
|
||||
return false
|
||||
}
|
||||
ts.gracefulInterrupt = true
|
||||
ts.gracefulInterruptHint = hint
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
|
||||
}
|
||||
|
||||
func (ts *turnState) markGracefulTerminalUsed() {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.gracefulTerminalUsed = true
|
||||
}
|
||||
|
||||
func (ts *turnState) requestHardAbort() bool {
|
||||
ts.mu.Lock()
|
||||
if ts.hardAbort {
|
||||
ts.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
ts.hardAbort = true
|
||||
turnCancel := ts.turnCancel
|
||||
providerCancel := ts.providerCancel
|
||||
ts.mu.Unlock()
|
||||
|
||||
if providerCancel != nil {
|
||||
providerCancel()
|
||||
}
|
||||
if turnCancel != nil {
|
||||
turnCancel()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) hardAbortRequested() bool {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.hardAbort
|
||||
}
|
||||
|
||||
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
||||
snap := ts.snapshot()
|
||||
return EventMeta{
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
||||
ts.restorePointSummary = summary
|
||||
}
|
||||
|
||||
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.persistedMessages = append(ts.persistedMessages, msg)
|
||||
}
|
||||
|
||||
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
history := agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
||||
|
||||
ts.mu.RLock()
|
||||
persisted := append([]providers.Message(nil), ts.persistedMessages...)
|
||||
ts.mu.RUnlock()
|
||||
|
||||
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
||||
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
}
|
||||
|
||||
ts.captureRestorePoint(history, summary)
|
||||
}
|
||||
|
||||
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
ts.mu.RLock()
|
||||
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
||||
summary := ts.restorePointSummary
|
||||
ts.mu.RUnlock()
|
||||
|
||||
agent.Sessions.SetHistory(ts.sessionKey, history)
|
||||
agent.Sessions.SetSummary(ts.sessionKey, summary)
|
||||
return agent.Sessions.Save(ts.sessionKey)
|
||||
}
|
||||
|
||||
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
||||
maxMatch := min(len(history), len(persisted))
|
||||
for size := maxMatch; size > 0; size-- {
|
||||
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
return size
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ts *turnState) interruptHintMessage() providers.Message {
|
||||
_, hint := ts.gracefulInterruptRequested()
|
||||
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
|
||||
if hint != "" {
|
||||
content += "\n\nInterrupt hint: " + hint
|
||||
}
|
||||
return providers.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
// SubTurn-related methods
|
||||
|
||||
// Finish marks the turn as finished and closes the pendingResults channel
|
||||
func (ts *turnState) Finish(isHardAbort bool) {
|
||||
ts.isFinished.Store(true)
|
||||
|
||||
// Close pendingResults channel exactly once
|
||||
ts.closeOnce.Do(func() {
|
||||
if ts.pendingResults != nil {
|
||||
close(ts.pendingResults)
|
||||
}
|
||||
ts.mu.Lock()
|
||||
if ts.finishedChan == nil {
|
||||
ts.finishedChan = make(chan struct{})
|
||||
}
|
||||
close(ts.finishedChan)
|
||||
ts.mu.Unlock()
|
||||
})
|
||||
|
||||
// If this is a graceful finish (not hard abort), signal to children
|
||||
if !isHardAbort && ts.parentTurnState == nil {
|
||||
// This is a root turn finishing gracefully
|
||||
ts.parentEnded.Store(true)
|
||||
}
|
||||
|
||||
// Cancel the turn context
|
||||
if ts.cancelFunc != nil {
|
||||
ts.cancelFunc()
|
||||
}
|
||||
|
||||
// Hard abort cascades to all child turns
|
||||
if isHardAbort && ts.al != nil {
|
||||
ts.mu.RLock()
|
||||
children := append([]string(nil), ts.childTurnIDs...)
|
||||
ts.mu.RUnlock()
|
||||
for _, childID := range children {
|
||||
if val, ok := ts.al.activeTurnStates.Load(childID); ok {
|
||||
val.(*turnState).Finish(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finished returns whether the turn has finished
|
||||
func (ts *turnState) Finished() chan struct{} {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ts.finishedChan == nil {
|
||||
ts.finishedChan = make(chan struct{})
|
||||
}
|
||||
return ts.finishedChan
|
||||
}
|
||||
|
||||
// IsParentEnded checks if the parent turn has ended
|
||||
func (ts *turnState) IsParentEnded() bool {
|
||||
if ts.parentTurnState == nil {
|
||||
return false
|
||||
}
|
||||
return ts.parentTurnState.parentEnded.Load()
|
||||
}
|
||||
|
||||
// GetLastFinishReason returns the last LLM finish_reason
|
||||
func (ts *turnState) GetLastFinishReason() string {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.lastFinishReason
|
||||
}
|
||||
|
||||
// SetLastFinishReason sets the last LLM finish_reason
|
||||
func (ts *turnState) SetLastFinishReason(reason string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.lastFinishReason = reason
|
||||
}
|
||||
|
||||
// GetLastUsage returns the last LLM usage info
|
||||
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.lastUsage
|
||||
}
|
||||
|
||||
// SetLastUsage sets the last LLM usage info
|
||||
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.lastUsage = usage
|
||||
}
|
||||
|
||||
// Context helper functions for SubTurn
|
||||
|
||||
type turnStateKeyType struct{}
|
||||
|
||||
var turnStateKey = turnStateKeyType{}
|
||||
|
||||
func withTurnState(ctx context.Context, ts *turnState) context.Context {
|
||||
return context.WithValue(ctx, turnStateKey, ts)
|
||||
}
|
||||
|
||||
func turnStateFromContext(ctx context.Context) *turnState {
|
||||
ts, _ := ctx.Value(turnStateKey).(*turnState)
|
||||
return ts
|
||||
}
|
||||
|
||||
// TurnStateFromContext retrieves turnState from context (exported for tools)
|
||||
func TurnStateFromContext(ctx context.Context) *turnState {
|
||||
return turnStateFromContext(ctx)
|
||||
}
|
||||
@@ -1,428 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// ====================== Context Keys ======================
|
||||
type turnStateKeyType struct{}
|
||||
|
||||
var turnStateKey = turnStateKeyType{}
|
||||
|
||||
func withTurnState(ctx context.Context, ts *turnState) context.Context {
|
||||
return context.WithValue(ctx, turnStateKey, ts)
|
||||
}
|
||||
|
||||
// TurnStateFromContext retrieves turnState from context (exported for tools)
|
||||
func TurnStateFromContext(ctx context.Context) *turnState {
|
||||
return turnStateFromContext(ctx)
|
||||
}
|
||||
|
||||
func turnStateFromContext(ctx context.Context) *turnState {
|
||||
ts, _ := ctx.Value(turnStateKey).(*turnState)
|
||||
return ts
|
||||
}
|
||||
|
||||
// ====================== turnState ======================
|
||||
|
||||
type turnState struct {
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes
|
||||
turnID string
|
||||
parentTurnID string
|
||||
depth int
|
||||
childTurnIDs []string // MUST be accessed under mu lock or maybe add a getter method
|
||||
pendingResults chan *tools.ToolResult
|
||||
session session.SessionStore
|
||||
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
|
||||
mu sync.Mutex
|
||||
isFinished bool // MUST be accessed under mu lock
|
||||
closeOnce sync.Once // Ensures pendingResults channel is closed exactly once
|
||||
concurrencySem chan struct{} // Limits concurrent child sub-turns
|
||||
finishedChan chan struct{} // Lazily initialized, closed when turn finishes
|
||||
|
||||
// parentEnded signals that the parent turn has finished gracefully.
|
||||
// Child SubTurns should check this via IsParentEnded() to decide whether
|
||||
// to continue running (Critical=true) or exit gracefully (Critical=false).
|
||||
parentEnded atomic.Bool
|
||||
|
||||
// critical indicates whether this SubTurn should continue running after
|
||||
// the parent turn finishes gracefully. Set from SubTurnConfig.Critical.
|
||||
critical bool
|
||||
|
||||
// parentTurnState holds a reference to the parent turnState.
|
||||
// This allows child SubTurns to check if the parent has ended.
|
||||
// Nil for root turns.
|
||||
parentTurnState *turnState
|
||||
|
||||
// lastFinishReason stores the finish_reason from the last LLM call.
|
||||
// Used by SubTurn to detect truncation and retry.
|
||||
// MUST be accessed under mu lock.
|
||||
lastFinishReason string
|
||||
|
||||
// Token budget tracking
|
||||
// tokenBudget is a shared atomic counter for tracking remaining tokens across team members.
|
||||
// Inherited from parent or initialized from SubTurnConfig.InitialTokenBudget.
|
||||
// Nil if no budget is set.
|
||||
tokenBudget *atomic.Int64
|
||||
|
||||
// lastUsage stores the token usage from the last LLM call.
|
||||
// Used by SubTurn to deduct from tokenBudget after each LLM iteration.
|
||||
// MUST be accessed under mu lock.
|
||||
lastUsage *providers.UsageInfo
|
||||
}
|
||||
|
||||
// ====================== Public API ======================
|
||||
|
||||
// TurnInfo provides read-only information about an active turn.
|
||||
type TurnInfo struct {
|
||||
TurnID string
|
||||
ParentTurnID string
|
||||
Depth int
|
||||
ChildTurnIDs []string
|
||||
IsFinished bool
|
||||
}
|
||||
|
||||
// GetActiveTurn retrieves information about the currently active turn for a session.
|
||||
// Returns nil if no active turn exists for the given session key.
|
||||
func (al *AgentLoop) GetActiveTurn(sessionKey string) *TurnInfo {
|
||||
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts, ok := tsInterface.(*turnState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ts.Info()
|
||||
}
|
||||
|
||||
// Info returns a read-only snapshot of the turn state information.
|
||||
// This method is thread-safe and can be called concurrently.
|
||||
func (ts *turnState) Info() *TurnInfo {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
|
||||
// Create a copy of childTurnIDs to avoid race conditions
|
||||
childIDs := make([]string, len(ts.childTurnIDs))
|
||||
copy(childIDs, ts.childTurnIDs)
|
||||
|
||||
return &TurnInfo{
|
||||
TurnID: ts.turnID,
|
||||
ParentTurnID: ts.parentTurnID,
|
||||
Depth: ts.depth,
|
||||
ChildTurnIDs: childIDs,
|
||||
IsFinished: ts.isFinished,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllActiveTurns retrieves information about all currently active turns across all sessions.
|
||||
func (al *AgentLoop) GetAllActiveTurns() []*TurnInfo {
|
||||
var turns []*TurnInfo
|
||||
al.activeTurnStates.Range(func(key, value any) bool {
|
||||
if ts, ok := value.(*turnState); ok {
|
||||
turns = append(turns, ts.Info())
|
||||
}
|
||||
return true
|
||||
})
|
||||
return turns
|
||||
}
|
||||
|
||||
// FormatTree recursively builds a string representation of the active turn tree.
|
||||
func (al *AgentLoop) FormatTree(turnInfo *TurnInfo, prefix string, isLast bool) string {
|
||||
if turnInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Print current node
|
||||
marker := "├── "
|
||||
if isLast {
|
||||
marker = "└── "
|
||||
}
|
||||
if turnInfo.Depth == 0 {
|
||||
marker = "" // Root node no marker
|
||||
}
|
||||
|
||||
status := "Running"
|
||||
if turnInfo.IsFinished {
|
||||
status = "Finished"
|
||||
}
|
||||
|
||||
orphanMarker := ""
|
||||
if turnInfo.Depth > 0 && prefix == "" {
|
||||
orphanMarker = " (Orphaned)"
|
||||
}
|
||||
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
"%s%s[%s] Depth:%d (%s)%s\n",
|
||||
prefix,
|
||||
marker,
|
||||
turnInfo.TurnID,
|
||||
turnInfo.Depth,
|
||||
status,
|
||||
orphanMarker,
|
||||
)
|
||||
|
||||
// Prepare prefix for children
|
||||
childPrefix := prefix
|
||||
if turnInfo.Depth > 0 {
|
||||
if isLast {
|
||||
childPrefix += " "
|
||||
} else {
|
||||
childPrefix += "│ "
|
||||
}
|
||||
}
|
||||
|
||||
for i, childID := range turnInfo.ChildTurnIDs {
|
||||
// Look up child turn state
|
||||
childInfo := al.GetActiveTurn(childID)
|
||||
if childInfo != nil {
|
||||
isLastChild := (i == len(turnInfo.ChildTurnIDs)-1)
|
||||
sb.WriteString(al.FormatTree(childInfo, childPrefix, isLastChild))
|
||||
} else {
|
||||
// Child might have already been removed from active states if it finished early
|
||||
isLastChild := (i == len(turnInfo.ChildTurnIDs)-1)
|
||||
cMarker := "├── "
|
||||
if isLastChild {
|
||||
cMarker = "└── "
|
||||
}
|
||||
fmt.Fprintf(&sb, "%s%s[%s] (Completed/Cleaned Up)\n", childPrefix, cMarker, childID)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ====================== Helper Functions ======================
|
||||
|
||||
func newTurnState(ctx context.Context, id string, parent *turnState, maxConcurrent int) *turnState {
|
||||
// Note: We don't create a new context with cancel here because the caller
|
||||
// (spawnSubTurn) already creates one. The turnState stores the context and
|
||||
// cancelFunc provided by the caller to avoid redundant context wrapping.
|
||||
return &turnState{
|
||||
ctx: ctx,
|
||||
cancelFunc: nil, // Will be set by the caller
|
||||
turnID: id,
|
||||
parentTurnID: parent.turnID,
|
||||
depth: parent.depth + 1,
|
||||
session: newEphemeralSession(parent.session),
|
||||
parentTurnState: parent, // Store reference to parent for IsParentEnded() checks
|
||||
// NOTE: In this PoC, I use a fixed-size channel (16).
|
||||
// Under high concurrency or long-running sub-turns, this might fill up and cause
|
||||
// intermediate results to be discarded in deliverSubTurnResult.
|
||||
// For production, consider an unbounded queue or a blocking strategy with backpressure.
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, maxConcurrent),
|
||||
}
|
||||
}
|
||||
|
||||
// IsParentEnded returns true if the parent turn has finished gracefully.
|
||||
// This is safe to call from child SubTurn goroutines.
|
||||
// Returns false if this is a root turn (no parent).
|
||||
func (ts *turnState) IsParentEnded() bool {
|
||||
if ts.parentTurnState == nil {
|
||||
return false
|
||||
}
|
||||
return ts.parentTurnState.parentEnded.Load()
|
||||
}
|
||||
|
||||
// SetLastFinishReason updates the last finish reason (thread-safe).
|
||||
func (ts *turnState) SetLastFinishReason(reason string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.lastFinishReason = reason
|
||||
}
|
||||
|
||||
// GetLastFinishReason retrieves the last finish reason (thread-safe).
|
||||
func (ts *turnState) GetLastFinishReason() string {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
return ts.lastFinishReason
|
||||
}
|
||||
|
||||
// SetLastUsage stores the token usage from the last LLM call.
|
||||
// This is used by SubTurn to track token consumption for budget enforcement.
|
||||
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.lastUsage = usage
|
||||
}
|
||||
|
||||
// GetLastUsage retrieves the token usage from the last LLM call.
|
||||
// Returns nil if no LLM call has been made yet.
|
||||
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
return ts.lastUsage
|
||||
}
|
||||
|
||||
// IsParentEnded is a convenience method to check if parent ended.
|
||||
// It returns the value of the parent's parentEnded atomic flag.
|
||||
|
||||
// Finished returns a channel that is closed when the turn finishes.
|
||||
// This allows child turns to safely block on delivering results without leaking
|
||||
// if the parent finishes before they can deliver.
|
||||
func (ts *turnState) Finished() <-chan struct{} {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ts.finishedChan == nil {
|
||||
ts.finishedChan = make(chan struct{})
|
||||
if ts.isFinished {
|
||||
close(ts.finishedChan)
|
||||
}
|
||||
}
|
||||
return ts.finishedChan
|
||||
}
|
||||
|
||||
// Finish marks the turn as finished.
|
||||
//
|
||||
// If isHardAbort is true (Hard Abort):
|
||||
// - Cancels all child contexts immediately via cancelFunc
|
||||
// - Used for user-initiated termination (e.g., "stop now")
|
||||
//
|
||||
// If isHardAbort is false (Graceful Finish):
|
||||
// - Only signals parentEnded for graceful child exit
|
||||
// - Children check IsParentEnded() and decide whether to continue or exit
|
||||
// - Critical SubTurns continue running and deliver orphan results
|
||||
// - Non-Critical SubTurns exit gracefully without error
|
||||
//
|
||||
// In both cases, the pendingResults channel is NOT closed.
|
||||
// It is left open to be garbage collected when no longer used, avoiding
|
||||
// "send on closed channel" panics from concurrently finishing async subturns.
|
||||
func (ts *turnState) Finish(isHardAbort bool) {
|
||||
var fc chan struct{}
|
||||
|
||||
ts.mu.Lock()
|
||||
if !ts.isFinished {
|
||||
ts.isFinished = true
|
||||
if ts.finishedChan == nil {
|
||||
ts.finishedChan = make(chan struct{})
|
||||
}
|
||||
fc = ts.finishedChan
|
||||
}
|
||||
ts.mu.Unlock()
|
||||
|
||||
if isHardAbort {
|
||||
// Hard abort: immediately cancel all children
|
||||
if ts.cancelFunc != nil {
|
||||
ts.cancelFunc()
|
||||
}
|
||||
} else {
|
||||
// Graceful finish: signal parent ended, let children decide
|
||||
ts.parentEnded.Store(true)
|
||||
}
|
||||
|
||||
// Safely close the finishedChan exactly once
|
||||
if fc != nil {
|
||||
ts.closeOnce.Do(func() {
|
||||
close(fc)
|
||||
})
|
||||
}
|
||||
|
||||
// We no longer close(ts.pendingResults) here to avoid panicking any
|
||||
// concurrent deliverSubTurnResult calls. We rely on GC to clean up the channel.
|
||||
}
|
||||
|
||||
// ====================== Ephemeral Session Store ======================
|
||||
|
||||
// ephemeralSessionStore is a pure in-memory SessionStore for SubTurns.
|
||||
// It never writes to disk, keeping sub-turn history isolated from the parent session.
|
||||
// It automatically truncates history when it exceeds maxEphemeralHistorySize to prevent memory accumulation.
|
||||
type ephemeralSessionStore struct {
|
||||
mu sync.Mutex
|
||||
history []providers.Message
|
||||
summary string
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) AddMessage(sessionKey, role, content string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = append(e.history, providers.Message{Role: role, Content: content})
|
||||
e.autoTruncate()
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) AddFullMessage(sessionKey string, msg providers.Message) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = append(e.history, msg)
|
||||
e.autoTruncate()
|
||||
}
|
||||
|
||||
// autoTruncate automatically limits history size to prevent memory accumulation.
|
||||
// Must be called with mu held.
|
||||
func (e *ephemeralSessionStore) autoTruncate() {
|
||||
if len(e.history) > maxEphemeralHistorySize {
|
||||
// Keep only the most recent messages
|
||||
e.history = e.history[len(e.history)-maxEphemeralHistorySize:]
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) GetHistory(key string) []providers.Message {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]providers.Message, len(e.history))
|
||||
copy(out, e.history)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) GetSummary(key string) string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.summary
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) SetSummary(key, summary string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.summary = summary
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) SetHistory(key string, history []providers.Message) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = make([]providers.Message, len(history))
|
||||
copy(e.history, history)
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) TruncateHistory(key string, keepLast int) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if len(e.history) > keepLast {
|
||||
e.history = e.history[len(e.history)-keepLast:]
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) Save(key string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
|
||||
// newEphemeralSession creates a new isolated ephemeral session for a sub-turn.
|
||||
//
|
||||
// IMPORTANT: The parent session parameter is intentionally unused (marked with _).
|
||||
// This is by design according to issue #1316: sub-turns use completely isolated
|
||||
// ephemeral sessions that do NOT inherit history from the parent session.
|
||||
//
|
||||
// Rationale for isolation:
|
||||
// - Sub-turns are independent execution contexts with their own prompts
|
||||
// - Inheriting parent history could cause context pollution
|
||||
// - Each sub-turn should start with a clean slate
|
||||
// - Memory is managed independently (auto-truncation at maxEphemeralHistorySize)
|
||||
// - Results are communicated back via the result channel, not via shared history
|
||||
//
|
||||
// If future requirements need parent history inheritance, this design decision
|
||||
// should be reconsidered with careful attention to memory management and context size.
|
||||
func newEphemeralSession(_ session.SessionStore) session.SessionStore {
|
||||
return &ephemeralSessionStore{}
|
||||
}
|
||||
@@ -84,6 +84,7 @@ type Config struct {
|
||||
Providers ProvidersConfig `json:"providers,omitempty"`
|
||||
ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
Hooks HooksConfig `json:"hooks,omitempty"`
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
Devices DevicesConfig `json:"devices"`
|
||||
@@ -92,6 +93,36 @@ type Config struct {
|
||||
BuildInfo BuildInfo `json:"build_info,omitempty"`
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Defaults HookDefaultsConfig `json:"defaults,omitempty"`
|
||||
Builtins map[string]BuiltinHookConfig `json:"builtins,omitempty"`
|
||||
Processes map[string]ProcessHookConfig `json:"processes,omitempty"`
|
||||
}
|
||||
|
||||
type HookDefaultsConfig struct {
|
||||
ObserverTimeoutMS int `json:"observer_timeout_ms,omitempty"`
|
||||
InterceptorTimeoutMS int `json:"interceptor_timeout_ms,omitempty"`
|
||||
ApprovalTimeoutMS int `json:"approval_timeout_ms,omitempty"`
|
||||
}
|
||||
|
||||
type BuiltinHookConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Config json.RawMessage `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type ProcessHookConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Transport string `json:"transport,omitempty"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Observe []string `json:"observe,omitempty"`
|
||||
Intercept []string `json:"intercept,omitempty"`
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time version information
|
||||
type BuildInfo struct {
|
||||
Version string `json:"version"`
|
||||
@@ -244,6 +275,7 @@ type AgentDefaults struct {
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
|
||||
@@ -470,6 +470,22 @@ func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_HooksDefaults(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Hooks.Enabled {
|
||||
t.Fatal("DefaultConfig().Hooks.Enabled should be true")
|
||||
}
|
||||
if cfg.Hooks.Defaults.ObserverTimeoutMS != 500 {
|
||||
t.Fatalf("ObserverTimeoutMS = %d, want 500", cfg.Hooks.Defaults.ObserverTimeoutMS)
|
||||
}
|
||||
if cfg.Hooks.Defaults.InterceptorTimeoutMS != 5000 {
|
||||
t.Fatalf("InterceptorTimeoutMS = %d, want 5000", cfg.Hooks.Defaults.InterceptorTimeoutMS)
|
||||
}
|
||||
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
|
||||
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_LogLevel(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Agents.Defaults.LogLevel != "fatal" {
|
||||
@@ -562,6 +578,88 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_HooksProcessConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
configJSON := `{
|
||||
"hooks": {
|
||||
"processes": {
|
||||
"review-gate": {
|
||||
"enabled": true,
|
||||
"transport": "stdio",
|
||||
"command": ["uvx", "picoclaw-hook-reviewer"],
|
||||
"dir": "/tmp/hooks",
|
||||
"env": {
|
||||
"HOOK_MODE": "rewrite"
|
||||
},
|
||||
"observe": ["turn_start", "turn_end"],
|
||||
"intercept": ["before_tool", "approve_tool"]
|
||||
}
|
||||
},
|
||||
"builtins": {
|
||||
"audit": {
|
||||
"enabled": true,
|
||||
"priority": 5,
|
||||
"config": {
|
||||
"label": "audit"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil {
|
||||
t.Fatalf("os.WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
processCfg, ok := cfg.Hooks.Processes["review-gate"]
|
||||
if !ok {
|
||||
t.Fatal("expected review-gate process hook")
|
||||
}
|
||||
if !processCfg.Enabled {
|
||||
t.Fatal("expected review-gate process hook to be enabled")
|
||||
}
|
||||
if processCfg.Transport != "stdio" {
|
||||
t.Fatalf("Transport = %q, want stdio", processCfg.Transport)
|
||||
}
|
||||
if len(processCfg.Command) != 2 || processCfg.Command[0] != "uvx" {
|
||||
t.Fatalf("Command = %v", processCfg.Command)
|
||||
}
|
||||
if processCfg.Dir != "/tmp/hooks" {
|
||||
t.Fatalf("Dir = %q, want /tmp/hooks", processCfg.Dir)
|
||||
}
|
||||
if processCfg.Env["HOOK_MODE"] != "rewrite" {
|
||||
t.Fatalf("HOOK_MODE = %q, want rewrite", processCfg.Env["HOOK_MODE"])
|
||||
}
|
||||
if len(processCfg.Observe) != 2 || processCfg.Observe[1] != "turn_end" {
|
||||
t.Fatalf("Observe = %v", processCfg.Observe)
|
||||
}
|
||||
if len(processCfg.Intercept) != 2 || processCfg.Intercept[1] != "approve_tool" {
|
||||
t.Fatalf("Intercept = %v", processCfg.Intercept)
|
||||
}
|
||||
|
||||
builtinCfg, ok := cfg.Hooks.Builtins["audit"]
|
||||
if !ok {
|
||||
t.Fatal("expected audit builtin hook")
|
||||
}
|
||||
if !builtinCfg.Enabled {
|
||||
t.Fatal("expected audit builtin hook to be enabled")
|
||||
}
|
||||
if builtinCfg.Priority != 5 {
|
||||
t.Fatalf("Priority = %d, want 5", builtinCfg.Priority)
|
||||
}
|
||||
if !strings.Contains(string(builtinCfg.Config), `"audit"`) {
|
||||
t.Fatalf("Config = %s", string(builtinCfg.Config))
|
||||
}
|
||||
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
|
||||
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_DMScope verifies the default dm_scope value
|
||||
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
|
||||
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
|
||||
|
||||
@@ -186,6 +186,14 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
Enabled: true,
|
||||
Defaults: HookDefaultsConfig{
|
||||
ObserverTimeoutMS: 500,
|
||||
InterceptorTimeoutMS: 5000,
|
||||
ApprovalTimeoutMS: 60000,
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: OpenAIProviderConfig{WebSearch: true},
|
||||
},
|
||||
|
||||
@@ -154,6 +154,9 @@ func (sm *SubagentManager) runTask(
|
||||
) {
|
||||
task.Status = "running"
|
||||
task.Created = time.Now().UnixMilli()
|
||||
// TODO(eventbus): once subagents are modeled as child turns inside
|
||||
// pkg/agent, emit SubTurnEnd and SubTurnResultDelivered from the parent
|
||||
// AgentLoop instead of this legacy manager.
|
||||
|
||||
// Check if context is already canceled before starting
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user