mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into feat/skill-channel-commands
# Conflicts: # pkg/agent/loop.go
This commit is contained in:
+27
-16
@@ -222,13 +222,10 @@ func (cb *ContextBuilder) InvalidateCache() {
|
||||
// invalidation (bootstrap files + memory). Skill roots are handled separately
|
||||
// because they require both directory-level and recursive file-level checks.
|
||||
func (cb *ContextBuilder) sourcePaths() []string {
|
||||
return []string{
|
||||
filepath.Join(cb.workspace, "AGENTS.md"),
|
||||
filepath.Join(cb.workspace, "SOUL.md"),
|
||||
filepath.Join(cb.workspace, "USER.md"),
|
||||
filepath.Join(cb.workspace, "IDENTITY.md"),
|
||||
filepath.Join(cb.workspace, "memory", "MEMORY.md"),
|
||||
}
|
||||
agentDefinition := cb.LoadAgentDefinition()
|
||||
paths := agentDefinition.trackedPaths(cb.workspace)
|
||||
paths = append(paths, filepath.Join(cb.workspace, "memory", "MEMORY.md"))
|
||||
return uniquePaths(paths)
|
||||
}
|
||||
|
||||
// skillRoots returns all skill root directories that can affect
|
||||
@@ -432,18 +429,32 @@ func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Ti
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
bootstrapFiles := []string{
|
||||
"AGENTS.md",
|
||||
"SOUL.md",
|
||||
"USER.md",
|
||||
"IDENTITY.md",
|
||||
var sb strings.Builder
|
||||
|
||||
agentDefinition := cb.LoadAgentDefinition()
|
||||
if agentDefinition.Agent != nil {
|
||||
label := string(agentDefinition.Source)
|
||||
if label == "" {
|
||||
label = relativeWorkspacePath(cb.workspace, agentDefinition.Agent.Path)
|
||||
}
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", label, agentDefinition.Agent.Body)
|
||||
}
|
||||
if agentDefinition.Soul != nil {
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
"## %s\n\n%s\n\n",
|
||||
relativeWorkspacePath(cb.workspace, agentDefinition.Soul.Path),
|
||||
agentDefinition.Soul.Content,
|
||||
)
|
||||
}
|
||||
if agentDefinition.User != nil {
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "USER.md", agentDefinition.User.Content)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, filename := range bootstrapFiles {
|
||||
filePath := filepath.Join(cb.workspace, filename)
|
||||
if agentDefinition.Source != AgentDefinitionSourceAgent {
|
||||
filePath := filepath.Join(cb.workspace, "IDENTITY.md")
|
||||
if data, err := os.ReadFile(filePath); err == nil {
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", filename, data)
|
||||
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "IDENTITY.md", data)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// parseTurnBoundaries returns the starting index of each Turn in the history.
|
||||
// A Turn is a complete "user input → LLM iterations → final response" cycle
|
||||
// (as defined in #1316). Each Turn begins at a user message and extends
|
||||
// through all subsequent assistant/tool messages until the next user message.
|
||||
//
|
||||
// Cutting at a Turn boundary guarantees that no tool-call sequence
|
||||
// (assistant+ToolCalls → tool results) is split across the cut.
|
||||
func parseTurnBoundaries(history []providers.Message) []int {
|
||||
var starts []int
|
||||
for i, msg := range history {
|
||||
if msg.Role == "user" {
|
||||
starts = append(starts, i)
|
||||
}
|
||||
}
|
||||
return starts
|
||||
}
|
||||
|
||||
// isSafeBoundary reports whether index is a valid Turn boundary — i.e.,
|
||||
// a position where the kept portion (history[index:]) begins at a user
|
||||
// message, so no tool-call sequence is torn apart.
|
||||
func isSafeBoundary(history []providers.Message, index int) bool {
|
||||
if index <= 0 || index >= len(history) {
|
||||
return true
|
||||
}
|
||||
return history[index].Role == "user"
|
||||
}
|
||||
|
||||
// findSafeBoundary locates the nearest Turn boundary to targetIndex.
|
||||
// It prefers the boundary at or before targetIndex (preserving more recent
|
||||
// context). Falls back to the nearest boundary after targetIndex, and
|
||||
// returns targetIndex unchanged only when no Turn boundary exists at all.
|
||||
func findSafeBoundary(history []providers.Message, targetIndex int) int {
|
||||
if len(history) == 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIndex <= 0 {
|
||||
return 0
|
||||
}
|
||||
if targetIndex >= len(history) {
|
||||
return len(history)
|
||||
}
|
||||
|
||||
turns := parseTurnBoundaries(history)
|
||||
if len(turns) == 0 {
|
||||
return targetIndex
|
||||
}
|
||||
|
||||
// Find the last Turn boundary at or before targetIndex.
|
||||
// Prefer backward: keeps more recent messages.
|
||||
backward := -1
|
||||
for _, t := range turns {
|
||||
if t <= targetIndex {
|
||||
backward = t
|
||||
}
|
||||
}
|
||||
if backward > 0 {
|
||||
return backward
|
||||
}
|
||||
|
||||
// No valid Turn boundary before target (or only at index 0 which
|
||||
// would keep everything). Use the first Turn after targetIndex.
|
||||
for _, t := range turns {
|
||||
if t > targetIndex {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
// No Turn boundary after targetIndex either. The only boundary is at
|
||||
// index 0, meaning the entire history is a single Turn. Return 0 to
|
||||
// signal that safe compression is not possible — callers check for
|
||||
// mid <= 0 and skip compression in that case.
|
||||
return 0
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a single message,
|
||||
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
||||
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
||||
func estimateMessageTokens(msg providers.Message) int {
|
||||
chars := utf8.RuneCountInString(msg.Content)
|
||||
|
||||
// ReasoningContent (extended thinking / chain-of-thought) can be
|
||||
// substantial and is stored in session history via AddFullMessage.
|
||||
if msg.ReasoningContent != "" {
|
||||
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
}
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
chars += len(tc.ID) + len(tc.Type)
|
||||
if tc.Function != nil {
|
||||
// Count function name + arguments (the wire format for most providers).
|
||||
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
} else {
|
||||
// Fallback: some provider formats use top-level Name without Function.
|
||||
chars += len(tc.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ToolCallID != "" {
|
||||
chars += len(msg.ToolCallID)
|
||||
}
|
||||
|
||||
// Per-message overhead for role label, JSON structure, separators.
|
||||
const messageOverhead = 12
|
||||
chars += messageOverhead
|
||||
|
||||
tokens := chars * 2 / 5
|
||||
|
||||
// Media items (images, files) are serialized by provider adapters into
|
||||
// multipart or image_url payloads. Add a fixed per-item token estimate
|
||||
// directly (not through the chars heuristic) since actual cost depends
|
||||
// on resolution and provider-specific image tokenization.
|
||||
const mediaTokensPerItem = 256
|
||||
tokens += len(msg.Media) * mediaTokensPerItem
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// estimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Each tool's name, description, and
|
||||
// JSON schema parameters contribute to the context window budget.
|
||||
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
if len(defs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
totalChars := 0
|
||||
for _, d := range defs {
|
||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||
|
||||
if d.Function.Parameters != nil {
|
||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||
totalChars += len(paramJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Per-tool overhead: type field, JSON structure, separators.
|
||||
totalChars += 20
|
||||
}
|
||||
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
||||
// and output reserve would exceed the model's context window. This enables
|
||||
// proactive compression before calling the LLM, rather than reacting to 400 errors.
|
||||
func isOverContextBudget(
|
||||
contextWindow int,
|
||||
messages []providers.Message,
|
||||
toolDefs []providers.ToolDefinition,
|
||||
maxTokens int,
|
||||
) bool {
|
||||
msgTokens := 0
|
||||
for _, m := range messages {
|
||||
msgTokens += estimateMessageTokens(m)
|
||||
}
|
||||
|
||||
toolTokens := estimateToolDefsTokens(toolDefs)
|
||||
total := msgTokens + toolTokens + maxTokens
|
||||
|
||||
return total > contextWindow
|
||||
}
|
||||
@@ -0,0 +1,826 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// msgUser creates a user message.
|
||||
func msgUser(content string) providers.Message {
|
||||
return providers.Message{Role: "user", Content: content}
|
||||
}
|
||||
|
||||
// msgAssistant creates a plain assistant message (no tool calls).
|
||||
func msgAssistant(content string) providers.Message {
|
||||
return providers.Message{Role: "assistant", Content: content}
|
||||
}
|
||||
|
||||
// msgAssistantTC creates an assistant message with tool calls.
|
||||
func msgAssistantTC(toolIDs ...string) providers.Message {
|
||||
tcs := make([]providers.ToolCall, len(toolIDs))
|
||||
for i, id := range toolIDs {
|
||||
tcs[i] = providers.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
Name: "tool_" + id,
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_" + id,
|
||||
Arguments: `{"key":"value"}`,
|
||||
},
|
||||
}
|
||||
}
|
||||
return providers.Message{Role: "assistant", ToolCalls: tcs}
|
||||
}
|
||||
|
||||
// msgTool creates a tool result message.
|
||||
func msgTool(callID, content string) providers.Message {
|
||||
return providers.Message{Role: "tool", ToolCallID: callID, Content: content}
|
||||
}
|
||||
|
||||
func TestParseTurnBoundaries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "empty history",
|
||||
history: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "simple exchange",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
},
|
||||
want: []int{0, 2},
|
||||
},
|
||||
{
|
||||
name: "tool-call Turn",
|
||||
history: []providers.Message{
|
||||
msgUser("search"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "result"),
|
||||
msgAssistant("found it"),
|
||||
msgUser("thanks"),
|
||||
msgAssistant("welcome"),
|
||||
},
|
||||
want: []int{0, 4},
|
||||
},
|
||||
{
|
||||
name: "chained tool calls in single Turn",
|
||||
history: []providers.Message{
|
||||
msgUser("save and notify"),
|
||||
msgAssistantTC("tc_save"),
|
||||
msgTool("tc_save", "saved"),
|
||||
msgAssistantTC("tc_notify"),
|
||||
msgTool("tc_notify", "notified"),
|
||||
msgAssistant("done"),
|
||||
},
|
||||
want: []int{0},
|
||||
},
|
||||
{
|
||||
name: "no user messages",
|
||||
history: []providers.Message{
|
||||
msgAssistant("a1"),
|
||||
msgAssistant("a2"),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "leading non-user messages",
|
||||
history: []providers.Message{
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistant("greeting"),
|
||||
msgUser("hello"),
|
||||
msgAssistant("hi"),
|
||||
},
|
||||
want: []int{3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := parseTurnBoundaries(tt.history)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("parseTurnBoundaries() = %v, want %v", got, tt.want)
|
||||
return
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("parseTurnBoundaries()[%d] = %d, want %d", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
index int
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty history, index 0",
|
||||
history: nil,
|
||||
index: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single user message, index 0",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
index: 0,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single user message, index 1 (end)",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
index: 1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "at user message",
|
||||
history: []providers.Message{
|
||||
msgAssistant("hello"),
|
||||
msgUser("how are you"),
|
||||
msgAssistant("fine"),
|
||||
},
|
||||
index: 1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "at assistant without tool calls",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
msgAssistant("response"),
|
||||
msgUser("follow up"),
|
||||
},
|
||||
index: 1,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "at assistant with tool calls",
|
||||
history: []providers.Message{
|
||||
msgUser("search something"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "result"),
|
||||
msgAssistant("here is what I found"),
|
||||
},
|
||||
index: 1,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "at tool result",
|
||||
history: []providers.Message{
|
||||
msgUser("do something"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "done"),
|
||||
msgAssistant("completed"),
|
||||
},
|
||||
index: 2,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negative index",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
},
|
||||
index: -1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "index beyond length",
|
||||
history: []providers.Message{
|
||||
msgUser("hello"),
|
||||
},
|
||||
index: 5,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSafeBoundary(tt.history, tt.index)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSafeBoundary(history, %d) = %v, want %v", tt.index, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
history []providers.Message
|
||||
targetIndex int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "empty history",
|
||||
history: nil,
|
||||
targetIndex: 0,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "target at 0",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
targetIndex: 0,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "target beyond length",
|
||||
history: []providers.Message{msgUser("hi")},
|
||||
targetIndex: 5,
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "target already at user message",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
},
|
||||
targetIndex: 2,
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "target at assistant, scan backward finds user",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistant("a2"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 3, // assistant "a2"
|
||||
want: 2, // backward to user "q2"
|
||||
},
|
||||
{
|
||||
name: "target inside tool sequence, scan backward finds user",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1", "tc2"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("summary"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 4, // tool result "r1"
|
||||
want: 2, // backward: 3=assistant+TC (not safe), 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "target inside tool sequence, backward finds user before chain",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1", "tc2"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("summary"),
|
||||
msgUser("q3"),
|
||||
},
|
||||
targetIndex: 5, // tool result "r2"
|
||||
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "no backward user, scan forward finds one",
|
||||
history: []providers.Message{
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q1"),
|
||||
},
|
||||
targetIndex: 1, // tool result
|
||||
want: 3, // forward to user "q1"
|
||||
},
|
||||
{
|
||||
name: "multi-step tool chain preserves atomicity",
|
||||
history: []providers.Message{
|
||||
msgUser("q1"),
|
||||
msgAssistant("a1"),
|
||||
msgUser("q2"),
|
||||
msgAssistantTC("tc1"),
|
||||
msgTool("tc1", "r1"),
|
||||
msgAssistantTC("tc2"),
|
||||
msgTool("tc2", "r2"),
|
||||
msgAssistant("final"),
|
||||
msgUser("q3"),
|
||||
msgAssistant("a3"),
|
||||
},
|
||||
targetIndex: 5, // second assistant+TC
|
||||
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
|
||||
},
|
||||
{
|
||||
name: "all non-user messages returns target unchanged",
|
||||
history: []providers.Message{
|
||||
msgAssistant("a1"),
|
||||
msgAssistant("a2"),
|
||||
msgAssistant("a3"),
|
||||
},
|
||||
targetIndex: 1,
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := findSafeBoundary(tt.history, tt.targetIndex)
|
||||
if got != tt.want {
|
||||
t.Errorf("findSafeBoundary(history, %d) = %d, want %d",
|
||||
tt.targetIndex, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary_SingleTurnReturnsZero(t *testing.T) {
|
||||
// A single Turn with no subsequent user message. The only Turn boundary
|
||||
// is at index 0; cutting anywhere else would split the Turn's tool
|
||||
// sequence. findSafeBoundary must return 0 so callers skip compression.
|
||||
history := []providers.Message{
|
||||
msgUser("do everything"), // 0 ← only Turn boundary
|
||||
msgAssistantTC("tc1"), // 1
|
||||
msgTool("tc1", "result"), // 2
|
||||
msgAssistant("all done"), // 3
|
||||
}
|
||||
|
||||
got := findSafeBoundary(history, 2)
|
||||
if got != 0 {
|
||||
t.Errorf("findSafeBoundary(single_turn, 2) = %d, want 0 (cannot split single Turn)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) {
|
||||
// A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user
|
||||
// Target is inside the chain; boundary should skip the entire chain backward.
|
||||
history := []providers.Message{
|
||||
msgUser("start"), // 0
|
||||
msgAssistant("before chain"), // 1
|
||||
msgUser("trigger"), // 2 ← expected safe boundary
|
||||
msgAssistantTC("t1", "t2", "t3"), // 3
|
||||
msgTool("t1", "r1"), // 4
|
||||
msgTool("t2", "r2"), // 5
|
||||
msgTool("t3", "r3"), // 6
|
||||
msgAssistantTC("t4"), // 7
|
||||
msgTool("t4", "r4"), // 8
|
||||
msgAssistant("chain done"), // 9
|
||||
msgUser("next"), // 10
|
||||
}
|
||||
|
||||
// Target at index 6 (middle of tool results)
|
||||
got := findSafeBoundary(history, 6)
|
||||
if got != 2 {
|
||||
t.Errorf("findSafeBoundary(history, 6) = %d, want 2 (user before chain)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg providers.Message
|
||||
want int // minimum expected tokens (exact value depends on overhead)
|
||||
}{
|
||||
{
|
||||
name: "plain user message",
|
||||
msg: msgUser("Hello, world!"),
|
||||
want: 1, // at least some tokens
|
||||
},
|
||||
{
|
||||
name: "empty message still has overhead",
|
||||
msg: providers.Message{Role: "user"},
|
||||
want: 1, // message overhead alone
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
msg: msgAssistantTC("tc_123"),
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "tool result with ID",
|
||||
msg: msgTool("call_abc", "Here is the search result with lots of content"),
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateMessageTokens(tt.msg)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
|
||||
plain := msgAssistant("thinking")
|
||||
withTC := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "thinking",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "web_search",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "web_search",
|
||||
Arguments: `{"query":"picoclaw agent framework","max_results":5}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
withTCTokens := estimateMessageTokens(withTC)
|
||||
|
||||
if withTCTokens <= plainTokens {
|
||||
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
||||
withTCTokens, plainTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
|
||||
// Multi-byte characters (e.g. emoji, accented letters) are single runes
|
||||
// but may map to different token counts. The heuristic should still produce
|
||||
// reasonable estimates via RuneCountInString.
|
||||
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
||||
tokens := estimateMessageTokens(msg)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
|
||||
// Simulate a tool call with large JSON arguments.
|
||||
largeArgs := fmt.Sprintf(`{"content":"%s"}`, strings.Repeat("x", 5000))
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_large",
|
||||
Type: "function",
|
||||
Name: "write_file",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "write_file",
|
||||
Arguments: largeArgs,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
||||
if tokens < 2000 {
|
||||
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
|
||||
plain := msgAssistant("result")
|
||||
withReasoning := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "result",
|
||||
ReasoningContent: strings.Repeat("thinking step ", 200),
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
reasoningTokens := estimateMessageTokens(withReasoning)
|
||||
|
||||
if reasoningTokens <= plainTokens {
|
||||
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||
reasoningTokens, plainTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_MediaItems(t *testing.T) {
|
||||
plain := msgUser("describe this")
|
||||
withMedia := providers.Message{
|
||||
Role: "user",
|
||||
Content: "describe this",
|
||||
Media: []string{"media://img1.png", "media://img2.png"},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
mediaTokens := estimateMessageTokens(withMedia)
|
||||
|
||||
if mediaTokens <= plainTokens {
|
||||
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
|
||||
mediaTokens, plainTokens)
|
||||
}
|
||||
|
||||
// Each media item should add exactly 256 tokens (not run through chars*2/5).
|
||||
expectedDelta := 256 * 2
|
||||
actualDelta := mediaTokens - plainTokens
|
||||
if actualDelta != expectedDelta {
|
||||
t.Errorf("2 media items should add %d tokens, got delta %d", expectedDelta, actualDelta)
|
||||
}
|
||||
}
|
||||
|
||||
// --- estimateToolDefsTokens tests ---
|
||||
|
||||
func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
defs []providers.ToolDefinition
|
||||
want int // minimum expected tokens
|
||||
}{
|
||||
{
|
||||
name: "empty tool list",
|
||||
defs: nil,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "single tool with params",
|
||||
defs: []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "Search the web for information",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "tool without params",
|
||||
defs: []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "list_dir",
|
||||
Description: "List directory contents",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateToolDefsTokens(tt.defs)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
|
||||
makeTool := func(name string) providers.ToolDefinition {
|
||||
return providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: name,
|
||||
Description: "A test tool that does something useful",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"input": map[string]any{"type": "string", "description": "Input value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := estimateToolDefsTokens([]providers.ToolDefinition{
|
||||
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
||||
})
|
||||
|
||||
if three <= one {
|
||||
t.Errorf("3 tools (%d tokens) should exceed 1 tool (%d tokens)", three, one)
|
||||
}
|
||||
}
|
||||
|
||||
// --- isOverContextBudget tests ---
|
||||
|
||||
func TestIsOverContextBudget(t *testing.T) {
|
||||
systemMsg := providers.Message{Role: "system", Content: strings.Repeat("x", 1000)}
|
||||
userMsg := msgUser("hello")
|
||||
smallHistory := []providers.Message{systemMsg, msgUser("q1"), msgAssistant("a1"), userMsg}
|
||||
|
||||
tools := []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contextWindow int
|
||||
messages []providers.Message
|
||||
toolDefs []providers.ToolDefinition
|
||||
maxTokens int
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "within budget",
|
||||
contextWindow: 100000,
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 4096,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "over budget with small window",
|
||||
contextWindow: 100, // very small window
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 4096,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large max_tokens eats budget",
|
||||
contextWindow: 2000,
|
||||
messages: smallHistory,
|
||||
toolDefs: tools,
|
||||
maxTokens: 1800, // leaves almost no room
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty messages within budget",
|
||||
contextWindow: 10000,
|
||||
messages: nil,
|
||||
toolDefs: nil,
|
||||
maxTokens: 4096,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isOverContextBudget(tt.contextWindow, tt.messages, tt.toolDefs, tt.maxTokens)
|
||||
if got != tt.want {
|
||||
t.Errorf("isOverContextBudget() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests reflecting actual session data shape ---
|
||||
// Session history never contains system messages. The system prompt is
|
||||
// built dynamically by BuildMessages. These tests use realistic history
|
||||
// shapes: user/assistant/tool only, with tool chains and reasoning content.
|
||||
|
||||
func TestFindSafeBoundary_SessionHistoryNoSystem(t *testing.T) {
|
||||
// Real session history starts with a user message, not a system message.
|
||||
history := []providers.Message{
|
||||
msgUser("hello"), // 0
|
||||
msgAssistant("hi there"), // 1
|
||||
msgUser("search for X"), // 2
|
||||
msgAssistantTC("tc1"), // 3
|
||||
msgTool("tc1", "found X"), // 4
|
||||
msgAssistant("here is X"), // 5
|
||||
msgUser("thanks"), // 6
|
||||
msgAssistant("you're welcome"), // 7
|
||||
}
|
||||
|
||||
// Mid-point is 4 (tool result). Should snap backward to 2 (user).
|
||||
got := findSafeBoundary(history, 4)
|
||||
if got != 2 {
|
||||
t.Errorf("findSafeBoundary(session_history, 4) = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSafeBoundary_SessionWithChainedTools(t *testing.T) {
|
||||
// Session with chained tool calls (save then notify).
|
||||
history := []providers.Message{
|
||||
msgUser("save and notify"), // 0
|
||||
msgAssistantTC("tc_save"), // 1
|
||||
msgTool("tc_save", "saved"), // 2
|
||||
msgAssistantTC("tc_notify"), // 3
|
||||
msgTool("tc_notify", "notified"), // 4
|
||||
msgAssistant("done"), // 5
|
||||
msgUser("check status"), // 6
|
||||
msgAssistant("all good"), // 7
|
||||
}
|
||||
|
||||
// Target at 3 (inside chain). Should find user at 0, but backward
|
||||
// scan stops at i>0, so forward scan finds user at 6.
|
||||
// Actually: backward from 3: 2=tool (no), 1=assistantTC (no). Forward: 4=tool, 5=asst, 6=user ✓
|
||||
got := findSafeBoundary(history, 3)
|
||||
if got != 6 {
|
||||
t.Errorf("findSafeBoundary(chained_tools, 3) = %d, want 6", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
// Message with all fields populated — mirrors what AddFullMessage stores.
|
||||
msg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "Here is the analysis.",
|
||||
ReasoningContent: strings.Repeat("Let me think about this carefully. ", 50),
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "analyze",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "analyze",
|
||||
Arguments: `{"data":"sample","depth":3}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
|
||||
// ReasoningContent alone is ~1700 chars → ~680 tokens.
|
||||
// Content + TC + overhead adds more. Should be well above 500.
|
||||
if tokens < 500 {
|
||||
t.Errorf("message with reasoning+toolcalls should have significant tokens, got %d", tokens)
|
||||
}
|
||||
|
||||
// Compare without reasoning to ensure it's counted.
|
||||
msgNoReasoning := msg
|
||||
msgNoReasoning.ReasoningContent = ""
|
||||
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
|
||||
|
||||
if tokens <= tokensNoReasoning {
|
||||
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOverContextBudget_RealisticSession(t *testing.T) {
|
||||
// Simulate what BuildMessages produces: system + session history + current user.
|
||||
// System message is built by BuildMessages, not stored in session.
|
||||
systemMsg := providers.Message{
|
||||
Role: "system",
|
||||
Content: strings.Repeat("system prompt content ", 100),
|
||||
}
|
||||
sessionHistory := []providers.Message{
|
||||
msgUser("first question"),
|
||||
msgAssistant("first answer"),
|
||||
msgUser("use tool X"),
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll use tool X",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "tc1", Type: "function", Name: "tool_x",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_x",
|
||||
Arguments: `{"query":"test","verbose":true}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: strings.Repeat("result data ", 200), ToolCallID: "tc1"},
|
||||
msgAssistant("Here are the results from tool X."),
|
||||
}
|
||||
currentUser := msgUser("follow up question")
|
||||
|
||||
// Assemble as BuildMessages would.
|
||||
messages := make([]providers.Message, 0, 1+len(sessionHistory)+1)
|
||||
messages = append(messages, systemMsg)
|
||||
messages = append(messages, sessionHistory...)
|
||||
messages = append(messages, currentUser)
|
||||
|
||||
tools := []providers.ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: "tool_x",
|
||||
Description: "A useful tool",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// With a large context window, should be within budget.
|
||||
if isOverContextBudget(131072, messages, tools, 32768) {
|
||||
t.Error("realistic session should be within 131072 context window")
|
||||
}
|
||||
|
||||
// With a tiny context window, should exceed budget.
|
||||
if !isOverContextBudget(500, messages, tools, 32768) {
|
||||
t.Error("realistic session should exceed 500 context window")
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func setupWorkspace(t *testing.T, files map[string]string) string {
|
||||
// Codex (only reads last system message as instructions).
|
||||
func TestSingleSystemMessage(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nTest agent.",
|
||||
"AGENT.md": "# Agent\nTest agent.",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
@@ -202,10 +202,10 @@ func TestMtimeAutoInvalidation(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "bootstrap file change",
|
||||
file: "IDENTITY.md",
|
||||
contentV1: "# Original Identity",
|
||||
contentV2: "# Updated Identity",
|
||||
checkField: "Updated Identity",
|
||||
file: "AGENT.md",
|
||||
contentV1: "# Original Agent",
|
||||
contentV2: "# Updated Agent",
|
||||
checkField: "Updated Agent",
|
||||
},
|
||||
{
|
||||
name: "memory file change",
|
||||
@@ -280,7 +280,7 @@ func TestMtimeAutoInvalidation(t *testing.T) {
|
||||
// even when source files haven't changed (useful for tests and reload commands).
|
||||
func TestExplicitInvalidateCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Test Identity",
|
||||
"AGENT.md": "# Test Agent",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
@@ -307,8 +307,8 @@ func TestExplicitInvalidateCache(t *testing.T) {
|
||||
// when no files change (regression test for issue #607).
|
||||
func TestCacheStability(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nContent",
|
||||
"SOUL.md": "# Soul\nContent",
|
||||
"AGENT.md": "# Agent\nContent",
|
||||
"SOUL.md": "# Soul\nContent",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
@@ -607,7 +607,7 @@ description: delete-me-v1
|
||||
// Run with: go test -race ./pkg/agent/ -run TestConcurrentBuildSystemPromptWithCache
|
||||
func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nConcurrency test agent.",
|
||||
"AGENT.md": "# Agent\nConcurrency test agent.",
|
||||
"SOUL.md": "# Soul\nBe helpful.",
|
||||
"memory/MEMORY.md": "# Memory\nUser prefers Go.",
|
||||
"skills/demo/SKILL.md": "---\nname: demo\ndescription: \"demo skill\"\n---\n# Demo",
|
||||
@@ -714,7 +714,7 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
|
||||
|
||||
os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755)
|
||||
os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755)
|
||||
for _, name := range []string{"IDENTITY.md", "SOUL.md", "USER.md"} {
|
||||
for _, name := range []string{"AGENT.md", "SOUL.md"} {
|
||||
os.WriteFile(filepath.Join(tmpDir, name), []byte(strings.Repeat("Content.\n", 10)), 0o644)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gomarkdown/markdown/parser"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// AgentDefinitionSource identifies which agent bootstrap file produced the definition.
|
||||
type AgentDefinitionSource string
|
||||
|
||||
const (
|
||||
// AgentDefinitionSourceAgent indicates the new AGENT.md format.
|
||||
AgentDefinitionSourceAgent AgentDefinitionSource = "AGENT.md"
|
||||
// AgentDefinitionSourceAgents indicates the legacy AGENTS.md format.
|
||||
AgentDefinitionSourceAgents AgentDefinitionSource = "AGENTS.md"
|
||||
)
|
||||
|
||||
// AgentFrontmatter holds machine-readable AGENT.md configuration.
|
||||
//
|
||||
// Known fields are exposed directly for convenience. Fields keeps the full
|
||||
// parsed frontmatter so future refactors can read additional keys without
|
||||
// changing the loader contract again.
|
||||
type AgentFrontmatter struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tools []string `json:"tools,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTurns *int `json:"maxTurns,omitempty"`
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
MCPServers []string `json:"mcpServers,omitempty"`
|
||||
Fields map[string]any `json:"fields,omitempty"`
|
||||
}
|
||||
|
||||
// AgentPromptDefinition represents the parsed AGENT.md or AGENTS.md prompt file.
|
||||
type AgentPromptDefinition struct {
|
||||
Path string `json:"path"`
|
||||
Raw string `json:"raw"`
|
||||
Body string `json:"body"`
|
||||
RawFrontmatter string `json:"raw_frontmatter,omitempty"`
|
||||
Frontmatter AgentFrontmatter `json:"frontmatter"`
|
||||
}
|
||||
|
||||
// SoulDefinition represents the resolved SOUL.md file linked to the agent.
|
||||
type SoulDefinition struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// UserDefinition represents the resolved USER.md file linked to the workspace.
|
||||
type UserDefinition struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// AgentContextDefinition captures the workspace agent definition in a runtime-friendly shape.
|
||||
type AgentContextDefinition struct {
|
||||
Source AgentDefinitionSource `json:"source,omitempty"`
|
||||
Agent *AgentPromptDefinition `json:"agent,omitempty"`
|
||||
Soul *SoulDefinition `json:"soul,omitempty"`
|
||||
User *UserDefinition `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// LoadAgentDefinition parses the workspace agent bootstrap files.
|
||||
//
|
||||
// It prefers the new AGENT.md format and its paired SOUL.md file. When the
|
||||
// structured files are absent, it falls back to the legacy AGENTS.md layout so
|
||||
// the current runtime can transition incrementally.
|
||||
func (cb *ContextBuilder) LoadAgentDefinition() AgentContextDefinition {
|
||||
return loadAgentDefinition(cb.workspace)
|
||||
}
|
||||
|
||||
func loadAgentDefinition(workspace string) AgentContextDefinition {
|
||||
definition := AgentContextDefinition{}
|
||||
definition.User = loadUserDefinition(workspace)
|
||||
agentPath := filepath.Join(workspace, string(AgentDefinitionSourceAgent))
|
||||
if content, err := os.ReadFile(agentPath); err == nil {
|
||||
prompt := parseAgentPromptDefinition(agentPath, string(content))
|
||||
definition.Source = AgentDefinitionSourceAgent
|
||||
definition.Agent = &prompt
|
||||
soulPath := filepath.Join(workspace, "SOUL.md")
|
||||
if content, err := os.ReadFile(soulPath); err == nil {
|
||||
definition.Soul = &SoulDefinition{
|
||||
Path: soulPath,
|
||||
Content: string(content),
|
||||
}
|
||||
}
|
||||
return definition
|
||||
}
|
||||
|
||||
legacyPath := filepath.Join(workspace, string(AgentDefinitionSourceAgents))
|
||||
if content, err := os.ReadFile(legacyPath); err == nil {
|
||||
definition.Source = AgentDefinitionSourceAgents
|
||||
definition.Agent = &AgentPromptDefinition{
|
||||
Path: legacyPath,
|
||||
Raw: string(content),
|
||||
Body: string(content),
|
||||
}
|
||||
}
|
||||
|
||||
defaultSoulPath := filepath.Join(workspace, "SOUL.md")
|
||||
if definition.Source != "" || fileExists(defaultSoulPath) {
|
||||
if content, err := os.ReadFile(defaultSoulPath); err == nil {
|
||||
definition.Soul = &SoulDefinition{
|
||||
Path: defaultSoulPath,
|
||||
Content: string(content),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return definition
|
||||
}
|
||||
|
||||
func (definition AgentContextDefinition) trackedPaths(workspace string) []string {
|
||||
paths := []string{
|
||||
filepath.Join(workspace, string(AgentDefinitionSourceAgent)),
|
||||
filepath.Join(workspace, "SOUL.md"),
|
||||
filepath.Join(workspace, "USER.md"),
|
||||
}
|
||||
if definition.Source != AgentDefinitionSourceAgent {
|
||||
paths = append(paths,
|
||||
filepath.Join(workspace, string(AgentDefinitionSourceAgents)),
|
||||
filepath.Join(workspace, "IDENTITY.md"),
|
||||
)
|
||||
}
|
||||
return uniquePaths(paths)
|
||||
}
|
||||
|
||||
func loadUserDefinition(workspace string) *UserDefinition {
|
||||
userPath := filepath.Join(workspace, "USER.md")
|
||||
if content, err := os.ReadFile(userPath); err == nil {
|
||||
return &UserDefinition{
|
||||
Path: userPath,
|
||||
Content: string(content),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAgentPromptDefinition(path, content string) AgentPromptDefinition {
|
||||
frontmatter, body := splitAgentFrontmatter(content)
|
||||
return AgentPromptDefinition{
|
||||
Path: path,
|
||||
Raw: content,
|
||||
Body: body,
|
||||
RawFrontmatter: frontmatter,
|
||||
Frontmatter: parseAgentFrontmatter(path, frontmatter),
|
||||
}
|
||||
}
|
||||
|
||||
func parseAgentFrontmatter(path, frontmatter string) AgentFrontmatter {
|
||||
frontmatter = strings.TrimSpace(frontmatter)
|
||||
if frontmatter == "" {
|
||||
return AgentFrontmatter{}
|
||||
}
|
||||
|
||||
rawFields := make(map[string]any)
|
||||
if err := yaml.Unmarshal([]byte(frontmatter), &rawFields); err != nil {
|
||||
logger.WarnCF("agent", "Failed to parse AGENT.md frontmatter", map[string]any{
|
||||
"path": path,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return AgentFrontmatter{}
|
||||
}
|
||||
|
||||
var typed struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Tools []string `yaml:"tools"`
|
||||
Model string `yaml:"model"`
|
||||
MaxTurns *int `yaml:"maxTurns"`
|
||||
Skills []string `yaml:"skills"`
|
||||
MCPServers []string `yaml:"mcpServers"`
|
||||
}
|
||||
if err := yaml.Unmarshal([]byte(frontmatter), &typed); err != nil {
|
||||
logger.WarnCF("agent", "Failed to decode AGENT.md frontmatter fields", map[string]any{
|
||||
"path": path,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return AgentFrontmatter{}
|
||||
}
|
||||
|
||||
return AgentFrontmatter{
|
||||
Name: strings.TrimSpace(typed.Name),
|
||||
Description: strings.TrimSpace(typed.Description),
|
||||
Tools: append([]string(nil), typed.Tools...),
|
||||
Model: strings.TrimSpace(typed.Model),
|
||||
MaxTurns: typed.MaxTurns,
|
||||
Skills: append([]string(nil), typed.Skills...),
|
||||
MCPServers: append([]string(nil), typed.MCPServers...),
|
||||
Fields: rawFields,
|
||||
}
|
||||
}
|
||||
|
||||
func splitAgentFrontmatter(content string) (frontmatter, body string) {
|
||||
normalized := string(parser.NormalizeNewlines([]byte(content)))
|
||||
lines := strings.Split(normalized, "\n")
|
||||
if len(lines) == 0 || lines[0] != "---" {
|
||||
return "", content
|
||||
}
|
||||
|
||||
end := -1
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "---" {
|
||||
end = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if end == -1 {
|
||||
return "", content
|
||||
}
|
||||
|
||||
frontmatter = strings.Join(lines[1:end], "\n")
|
||||
body = strings.Join(lines[end+1:], "\n")
|
||||
body = strings.TrimLeft(body, "\n")
|
||||
return frontmatter, body
|
||||
}
|
||||
|
||||
func relativeWorkspacePath(workspace, path string) string {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return ""
|
||||
}
|
||||
relativePath, err := filepath.Rel(workspace, path)
|
||||
if err == nil && relativePath != "." && !strings.HasPrefix(relativePath, "..") {
|
||||
return filepath.ToSlash(relativePath)
|
||||
}
|
||||
return filepath.Clean(path)
|
||||
}
|
||||
|
||||
func uniquePaths(paths []string) []string {
|
||||
result := make([]string, 0, len(paths))
|
||||
for _, path := range paths {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
continue
|
||||
}
|
||||
cleaned := filepath.Clean(path)
|
||||
if slices.Contains(result, cleaned) {
|
||||
continue
|
||||
}
|
||||
result = append(result, cleaned)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadAgentDefinitionParsesFrontmatterAndSoul(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": `---
|
||||
name: pico
|
||||
description: Structured agent
|
||||
model: claude-3-7-sonnet
|
||||
tools:
|
||||
- shell
|
||||
- search
|
||||
maxTurns: 8
|
||||
skills:
|
||||
- review
|
||||
- search-docs
|
||||
mcpServers:
|
||||
- github
|
||||
metadata:
|
||||
mode: strict
|
||||
---
|
||||
# Agent
|
||||
|
||||
Act directly and use tools first.
|
||||
`,
|
||||
"SOUL.md": "# Soul\nStay precise.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.Source != AgentDefinitionSourceAgent {
|
||||
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgent, definition.Source)
|
||||
}
|
||||
if definition.Agent == nil {
|
||||
t.Fatal("expected AGENT.md definition to be loaded")
|
||||
}
|
||||
if definition.Agent.Body == "" || !strings.Contains(definition.Agent.Body, "Act directly") {
|
||||
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Name != "pico" {
|
||||
t.Fatalf("expected name to be parsed, got %q", definition.Agent.Frontmatter.Name)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Model != "claude-3-7-sonnet" {
|
||||
t.Fatalf("expected model to be parsed, got %q", definition.Agent.Frontmatter.Model)
|
||||
}
|
||||
if len(definition.Agent.Frontmatter.Tools) != 2 {
|
||||
t.Fatalf("expected tools to be parsed, got %v", definition.Agent.Frontmatter.Tools)
|
||||
}
|
||||
if definition.Agent.Frontmatter.MaxTurns == nil || *definition.Agent.Frontmatter.MaxTurns != 8 {
|
||||
t.Fatalf("expected maxTurns to be parsed, got %v", definition.Agent.Frontmatter.MaxTurns)
|
||||
}
|
||||
if len(definition.Agent.Frontmatter.Skills) != 2 {
|
||||
t.Fatalf("expected skills to be parsed, got %v", definition.Agent.Frontmatter.Skills)
|
||||
}
|
||||
if len(definition.Agent.Frontmatter.MCPServers) != 1 || definition.Agent.Frontmatter.MCPServers[0] != "github" {
|
||||
t.Fatalf("expected mcpServers to be parsed, got %v", definition.Agent.Frontmatter.MCPServers)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Fields["metadata"] == nil {
|
||||
t.Fatal("expected arbitrary frontmatter fields to remain available")
|
||||
}
|
||||
|
||||
if definition.Soul == nil {
|
||||
t.Fatal("expected SOUL.md to be loaded")
|
||||
}
|
||||
if !strings.Contains(definition.Soul.Content, "Stay precise") {
|
||||
t.Fatalf("expected soul content to be loaded, got %q", definition.Soul.Content)
|
||||
}
|
||||
if definition.Soul.Path != filepath.Join(tmpDir, "SOUL.md") {
|
||||
t.Fatalf("expected default SOUL.md path, got %q", definition.Soul.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAgentDefinitionFallsBackToLegacyAgentsMarkdown(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENTS.md": "# Legacy Agent\nKeep compatibility.",
|
||||
"SOUL.md": "# Soul\nLegacy soul.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.Source != AgentDefinitionSourceAgents {
|
||||
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgents, definition.Source)
|
||||
}
|
||||
if definition.Agent == nil {
|
||||
t.Fatal("expected AGENTS.md to be loaded")
|
||||
}
|
||||
if definition.Agent.RawFrontmatter != "" {
|
||||
t.Fatalf("legacy AGENTS.md should not have frontmatter, got %q", definition.Agent.RawFrontmatter)
|
||||
}
|
||||
if !strings.Contains(definition.Agent.Body, "Keep compatibility") {
|
||||
t.Fatalf("expected legacy body to be preserved, got %q", definition.Agent.Body)
|
||||
}
|
||||
if definition.Soul == nil || !strings.Contains(definition.Soul.Content, "Legacy soul") {
|
||||
t.Fatal("expected default SOUL.md to be loaded for legacy format")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAgentDefinitionLoadsWorkspaceUserMarkdown(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nStructured agent.",
|
||||
"USER.md": "# User\nWorkspace preferences.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.User == nil {
|
||||
t.Fatal("expected USER.md to be loaded")
|
||||
}
|
||||
if definition.User.Path != filepath.Join(tmpDir, "USER.md") {
|
||||
t.Fatalf("expected workspace USER.md path, got %q", definition.User.Path)
|
||||
}
|
||||
if !strings.Contains(definition.User.Content, "Workspace preferences") {
|
||||
t.Fatalf("expected workspace USER.md content, got %q", definition.User.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAgentDefinitionInvalidFrontmatterFallsBackToEmptyStructuredFields(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": `---
|
||||
name: pico
|
||||
tools:
|
||||
- shell
|
||||
broken
|
||||
---
|
||||
# Agent
|
||||
|
||||
Keep going.
|
||||
`,
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
definition := cb.LoadAgentDefinition()
|
||||
|
||||
if definition.Agent == nil {
|
||||
t.Fatal("expected AGENT.md definition to be loaded")
|
||||
}
|
||||
if !strings.Contains(definition.Agent.Body, "Keep going.") {
|
||||
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
|
||||
}
|
||||
if definition.Agent.Frontmatter.Name != "" ||
|
||||
definition.Agent.Frontmatter.Description != "" ||
|
||||
definition.Agent.Frontmatter.Model != "" ||
|
||||
definition.Agent.Frontmatter.MaxTurns != nil ||
|
||||
len(definition.Agent.Frontmatter.Tools) != 0 ||
|
||||
len(definition.Agent.Frontmatter.Skills) != 0 ||
|
||||
len(definition.Agent.Frontmatter.MCPServers) != 0 ||
|
||||
len(definition.Agent.Frontmatter.Fields) != 0 {
|
||||
t.Fatalf("expected invalid frontmatter to decode as empty struct, got %+v", definition.Agent.Frontmatter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBootstrapFilesUsesAgentBodyNotFrontmatter(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": `---
|
||||
name: pico
|
||||
model: codex-mini
|
||||
---
|
||||
# Agent
|
||||
|
||||
Follow the body prompt.
|
||||
`,
|
||||
"SOUL.md": "# Soul\nSpeak plainly.",
|
||||
"IDENTITY.md": "# Identity\nWorkspace identity.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
bootstrap := cb.LoadBootstrapFiles()
|
||||
|
||||
if !strings.Contains(bootstrap, "Follow the body prompt") {
|
||||
t.Fatalf("expected AGENT.md body in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
if !strings.Contains(bootstrap, "Speak plainly") {
|
||||
t.Fatalf("expected resolved soul content in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
if strings.Contains(bootstrap, "name: pico") {
|
||||
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
|
||||
}
|
||||
if strings.Contains(bootstrap, "model: codex-mini") {
|
||||
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
|
||||
}
|
||||
if !strings.Contains(bootstrap, "SOUL.md") {
|
||||
t.Fatalf("expected bootstrap to label SOUL.md, got %q", bootstrap)
|
||||
}
|
||||
if strings.Contains(bootstrap, "Workspace identity") {
|
||||
t.Fatalf("structured bootstrap should ignore IDENTITY.md, got %q", bootstrap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBootstrapFilesIncludesWorkspaceUserMarkdown(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nFollow the new structure.",
|
||||
"SOUL.md": "# Soul\nSpeak plainly.",
|
||||
"USER.md": "# User\nShared profile.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
bootstrap := cb.LoadBootstrapFiles()
|
||||
|
||||
if !strings.Contains(bootstrap, "Shared profile") {
|
||||
t.Fatalf("expected workspace USER.md in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
if !strings.Contains(bootstrap, "## USER.md") {
|
||||
t.Fatalf("expected USER.md heading in bootstrap, got %q", bootstrap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructuredAgentIgnoresIdentityChanges(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nFollow the new structure.",
|
||||
"SOUL.md": "# Soul\nVersion one.",
|
||||
"IDENTITY.md": "# Identity\nLegacy identity.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
|
||||
promptV1 := cb.BuildSystemPromptWithCache()
|
||||
if strings.Contains(promptV1, "Legacy identity") {
|
||||
t.Fatalf("structured prompt should not include IDENTITY.md, got %q", promptV1)
|
||||
}
|
||||
|
||||
identityPath := filepath.Join(tmpDir, "IDENTITY.md")
|
||||
if err := os.WriteFile(identityPath, []byte("# Identity\nVersion two."), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(identityPath, future, future); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if changed {
|
||||
t.Fatal("IDENTITY.md should not invalidate cache for structured agent definitions")
|
||||
}
|
||||
|
||||
promptV2 := cb.BuildSystemPromptWithCache()
|
||||
if promptV1 != promptV2 {
|
||||
t.Fatal("structured prompt should remain stable after IDENTITY.md changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructuredAgentUserChangesInvalidateCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"AGENT.md": "# Agent\nFollow the new structure.",
|
||||
"SOUL.md": "# Soul\nVersion one.",
|
||||
"USER.md": "# User\nInitial workspace preferences.",
|
||||
})
|
||||
defer cleanupWorkspace(t, tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
|
||||
promptV1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(promptV1, "Initial workspace preferences") {
|
||||
t.Fatalf("expected workspace USER.md in prompt, got %q", promptV1)
|
||||
}
|
||||
|
||||
userPath := filepath.Join(tmpDir, "USER.md")
|
||||
if err := os.WriteFile(userPath, []byte("# User\nUpdated workspace preferences."), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(userPath, future, future); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("workspace USER.md changes should invalidate cache")
|
||||
}
|
||||
|
||||
promptV2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(promptV2, "Updated workspace preferences") {
|
||||
t.Fatalf("expected updated workspace USER.md in prompt, got %q", promptV2)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupWorkspace(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
t.Fatalf("failed to clean up workspace %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultEventSubscriberBuffer = 16
|
||||
|
||||
// EventSubscription identifies a subscriber channel returned by EventBus.Subscribe.
|
||||
type EventSubscription struct {
|
||||
ID uint64
|
||||
C <-chan Event
|
||||
}
|
||||
|
||||
type eventSubscriber struct {
|
||||
ch chan Event
|
||||
}
|
||||
|
||||
// EventBus is a lightweight multi-subscriber broadcaster for agent-loop events.
|
||||
type EventBus struct {
|
||||
mu sync.RWMutex
|
||||
subs map[uint64]eventSubscriber
|
||||
nextID uint64
|
||||
closed bool
|
||||
dropped [eventKindCount]atomic.Int64
|
||||
}
|
||||
|
||||
// NewEventBus creates a new in-process event broadcaster.
|
||||
func NewEventBus() *EventBus {
|
||||
return &EventBus{
|
||||
subs: make(map[uint64]eventSubscriber),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers a new subscriber with the requested channel buffer size.
|
||||
// A non-positive buffer uses the default size.
|
||||
func (b *EventBus) Subscribe(buffer int) EventSubscription {
|
||||
if buffer <= 0 {
|
||||
buffer = defaultEventSubscriberBuffer
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
ch := make(chan Event)
|
||||
close(ch)
|
||||
return EventSubscription{C: ch}
|
||||
}
|
||||
|
||||
b.nextID++
|
||||
id := b.nextID
|
||||
ch := make(chan Event, buffer)
|
||||
b.subs[id] = eventSubscriber{ch: ch}
|
||||
return EventSubscription{ID: id, C: ch}
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscriber and closes its channel.
|
||||
func (b *EventBus) Unsubscribe(id uint64) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
sub, ok := b.subs[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(b.subs, id)
|
||||
close(sub.ch)
|
||||
}
|
||||
|
||||
// Emit broadcasts an event to all current subscribers without blocking.
|
||||
// When a subscriber channel is full, the event is dropped for that subscriber.
|
||||
func (b *EventBus) Emit(evt Event) {
|
||||
if evt.Time.IsZero() {
|
||||
evt.Time = time.Now()
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range b.subs {
|
||||
select {
|
||||
case sub.ch <- evt:
|
||||
default:
|
||||
if evt.Kind < eventKindCount {
|
||||
b.dropped[evt.Kind].Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dropped returns the number of dropped events for a given kind.
|
||||
func (b *EventBus) Dropped(kind EventKind) int64 {
|
||||
if kind >= eventKindCount {
|
||||
return 0
|
||||
}
|
||||
return b.dropped[kind].Load()
|
||||
}
|
||||
|
||||
// Close closes all subscriber channels and stops future broadcasts.
|
||||
func (b *EventBus) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
b.closed = true
|
||||
for id, sub := range b.subs {
|
||||
close(sub.ch)
|
||||
delete(b.subs, id)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,684 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
|
||||
eventBus := NewEventBus()
|
||||
sub := eventBus.Subscribe(1)
|
||||
|
||||
eventBus.Emit(Event{
|
||||
Kind: EventKindTurnStart,
|
||||
Meta: EventMeta{TurnID: "turn-1"},
|
||||
})
|
||||
|
||||
select {
|
||||
case evt := <-sub.C:
|
||||
if evt.Kind != EventKindTurnStart {
|
||||
t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
|
||||
}
|
||||
if evt.Meta.TurnID != "turn-1" {
|
||||
t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
}
|
||||
|
||||
eventBus.Unsubscribe(sub.ID)
|
||||
if _, ok := <-sub.C; ok {
|
||||
t.Fatal("expected subscriber channel to be closed after unsubscribe")
|
||||
}
|
||||
|
||||
eventBus.Close()
|
||||
closedSub := eventBus.Subscribe(1)
|
||||
if _, ok := <-closedSub.C; ok {
|
||||
t.Fatal("expected closed bus to return a closed subscriber channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
|
||||
eventBus := NewEventBus()
|
||||
sub := eventBus.Subscribe(1)
|
||||
defer eventBus.Unsubscribe(sub.ID)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
eventBus.Emit(Event{Kind: EventKindLLMRequest})
|
||||
}
|
||||
|
||||
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
|
||||
t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
|
||||
}
|
||||
|
||||
if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
|
||||
t.Fatalf("expected 999 dropped events, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
type scriptedToolProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *scriptedToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
toolDefs []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "mock_custom",
|
||||
Arguments: map[string]any{"task": "ping"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "done",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *scriptedToolProvider) GetDefaultModel() string {
|
||||
return "scripted-tool-model"
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &scriptedToolProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(&mockCustomTool{})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if response != "done" {
|
||||
t.Fatalf("expected final response 'done', got %q", response)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
if len(events) != 8 {
|
||||
t.Fatalf("expected 8 events, got %d", len(events))
|
||||
}
|
||||
|
||||
kinds := make([]EventKind, 0, len(events))
|
||||
for _, evt := range events {
|
||||
kinds = append(kinds, evt.Kind)
|
||||
}
|
||||
|
||||
expectedKinds := []EventKind{
|
||||
EventKindTurnStart,
|
||||
EventKindLLMRequest,
|
||||
EventKindLLMResponse,
|
||||
EventKindToolExecStart,
|
||||
EventKindToolExecEnd,
|
||||
EventKindLLMRequest,
|
||||
EventKindLLMResponse,
|
||||
EventKindTurnEnd,
|
||||
}
|
||||
if !slices.Equal(kinds, expectedKinds) {
|
||||
t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
|
||||
}
|
||||
|
||||
turnID := events[0].Meta.TurnID
|
||||
for i, evt := range events {
|
||||
if evt.Meta.TurnID != turnID {
|
||||
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
|
||||
}
|
||||
if evt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
startPayload, ok := events[0].Payload.(TurnStartPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnStartPayload, got %T", events[0].Payload)
|
||||
}
|
||||
if startPayload.UserMessage != "run tool" {
|
||||
t.Fatalf("expected user message 'run tool', got %q", startPayload.UserMessage)
|
||||
}
|
||||
|
||||
toolStartPayload, ok := events[3].Payload.(ToolExecStartPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecStartPayload, got %T", events[3].Payload)
|
||||
}
|
||||
if toolStartPayload.Tool != "mock_custom" {
|
||||
t.Fatalf("expected tool name mock_custom, got %q", toolStartPayload.Tool)
|
||||
}
|
||||
|
||||
toolEndPayload, ok := events[4].Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", events[4].Payload)
|
||||
}
|
||||
if toolEndPayload.Tool != "mock_custom" {
|
||||
t.Fatalf("expected tool end payload for mock_custom, got %q", toolEndPayload.Tool)
|
||||
}
|
||||
if toolEndPayload.IsError {
|
||||
t.Fatal("expected mock_custom tool to succeed")
|
||||
}
|
||||
|
||||
turnEndPayload, ok := events[len(events)-1].Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", events[len(events)-1].Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusCompleted {
|
||||
t.Fatalf("expected completed turn, got %q", turnEndPayload.Status)
|
||||
}
|
||||
if turnEndPayload.Iterations != 2 {
|
||||
t.Fatalf("expected 2 iterations, got %d", turnEndPayload.Iterations)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-steering-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool1ExecCh := make(chan struct{})
|
||||
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
||||
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
||||
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "tool_one",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_one",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Type: "function",
|
||||
Name: "tool_two",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_two",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "steered response",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resultCh := make(chan string, 1)
|
||||
go func() {
|
||||
resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
|
||||
resultCh <- resp
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tool1ExecCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for tool_one to start")
|
||||
}
|
||||
|
||||
if err := al.Steer(providers.Message{Role: "user", Content: "change course"}); err != nil {
|
||||
t.Fatalf("Steer failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case resp := <-resultCh:
|
||||
if resp != "steered response" {
|
||||
t.Fatalf("expected steered response, got %q", resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for steered response")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
|
||||
if !ok {
|
||||
t.Fatal("expected steering injected event")
|
||||
}
|
||||
steeringPayload, ok := steeringEvt.Payload.(SteeringInjectedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected SteeringInjectedPayload, got %T", steeringEvt.Payload)
|
||||
}
|
||||
if steeringPayload.Count != 1 {
|
||||
t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
|
||||
}
|
||||
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected skipped tool event")
|
||||
}
|
||||
skippedPayload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if skippedPayload.Tool != "tool_two" {
|
||||
t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
|
||||
}
|
||||
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Role != "user" {
|
||||
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindSteering {
|
||||
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
|
||||
}
|
||||
if interruptPayload.ContentLen != len("change course") {
|
||||
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-compress-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
||||
provider := &failFirstMockProvider{
|
||||
failures: 1,
|
||||
failError: contextErr,
|
||||
successResp: "Recovered from context error",
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
|
||||
{Role: "user", Content: "Old message 1"},
|
||||
{Role: "assistant", Content: "Old response 1"},
|
||||
{Role: "user", Content: "Old message 2"},
|
||||
{Role: "assistant", Content: "Old response 2"},
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "Trigger message",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "Recovered from context error" {
|
||||
t.Fatalf("expected retry success, got %q", resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
retryEvt, ok := findEvent(events, EventKindLLMRetry)
|
||||
if !ok {
|
||||
t.Fatal("expected llm retry event")
|
||||
}
|
||||
retryPayload, ok := retryEvt.Payload.(LLMRetryPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected LLMRetryPayload, got %T", retryEvt.Payload)
|
||||
}
|
||||
if retryPayload.Reason != "context_limit" {
|
||||
t.Fatalf("expected context_limit retry reason, got %q", retryPayload.Reason)
|
||||
}
|
||||
if retryPayload.Attempt != 1 {
|
||||
t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
|
||||
}
|
||||
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
payload, ok := compressEvt.Payload.(ContextCompressPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
|
||||
}
|
||||
if payload.Reason != ContextCompressReasonRetry {
|
||||
t.Fatalf("expected retry compress reason, got %q", payload.Reason)
|
||||
}
|
||||
if payload.DroppedMessages == 0 {
|
||||
t.Fatal("expected dropped messages to be recorded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-summary-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
ContextWindow: 8000,
|
||||
SummarizeMessageThreshold: 2,
|
||||
SummarizeTokenPercent: 75,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary text"})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
|
||||
{Role: "user", Content: "Question one"},
|
||||
{Role: "assistant", Content: "Answer one"},
|
||||
{Role: "user", Content: "Question two"},
|
||||
{Role: "assistant", Content: "Answer two"},
|
||||
{Role: "user", Content: "Question three"},
|
||||
{Role: "assistant", Content: "Answer three"},
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1")
|
||||
al.summarizeSession(defaultAgent, "session-1", turnScope)
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
|
||||
if !ok {
|
||||
t.Fatal("expected session summarize event")
|
||||
}
|
||||
payload, ok := summaryEvt.Payload.(SessionSummarizePayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected SessionSummarizePayload, got %T", summaryEvt.Payload)
|
||||
}
|
||||
if payload.SummaryLen == 0 {
|
||||
t.Fatal("expected non-empty summary length")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-eventbus-followup-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_async_1",
|
||||
Type: "function",
|
||||
Name: "async_followup",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "async_followup",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "async launched",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
doneCh := make(chan struct{})
|
||||
al.RegisterTool(&asyncFollowUpTool{
|
||||
name: "async_followup",
|
||||
followUpText: "background result",
|
||||
completionSig: doneCh,
|
||||
})
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run async tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "async launched" {
|
||||
t.Fatalf("expected final response 'async launched', got %q", resp)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for async tool completion")
|
||||
}
|
||||
|
||||
followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
|
||||
return evt.Kind == EventKindFollowUpQueued
|
||||
})
|
||||
payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected FollowUpQueuedPayload, got %T", followUpEvt.Payload)
|
||||
}
|
||||
if payload.SourceTool != "async_followup" {
|
||||
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
|
||||
}
|
||||
if payload.Channel != "cli" {
|
||||
t.Fatalf("expected channel cli, got %q", payload.Channel)
|
||||
}
|
||||
if payload.ChatID != "direct" {
|
||||
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
|
||||
}
|
||||
if payload.ContentLen != len("background result") {
|
||||
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
|
||||
}
|
||||
if followUpEvt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
|
||||
}
|
||||
if followUpEvt.Meta.TurnID == "" {
|
||||
t.Fatal("expected follow-up event to include turn id")
|
||||
}
|
||||
}
|
||||
|
||||
func collectEventStream(ch <-chan Event) []Event {
|
||||
var events []Event
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
return events
|
||||
}
|
||||
events = append(events, evt)
|
||||
default:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
|
||||
t.Helper()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("event stream closed before expected event arrived")
|
||||
}
|
||||
if match(evt) {
|
||||
return evt
|
||||
}
|
||||
case <-timer.C:
|
||||
t.Fatal("timed out waiting for expected event")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findEvent(events []Event, kind EventKind) (Event, bool) {
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
return evt, true
|
||||
}
|
||||
}
|
||||
return Event{}, false
|
||||
}
|
||||
|
||||
type stringError string
|
||||
|
||||
func (e stringError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
type asyncFollowUpTool struct {
|
||||
name string
|
||||
followUpText string
|
||||
completionSig chan struct{}
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Description() string {
|
||||
return "async follow-up tool for testing"
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
return tools.AsyncResult("async follow-up scheduled")
|
||||
}
|
||||
|
||||
func (t *asyncFollowUpTool) ExecuteAsync(
|
||||
ctx context.Context,
|
||||
args map[string]any,
|
||||
cb tools.AsyncCallback,
|
||||
) *tools.ToolResult {
|
||||
go func() {
|
||||
cb(ctx, &tools.ToolResult{ForLLM: t.followUpText})
|
||||
if t.completionSig != nil {
|
||||
close(t.completionSig)
|
||||
}
|
||||
}()
|
||||
return tools.AsyncResult("async follow-up scheduled")
|
||||
}
|
||||
|
||||
var (
|
||||
_ tools.Tool = (*mockCustomTool)(nil)
|
||||
_ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
|
||||
)
|
||||
@@ -0,0 +1,271 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EventKind identifies a structured agent-loop event.
|
||||
type EventKind uint8
|
||||
|
||||
const (
|
||||
// EventKindTurnStart is emitted when a turn begins processing.
|
||||
EventKindTurnStart EventKind = iota
|
||||
// EventKindTurnEnd is emitted when a turn finishes, successfully or with an error.
|
||||
EventKindTurnEnd
|
||||
// EventKindLLMRequest is emitted before a provider chat request is made.
|
||||
EventKindLLMRequest
|
||||
// EventKindLLMDelta is emitted when a streaming provider yields a partial delta.
|
||||
EventKindLLMDelta
|
||||
// EventKindLLMResponse is emitted after a provider chat response is received.
|
||||
EventKindLLMResponse
|
||||
// EventKindLLMRetry is emitted when an LLM request is retried.
|
||||
EventKindLLMRetry
|
||||
// EventKindContextCompress is emitted when session history is forcibly compressed.
|
||||
EventKindContextCompress
|
||||
// EventKindSessionSummarize is emitted when asynchronous summarization completes.
|
||||
EventKindSessionSummarize
|
||||
// EventKindToolExecStart is emitted immediately before a tool executes.
|
||||
EventKindToolExecStart
|
||||
// EventKindToolExecEnd is emitted immediately after a tool finishes executing.
|
||||
EventKindToolExecEnd
|
||||
// EventKindToolExecSkipped is emitted when a queued tool call is skipped.
|
||||
EventKindToolExecSkipped
|
||||
// EventKindSteeringInjected is emitted when queued steering is injected into context.
|
||||
EventKindSteeringInjected
|
||||
// EventKindFollowUpQueued is emitted when an async tool queues a follow-up system message.
|
||||
EventKindFollowUpQueued
|
||||
// EventKindInterruptReceived is emitted when a soft interrupt message is accepted.
|
||||
EventKindInterruptReceived
|
||||
// EventKindSubTurnSpawn is emitted when a sub-turn is spawned.
|
||||
EventKindSubTurnSpawn
|
||||
// EventKindSubTurnEnd is emitted when a sub-turn finishes.
|
||||
EventKindSubTurnEnd
|
||||
// EventKindSubTurnResultDelivered is emitted when a sub-turn result is delivered.
|
||||
EventKindSubTurnResultDelivered
|
||||
// EventKindSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
|
||||
EventKindSubTurnOrphan
|
||||
// EventKindError is emitted when a turn encounters an execution error.
|
||||
EventKindError
|
||||
|
||||
eventKindCount
|
||||
)
|
||||
|
||||
var eventKindNames = [...]string{
|
||||
"turn_start",
|
||||
"turn_end",
|
||||
"llm_request",
|
||||
"llm_delta",
|
||||
"llm_response",
|
||||
"llm_retry",
|
||||
"context_compress",
|
||||
"session_summarize",
|
||||
"tool_exec_start",
|
||||
"tool_exec_end",
|
||||
"tool_exec_skipped",
|
||||
"steering_injected",
|
||||
"follow_up_queued",
|
||||
"interrupt_received",
|
||||
"subturn_spawn",
|
||||
"subturn_end",
|
||||
"subturn_result_delivered",
|
||||
"subturn_orphan",
|
||||
"error",
|
||||
}
|
||||
|
||||
// String returns the stable string form of an EventKind.
|
||||
func (k EventKind) String() string {
|
||||
if k >= eventKindCount {
|
||||
return fmt.Sprintf("event_kind(%d)", k)
|
||||
}
|
||||
return eventKindNames[k]
|
||||
}
|
||||
|
||||
// Event is the structured envelope broadcast by the agent EventBus.
|
||||
type Event struct {
|
||||
Kind EventKind
|
||||
Time time.Time
|
||||
Meta EventMeta
|
||||
Payload any
|
||||
}
|
||||
|
||||
// EventMeta contains correlation fields shared by all agent-loop events.
|
||||
type EventMeta struct {
|
||||
AgentID string
|
||||
TurnID string
|
||||
ParentTurnID string
|
||||
SessionKey string
|
||||
Iteration int
|
||||
TracePath string
|
||||
Source string
|
||||
}
|
||||
|
||||
// TurnEndStatus describes the terminal state of a turn.
|
||||
type TurnEndStatus string
|
||||
|
||||
const (
|
||||
// TurnEndStatusCompleted indicates the turn finished normally.
|
||||
TurnEndStatusCompleted TurnEndStatus = "completed"
|
||||
// TurnEndStatusError indicates the turn ended because of an error.
|
||||
TurnEndStatusError TurnEndStatus = "error"
|
||||
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
|
||||
TurnEndStatusAborted TurnEndStatus = "aborted"
|
||||
)
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
type TurnStartPayload struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
MediaCount int
|
||||
}
|
||||
|
||||
// TurnEndPayload describes the completion of a turn.
|
||||
type TurnEndPayload struct {
|
||||
Status TurnEndStatus
|
||||
Iterations int
|
||||
Duration time.Duration
|
||||
FinalContentLen int
|
||||
}
|
||||
|
||||
// LLMRequestPayload describes an outbound LLM request.
|
||||
type LLMRequestPayload struct {
|
||||
Model string
|
||||
MessagesCount int
|
||||
ToolsCount int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
}
|
||||
|
||||
// LLMResponsePayload describes an inbound LLM response.
|
||||
type LLMResponsePayload struct {
|
||||
ContentLen int
|
||||
ToolCalls int
|
||||
HasReasoning bool
|
||||
}
|
||||
|
||||
// LLMDeltaPayload describes a streamed LLM delta.
|
||||
type LLMDeltaPayload struct {
|
||||
ContentDeltaLen int
|
||||
ReasoningDeltaLen int
|
||||
}
|
||||
|
||||
// LLMRetryPayload describes a retry of an LLM request.
|
||||
type LLMRetryPayload struct {
|
||||
Attempt int
|
||||
MaxRetries int
|
||||
Reason string
|
||||
Error string
|
||||
Backoff time.Duration
|
||||
}
|
||||
|
||||
// ContextCompressReason identifies why emergency compression ran.
|
||||
type ContextCompressReason string
|
||||
|
||||
const (
|
||||
// ContextCompressReasonProactive indicates compression before the first LLM call.
|
||||
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
|
||||
// ContextCompressReasonRetry indicates compression during context-error retry handling.
|
||||
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
|
||||
)
|
||||
|
||||
// ContextCompressPayload describes a forced history compression.
|
||||
type ContextCompressPayload struct {
|
||||
Reason ContextCompressReason
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// SessionSummarizePayload describes a completed async session summarization.
|
||||
type SessionSummarizePayload struct {
|
||||
SummarizedMessages int
|
||||
KeptMessages int
|
||||
SummaryLen int
|
||||
OmittedOversized bool
|
||||
}
|
||||
|
||||
// ToolExecStartPayload describes a tool execution request.
|
||||
type ToolExecStartPayload struct {
|
||||
Tool string
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// ToolExecEndPayload describes the outcome of a tool execution.
|
||||
type ToolExecEndPayload struct {
|
||||
Tool string
|
||||
Duration time.Duration
|
||||
ForLLMLen int
|
||||
ForUserLen int
|
||||
IsError bool
|
||||
Async bool
|
||||
}
|
||||
|
||||
// ToolExecSkippedPayload describes a skipped tool call.
|
||||
type ToolExecSkippedPayload struct {
|
||||
Tool string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
|
||||
type SteeringInjectedPayload struct {
|
||||
Count int
|
||||
TotalContentLen int
|
||||
}
|
||||
|
||||
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
|
||||
type FollowUpQueuedPayload struct {
|
||||
SourceTool string
|
||||
Channel string
|
||||
ChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
type InterruptKind string
|
||||
|
||||
const (
|
||||
InterruptKindSteering InterruptKind = "steering"
|
||||
InterruptKindGraceful InterruptKind = "graceful"
|
||||
InterruptKindHard InterruptKind = "hard_abort"
|
||||
)
|
||||
|
||||
// InterruptReceivedPayload describes accepted turn-control input.
|
||||
type InterruptReceivedPayload struct {
|
||||
Kind InterruptKind
|
||||
Role string
|
||||
ContentLen int
|
||||
QueueDepth int
|
||||
HintLen int
|
||||
}
|
||||
|
||||
// SubTurnSpawnPayload describes the creation of a child turn.
|
||||
type SubTurnSpawnPayload struct {
|
||||
AgentID string
|
||||
Label string
|
||||
ParentTurnID string
|
||||
}
|
||||
|
||||
// SubTurnEndPayload describes the completion of a child turn.
|
||||
type SubTurnEndPayload struct {
|
||||
AgentID string
|
||||
Status string
|
||||
}
|
||||
|
||||
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
|
||||
type SubTurnResultDeliveredPayload struct {
|
||||
TargetChannel string
|
||||
TargetChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
|
||||
type SubTurnOrphanPayload struct {
|
||||
ParentTurnID string
|
||||
ChildTurnID string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// ErrorPayload describes an execution error inside the agent loop.
|
||||
type ErrorPayload struct {
|
||||
Stage string
|
||||
Message string
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type hookRuntime struct {
|
||||
initOnce sync.Once
|
||||
mu sync.Mutex
|
||||
initErr error
|
||||
mounted []string
|
||||
}
|
||||
|
||||
func (r *hookRuntime) setInitErr(err error) {
|
||||
r.mu.Lock()
|
||||
r.initErr = err
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *hookRuntime) getInitErr() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.initErr
|
||||
}
|
||||
|
||||
func (r *hookRuntime) setMounted(names []string) {
|
||||
r.mu.Lock()
|
||||
r.mounted = append([]string(nil), names...)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *hookRuntime) reset(al *AgentLoop) {
|
||||
r.mu.Lock()
|
||||
names := append([]string(nil), r.mounted...)
|
||||
r.mounted = nil
|
||||
r.initErr = nil
|
||||
r.initOnce = sync.Once{}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, name := range names {
|
||||
al.UnmountHook(name)
|
||||
}
|
||||
}
|
||||
|
||||
// BuiltinHookFactory constructs an in-process hook from config.
|
||||
type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error)
|
||||
|
||||
var (
|
||||
builtinHookRegistryMu sync.RWMutex
|
||||
builtinHookRegistry = map[string]BuiltinHookFactory{}
|
||||
)
|
||||
|
||||
// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting.
|
||||
func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("builtin hook name is required")
|
||||
}
|
||||
if factory == nil {
|
||||
return fmt.Errorf("builtin hook %q factory is nil", name)
|
||||
}
|
||||
|
||||
builtinHookRegistryMu.Lock()
|
||||
defer builtinHookRegistryMu.Unlock()
|
||||
|
||||
if _, exists := builtinHookRegistry[name]; exists {
|
||||
return fmt.Errorf("builtin hook %q is already registered", name)
|
||||
}
|
||||
builtinHookRegistry[name] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
func unregisterBuiltinHook(name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
builtinHookRegistryMu.Lock()
|
||||
delete(builtinHookRegistry, name)
|
||||
builtinHookRegistryMu.Unlock()
|
||||
}
|
||||
|
||||
func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) {
|
||||
builtinHookRegistryMu.RLock()
|
||||
defer builtinHookRegistryMu.RUnlock()
|
||||
|
||||
factory, ok := builtinHookRegistry[name]
|
||||
return factory, ok
|
||||
}
|
||||
|
||||
func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) {
|
||||
if hm == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
hm.ConfigureTimeouts(
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS),
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS),
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS),
|
||||
)
|
||||
}
|
||||
|
||||
func hookTimeoutFromMS(ms int) time.Duration {
|
||||
if ms <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error {
|
||||
if al == nil || al.cfg == nil || al.hooks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
al.hookRuntime.initOnce.Do(func() {
|
||||
al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx))
|
||||
})
|
||||
|
||||
return al.hookRuntime.getInitErr()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) {
|
||||
if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
mounted := make([]string, 0)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, name := range mounted {
|
||||
al.UnmountHook(name)
|
||||
}
|
||||
return
|
||||
}
|
||||
al.hookRuntime.setMounted(mounted)
|
||||
}()
|
||||
|
||||
builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins)
|
||||
for _, name := range builtinNames {
|
||||
spec := al.cfg.Hooks.Builtins[name]
|
||||
factory, ok := lookupBuiltinHook(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("builtin hook %q is not registered", name)
|
||||
}
|
||||
|
||||
hook, factoryErr := factory(ctx, spec)
|
||||
if factoryErr != nil {
|
||||
return fmt.Errorf("build builtin hook %q: %w", name, factoryErr)
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Priority: spec.Priority,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: hook,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("mount builtin hook %q: %w", name, err)
|
||||
}
|
||||
mounted = append(mounted, name)
|
||||
}
|
||||
|
||||
processNames := enabledProcessHookNames(al.cfg.Hooks.Processes)
|
||||
for _, name := range processNames {
|
||||
spec := al.cfg.Hooks.Processes[name]
|
||||
opts, buildErr := processHookOptionsFromConfig(spec)
|
||||
if buildErr != nil {
|
||||
return fmt.Errorf("configure process hook %q: %w", name, buildErr)
|
||||
}
|
||||
|
||||
processHook, buildErr := NewProcessHook(ctx, name, opts)
|
||||
if buildErr != nil {
|
||||
return fmt.Errorf("start process hook %q: %w", name, buildErr)
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Priority: spec.Priority,
|
||||
Source: HookSourceProcess,
|
||||
Hook: processHook,
|
||||
}); err != nil {
|
||||
_ = processHook.Close()
|
||||
return fmt.Errorf("mount process hook %q: %w", name, err)
|
||||
}
|
||||
mounted = append(mounted, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string {
|
||||
if len(specs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(specs))
|
||||
for name, spec := range specs {
|
||||
if spec.Enabled {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string {
|
||||
if len(specs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(specs))
|
||||
for name, spec := range specs {
|
||||
if spec.Enabled {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) {
|
||||
transport := spec.Transport
|
||||
if transport == "" {
|
||||
transport = "stdio"
|
||||
}
|
||||
if transport != "stdio" {
|
||||
return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport)
|
||||
}
|
||||
if len(spec.Command) == 0 {
|
||||
return ProcessHookOptions{}, fmt.Errorf("command is required")
|
||||
}
|
||||
|
||||
opts := ProcessHookOptions{
|
||||
Command: append([]string(nil), spec.Command...),
|
||||
Dir: spec.Dir,
|
||||
Env: processHookEnvFromMap(spec.Env),
|
||||
}
|
||||
|
||||
observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe)
|
||||
if err != nil {
|
||||
return ProcessHookOptions{}, err
|
||||
}
|
||||
opts.Observe = observeEnabled
|
||||
opts.ObserveKinds = observeKinds
|
||||
|
||||
for _, intercept := range spec.Intercept {
|
||||
switch intercept {
|
||||
case "before_llm", "after_llm":
|
||||
opts.InterceptLLM = true
|
||||
case "before_tool", "after_tool":
|
||||
opts.InterceptTool = true
|
||||
case "approve_tool":
|
||||
opts.ApproveTool = true
|
||||
case "":
|
||||
continue
|
||||
default:
|
||||
return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept)
|
||||
}
|
||||
}
|
||||
|
||||
if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool {
|
||||
return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled")
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func processHookEnvFromMap(envMap map[string]string) []string {
|
||||
if len(envMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(envMap))
|
||||
for key := range envMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
env := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
env = append(env, key+"="+envMap[key])
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) {
|
||||
if len(observe) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
validKinds := validHookEventKinds()
|
||||
normalized := make([]string, 0, len(observe))
|
||||
for _, kind := range observe {
|
||||
switch kind {
|
||||
case "", "*", "all":
|
||||
return nil, true, nil
|
||||
default:
|
||||
if _, ok := validKinds[kind]; !ok {
|
||||
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
|
||||
}
|
||||
normalized = append(normalized, kind)
|
||||
}
|
||||
}
|
||||
|
||||
if len(normalized) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
return normalized, true, nil
|
||||
}
|
||||
|
||||
func validHookEventKinds() map[string]struct{} {
|
||||
kinds := make(map[string]struct{}, int(eventKindCount))
|
||||
for kind := EventKind(0); kind < eventKindCount; kind++ {
|
||||
kinds[kind.String()] = struct{}{}
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type builtinAutoHookConfig struct {
|
||||
Model string `json:"model"`
|
||||
Suffix string `json:"suffix"`
|
||||
}
|
||||
|
||||
type builtinAutoHook struct {
|
||||
model string
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (h *builtinAutoHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = h.model
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *builtinAutoHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
next := resp.Clone()
|
||||
if next.Response != nil {
|
||||
next.Response.Content += h.suffix
|
||||
}
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Hooks: hooks,
|
||||
}
|
||||
|
||||
return NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) {
|
||||
const hookName = "test-auto-builtin-hook"
|
||||
|
||||
if err := RegisterBuiltinHook(hookName, func(
|
||||
ctx context.Context,
|
||||
spec config.BuiltinHookConfig,
|
||||
) (any, error) {
|
||||
var hookCfg builtinAutoHookConfig
|
||||
if len(spec.Config) > 0 {
|
||||
if err := json.Unmarshal(spec.Config, &hookCfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &builtinAutoHook{
|
||||
model: hookCfg.Model,
|
||||
suffix: hookCfg.Suffix,
|
||||
}, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("RegisterBuiltinHook failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
unregisterBuiltinHook(hookName)
|
||||
})
|
||||
|
||||
rawCfg, err := json.Marshal(builtinAutoHookConfig{
|
||||
Model: "builtin-model",
|
||||
Suffix: "|builtin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
provider := &llmHookTestProvider{}
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Builtins: map[string]config.BuiltinHookConfig{
|
||||
hookName: {
|
||||
Enabled: true,
|
||||
Config: rawCfg,
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|builtin" {
|
||||
t.Fatalf("expected builtin-hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "builtin-model" {
|
||||
t.Fatalf("expected builtin model, got %q", lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
eventLog := filepath.Join(t.TempDir(), "events.log")
|
||||
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Processes: map[string]config.ProcessHookConfig{
|
||||
"ipc-auto": {
|
||||
Enabled: true,
|
||||
Command: processHookHelperCommand(),
|
||||
Env: map[string]string{
|
||||
"PICOCLAW_HOOK_HELPER": "1",
|
||||
"PICOCLAW_HOOK_MODE": "rewrite",
|
||||
"PICOCLAW_HOOK_EVENT_LOG": eventLog,
|
||||
},
|
||||
Observe: []string{"turn_end"},
|
||||
Intercept: []string{"before_llm", "after_llm"},
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Processes: map[string]config.ProcessHookConfig{
|
||||
"bad-hook": {
|
||||
Enabled: true,
|
||||
Command: processHookHelperCommand(),
|
||||
Intercept: []string{"not_supported"},
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid configured hook error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
processHookJSONRPCVersion = "2.0"
|
||||
processHookReadBufferSize = 1024 * 1024
|
||||
processHookCloseTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type ProcessHookOptions struct {
|
||||
Command []string
|
||||
Dir string
|
||||
Env []string
|
||||
Observe bool
|
||||
ObserveKinds []string
|
||||
InterceptLLM bool
|
||||
InterceptTool bool
|
||||
ApproveTool bool
|
||||
}
|
||||
|
||||
type ProcessHook struct {
|
||||
name string
|
||||
opts ProcessHookOptions
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
observeKinds map[string]struct{}
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pending map[uint64]chan processHookRPCMessage
|
||||
nextID atomic.Uint64
|
||||
|
||||
closed atomic.Bool
|
||||
done chan struct{}
|
||||
closeErr error
|
||||
closeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type processHookRPCMessage struct {
|
||||
JSONRPC string `json:"jsonrpc,omitempty"`
|
||||
ID uint64 `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *processHookRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type processHookRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type processHookHelloParams struct {
|
||||
Name string `json:"name"`
|
||||
Version int `json:"version"`
|
||||
Modes []string `json:"modes,omitempty"`
|
||||
}
|
||||
|
||||
type processHookDecisionResponse struct {
|
||||
Action HookAction `json:"action"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
type processHookBeforeLLMResponse struct {
|
||||
processHookDecisionResponse
|
||||
Request *LLMHookRequest `json:"request,omitempty"`
|
||||
}
|
||||
|
||||
type processHookAfterLLMResponse struct {
|
||||
processHookDecisionResponse
|
||||
Response *LLMHookResponse `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
type processHookBeforeToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
}
|
||||
|
||||
type processHookAfterToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Result *ToolResultHookResponse `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
|
||||
if len(opts.Command) == 0 {
|
||||
return nil, fmt.Errorf("process hook command is required")
|
||||
}
|
||||
|
||||
cmd := exec.Command(opts.Command[0], opts.Command[1:]...)
|
||||
cmd.Dir = opts.Dir
|
||||
if len(opts.Env) > 0 {
|
||||
cmd.Env = append(os.Environ(), opts.Env...)
|
||||
}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stdin: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stdout: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stderr: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start process hook: %w", err)
|
||||
}
|
||||
|
||||
ph := &ProcessHook{
|
||||
name: name,
|
||||
opts: opts,
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
observeKinds: newProcessHookObserveKinds(opts.ObserveKinds),
|
||||
pending: make(map[uint64]chan processHookRPCMessage),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go ph.readLoop(stdout)
|
||||
go ph.readStderr(stderr)
|
||||
go ph.waitLoop()
|
||||
|
||||
helloCtx := ctx
|
||||
if helloCtx == nil {
|
||||
var cancel context.CancelFunc
|
||||
helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
if err := ph.hello(helloCtx); err != nil {
|
||||
_ = ph.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ph, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) Close() error {
|
||||
if ph == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ph.closeOnce.Do(func() {
|
||||
ph.closed.Store(true)
|
||||
if ph.stdin != nil {
|
||||
_ = ph.stdin.Close()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ph.done:
|
||||
case <-time.After(processHookCloseTimeout):
|
||||
if ph.cmd != nil && ph.cmd.Process != nil {
|
||||
_ = ph.cmd.Process.Kill()
|
||||
}
|
||||
<-ph.done
|
||||
}
|
||||
})
|
||||
|
||||
ph.closeMu.Lock()
|
||||
defer ph.closeMu.Unlock()
|
||||
return ph.closeErr
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if ph == nil || !ph.opts.Observe {
|
||||
return nil
|
||||
}
|
||||
if len(ph.observeKinds) > 0 {
|
||||
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ph.notify(ctx, "hook.event", evt)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptLLM {
|
||||
return req, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookBeforeLLMResponse
|
||||
if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Request == nil {
|
||||
resp.Request = req
|
||||
}
|
||||
return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptLLM {
|
||||
return resp, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var result processHookAfterLLMResponse
|
||||
if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if result.Response == nil {
|
||||
result.Response = resp
|
||||
}
|
||||
return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptTool {
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookBeforeToolResponse
|
||||
if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Call == nil {
|
||||
resp.Call = call
|
||||
}
|
||||
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptTool {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookAfterToolResponse
|
||||
if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Result == nil {
|
||||
resp.Result = result
|
||||
}
|
||||
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
if ph == nil || !ph.opts.ApproveTool {
|
||||
return ApprovalDecision{Approved: true}, nil
|
||||
}
|
||||
|
||||
var resp ApprovalDecision
|
||||
if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil {
|
||||
return ApprovalDecision{}, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) hello(ctx context.Context) error {
|
||||
modes := make([]string, 0, 4)
|
||||
if ph.opts.Observe {
|
||||
modes = append(modes, "observe")
|
||||
}
|
||||
if ph.opts.InterceptLLM {
|
||||
modes = append(modes, "llm")
|
||||
}
|
||||
if ph.opts.InterceptTool {
|
||||
modes = append(modes, "tool")
|
||||
}
|
||||
if ph.opts.ApproveTool {
|
||||
modes = append(modes, "approve")
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
return ph.call(ctx, "hook.hello", processHookHelloParams{
|
||||
Name: ph.name,
|
||||
Version: 1,
|
||||
Modes: modes,
|
||||
}, &result)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error {
|
||||
msg := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
Method: method,
|
||||
}
|
||||
if params != nil {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg.Params = body
|
||||
}
|
||||
return ph.send(ctx, msg)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error {
|
||||
if ph.closed.Load() {
|
||||
return fmt.Errorf("process hook %q is closed", ph.name)
|
||||
}
|
||||
|
||||
id := ph.nextID.Add(1)
|
||||
respCh := make(chan processHookRPCMessage, 1)
|
||||
ph.pendingMu.Lock()
|
||||
ph.pending[id] = respCh
|
||||
ph.pendingMu.Unlock()
|
||||
|
||||
msg := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
ID: id,
|
||||
Method: method,
|
||||
}
|
||||
if params != nil {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
ph.removePending(id)
|
||||
return err
|
||||
}
|
||||
msg.Params = body
|
||||
}
|
||||
|
||||
if err := ph.send(ctx, msg); err != nil {
|
||||
ph.removePending(id)
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case resp, ok := <-respCh:
|
||||
if !ok {
|
||||
return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message)
|
||||
}
|
||||
if out != nil && len(resp.Result) > 0 {
|
||||
if err := json.Unmarshal(resp.Result, out); err != nil {
|
||||
return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
ph.removePending(id)
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = append(body, '\n')
|
||||
|
||||
ph.writeMu.Lock()
|
||||
defer ph.writeMu.Unlock()
|
||||
|
||||
if ph.closed.Load() {
|
||||
return fmt.Errorf("process hook %q is closed", ph.name)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, writeErr := ph.stdin.Write(body)
|
||||
done <- writeErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("write process hook %q message: %w", ph.name, err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) readLoop(stdout io.Reader) {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
var msg processHookRPCMessage
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{
|
||||
"hook": ph.name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if msg.ID == 0 {
|
||||
continue
|
||||
}
|
||||
ph.pendingMu.Lock()
|
||||
respCh, ok := ph.pending[msg.ID]
|
||||
if ok {
|
||||
delete(ph.pending, msg.ID)
|
||||
}
|
||||
ph.pendingMu.Unlock()
|
||||
if ok {
|
||||
respCh <- msg
|
||||
close(respCh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) readStderr(stderr io.Reader) {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize)
|
||||
for scanner.Scan() {
|
||||
logger.WarnCF("hooks", "Process hook stderr", map[string]any{
|
||||
"hook": ph.name,
|
||||
"stderr": scanner.Text(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) waitLoop() {
|
||||
err := ph.cmd.Wait()
|
||||
ph.closeMu.Lock()
|
||||
ph.closeErr = err
|
||||
ph.closeMu.Unlock()
|
||||
ph.failPending(err)
|
||||
close(ph.done)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) failPending(err error) {
|
||||
ph.pendingMu.Lock()
|
||||
defer ph.pendingMu.Unlock()
|
||||
|
||||
msg := processHookRPCMessage{
|
||||
Error: &processHookRPCError{
|
||||
Code: -32000,
|
||||
Message: "process exited",
|
||||
},
|
||||
}
|
||||
if err != nil {
|
||||
msg.Error.Message = err.Error()
|
||||
}
|
||||
|
||||
for id, ch := range ph.pending {
|
||||
delete(ph.pending, id)
|
||||
ch <- msg
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) removePending(id uint64) {
|
||||
ph.pendingMu.Lock()
|
||||
defer ph.pendingMu.Unlock()
|
||||
|
||||
if ch, ok := ph.pending[id]; ok {
|
||||
delete(ph.pending, id)
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
|
||||
if al == nil {
|
||||
return fmt.Errorf("agent loop is nil")
|
||||
}
|
||||
processHook, err := NewProcessHook(ctx, name, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Source: HookSourceProcess,
|
||||
Hook: processHook,
|
||||
}); err != nil {
|
||||
_ = processHook.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newProcessHookObserveKinds(kinds []string) map[string]struct{} {
|
||||
if len(kinds) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make(map[string]struct{}, len(kinds))
|
||||
for _, kind := range kinds {
|
||||
if kind == "" {
|
||||
continue
|
||||
}
|
||||
normalized[kind] = struct{}{}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestProcessHook_HelperProcess(t *testing.T) {
|
||||
if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" {
|
||||
return
|
||||
}
|
||||
if err := runProcessHookHelper(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
eventLog := filepath.Join(t.TempDir(), "events.log")
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("rewrite", eventLog),
|
||||
Observe: true,
|
||||
InterceptLLM: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked llm content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("rewrite", ""),
|
||||
InterceptTool: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "ipc:ipc" {
|
||||
t.Fatalf("expected rewritten process-hook tool result, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type blockedToolProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (p *blockedToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.calls++
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "blocked_tool",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: messages[len(messages)-1].Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *blockedToolProvider) GetDefaultModel() string {
|
||||
return "blocked-tool-provider"
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
|
||||
provider := &blockedToolProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("deny", ""),
|
||||
ApproveTool: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run blocked tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "Tool execution denied by approval hook: blocked by ipc hook"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if payload.Reason != expected {
|
||||
t.Fatalf("expected reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func processHookHelperCommand() []string {
|
||||
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
|
||||
}
|
||||
|
||||
func processHookHelperEnv(mode, eventLog string) []string {
|
||||
env := []string{
|
||||
"PICOCLAW_HOOK_HELPER=1",
|
||||
"PICOCLAW_HOOK_MODE=" + mode,
|
||||
}
|
||||
if eventLog != "" {
|
||||
env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func waitForFileContains(t *testing.T, path, substring string) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err == nil && strings.Contains(string(data), substring) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data))
|
||||
}
|
||||
|
||||
func runProcessHookHelper() error {
|
||||
mode := os.Getenv("PICOCLAW_HOOK_MODE")
|
||||
eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG")
|
||||
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
|
||||
for scanner.Scan() {
|
||||
var msg processHookRPCMessage
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.ID == 0 {
|
||||
if msg.Method == "hook.event" && eventLog != "" {
|
||||
var evt map[string]any
|
||||
if err := json.Unmarshal(msg.Params, &evt); err == nil {
|
||||
if rawKind, ok := evt["Kind"].(float64); ok {
|
||||
kind := EventKind(rawKind)
|
||||
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
result, rpcErr := handleProcessHookRequest(mode, msg)
|
||||
resp := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
ID: msg.ID,
|
||||
}
|
||||
if rpcErr != nil {
|
||||
resp.Error = rpcErr
|
||||
} else if result != nil {
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Result = body
|
||||
} else {
|
||||
resp.Result = []byte("{}")
|
||||
}
|
||||
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) {
|
||||
switch msg.Method {
|
||||
case "hook.hello":
|
||||
return map[string]any{"ok": true}, nil
|
||||
case "hook.before_llm":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var req map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &req)
|
||||
req["model"] = "process-model"
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"request": req,
|
||||
}, nil
|
||||
case "hook.after_llm":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var resp map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &resp)
|
||||
if rawResponse, ok := resp["response"].(map[string]any); ok {
|
||||
if content, ok := rawResponse["content"].(string); ok {
|
||||
rawResponse["content"] = content + "|ipc"
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"response": resp,
|
||||
}, nil
|
||||
case "hook.before_tool":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var call map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &call)
|
||||
rawArgs, ok := call["arguments"].(map[string]any)
|
||||
if !ok || rawArgs == nil {
|
||||
rawArgs = map[string]any{}
|
||||
}
|
||||
rawArgs["text"] = "ipc"
|
||||
call["arguments"] = rawArgs
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"call": call,
|
||||
}, nil
|
||||
case "hook.after_tool":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &result)
|
||||
if rawResult, ok := result["result"].(map[string]any); ok {
|
||||
if forLLM, ok := rawResult["for_llm"].(string); ok {
|
||||
rawResult["for_llm"] = "ipc:" + forLLM
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"result": result,
|
||||
}, nil
|
||||
case "hook.approve_tool":
|
||||
if mode == "deny" {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked by ipc hook",
|
||||
}, nil
|
||||
}
|
||||
return ApprovalDecision{Approved: true}, nil
|
||||
default:
|
||||
return nil, &processHookRPCError{
|
||||
Code: -32601,
|
||||
Message: "method not found",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,809 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHookObserverTimeout = 500 * time.Millisecond
|
||||
defaultHookInterceptorTimeout = 5 * time.Second
|
||||
defaultHookApprovalTimeout = 60 * time.Second
|
||||
hookObserverBufferSize = 64
|
||||
)
|
||||
|
||||
type HookAction string
|
||||
|
||||
const (
|
||||
HookActionContinue HookAction = "continue"
|
||||
HookActionModify HookAction = "modify"
|
||||
HookActionDenyTool HookAction = "deny_tool"
|
||||
HookActionAbortTurn HookAction = "abort_turn"
|
||||
HookActionHardAbort HookAction = "hard_abort"
|
||||
)
|
||||
|
||||
type HookDecision struct {
|
||||
Action HookAction `json:"action"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
func (d HookDecision) normalizedAction() HookAction {
|
||||
if d.Action == "" {
|
||||
return HookActionContinue
|
||||
}
|
||||
return d.Action
|
||||
}
|
||||
|
||||
type ApprovalDecision struct {
|
||||
Approved bool `json:"approved"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
type HookSource uint8
|
||||
|
||||
const (
|
||||
HookSourceInProcess HookSource = iota
|
||||
HookSourceProcess
|
||||
)
|
||||
|
||||
type HookRegistration struct {
|
||||
Name string
|
||||
Priority int
|
||||
Source HookSource
|
||||
Hook any
|
||||
}
|
||||
|
||||
func NamedHook(name string, hook any) HookRegistration {
|
||||
return HookRegistration{
|
||||
Name: name,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: hook,
|
||||
}
|
||||
}
|
||||
|
||||
type EventObserver interface {
|
||||
OnEvent(ctx context.Context, evt Event) error
|
||||
}
|
||||
|
||||
type LLMInterceptor interface {
|
||||
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
|
||||
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolInterceptor interface {
|
||||
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
|
||||
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolApprover interface {
|
||||
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
|
||||
}
|
||||
|
||||
type LLMHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Model string `json:"model"`
|
||||
Messages []providers.Message `json:"messages,omitempty"`
|
||||
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Messages = cloneProviderMessages(r.Messages)
|
||||
cloned.Tools = cloneToolDefinitions(r.Tools)
|
||||
cloned.Options = cloneStringAnyMap(r.Options)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type LLMHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Model string `json:"model"`
|
||||
Response *providers.LLMResponse `json:"response,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Response = cloneLLMResponse(r.Response)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolApprovalRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolResultHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.Result = cloneToolResult(r.Result)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type HookManager struct {
|
||||
eventBus *EventBus
|
||||
observerTimeout time.Duration
|
||||
interceptorTimeout time.Duration
|
||||
approvalTimeout time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
hooks map[string]HookRegistration
|
||||
ordered []HookRegistration
|
||||
|
||||
sub EventSubscription
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHookManager(eventBus *EventBus) *HookManager {
|
||||
hm := &HookManager{
|
||||
eventBus: eventBus,
|
||||
observerTimeout: defaultHookObserverTimeout,
|
||||
interceptorTimeout: defaultHookInterceptorTimeout,
|
||||
approvalTimeout: defaultHookApprovalTimeout,
|
||||
hooks: make(map[string]HookRegistration),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if eventBus == nil {
|
||||
close(hm.done)
|
||||
return hm
|
||||
}
|
||||
|
||||
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
|
||||
go hm.dispatchEvents()
|
||||
return hm
|
||||
}
|
||||
|
||||
func (hm *HookManager) Close() {
|
||||
if hm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hm.closeOnce.Do(func() {
|
||||
if hm.eventBus != nil {
|
||||
hm.eventBus.Unsubscribe(hm.sub.ID)
|
||||
}
|
||||
<-hm.done
|
||||
hm.closeAllHooks()
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
|
||||
if hm == nil {
|
||||
return
|
||||
}
|
||||
if observer > 0 {
|
||||
hm.observerTimeout = observer
|
||||
}
|
||||
if interceptor > 0 {
|
||||
hm.interceptorTimeout = interceptor
|
||||
}
|
||||
if approval > 0 {
|
||||
hm.approvalTimeout = approval
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) Mount(reg HookRegistration) error {
|
||||
if hm == nil {
|
||||
return fmt.Errorf("hook manager is nil")
|
||||
}
|
||||
if reg.Name == "" {
|
||||
return fmt.Errorf("hook name is required")
|
||||
}
|
||||
if reg.Hook == nil {
|
||||
return fmt.Errorf("hook %q is nil", reg.Name)
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if existing, ok := hm.hooks[reg.Name]; ok {
|
||||
closeHookIfPossible(existing.Hook)
|
||||
}
|
||||
hm.hooks[reg.Name] = reg
|
||||
hm.rebuildOrdered()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) Unmount(name string) {
|
||||
if hm == nil || name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if existing, ok := hm.hooks[name]; ok {
|
||||
closeHookIfPossible(existing.Hook)
|
||||
}
|
||||
delete(hm.hooks, name)
|
||||
hm.rebuildOrdered()
|
||||
}
|
||||
|
||||
func (hm *HookManager) dispatchEvents() {
|
||||
defer close(hm.done)
|
||||
|
||||
for evt := range hm.sub.C {
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
observer, ok := reg.Hook.(EventObserver)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hm.runObserver(reg.Name, observer, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
|
||||
if hm == nil || req == nil {
|
||||
return req, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := req.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
|
||||
if hm == nil || resp == nil {
|
||||
return resp, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := resp.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision) {
|
||||
if hm == nil || call == nil {
|
||||
return call, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := call.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision) {
|
||||
if hm == nil || result == nil {
|
||||
return result, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := result.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
|
||||
if hm == nil || req == nil {
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
approver, ok := reg.Hook.(ToolApprover)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
|
||||
if !ok {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
|
||||
}
|
||||
}
|
||||
if !decision.Approved {
|
||||
return decision
|
||||
}
|
||||
}
|
||||
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
func (hm *HookManager) rebuildOrdered() {
|
||||
hm.ordered = hm.ordered[:0]
|
||||
for _, reg := range hm.hooks {
|
||||
hm.ordered = append(hm.ordered, reg)
|
||||
}
|
||||
sort.SliceStable(hm.ordered, func(i, j int) bool {
|
||||
if hm.ordered[i].Source != hm.ordered[j].Source {
|
||||
return hm.ordered[i].Source < hm.ordered[j].Source
|
||||
}
|
||||
if hm.ordered[i].Priority == hm.ordered[j].Priority {
|
||||
return hm.ordered[i].Name < hm.ordered[j].Name
|
||||
}
|
||||
return hm.ordered[i].Priority < hm.ordered[j].Priority
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) snapshotHooks() []HookRegistration {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
snapshot := make([]HookRegistration, len(hm.ordered))
|
||||
copy(snapshot, hm.ordered)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (hm *HookManager) closeAllHooks() {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
for name, reg := range hm.hooks {
|
||||
closeHookIfPossible(reg.Hook)
|
||||
delete(hm.hooks, name)
|
||||
}
|
||||
hm.ordered = nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- observer.OnEvent(ctx, evt)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Event observer failed", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"timeout_ms": hm.observerTimeout.Milliseconds(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_llm",
|
||||
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeLLM(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_llm",
|
||||
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterLLM(ctx, resp)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_tool",
|
||||
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeTool(ctx, call)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
resultView *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_tool",
|
||||
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterTool(ctx, resultView)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callApproveTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
approver ToolApprover,
|
||||
req *ToolApprovalRequest,
|
||||
) (ApprovalDecision, bool) {
|
||||
return runApprovalHook(
|
||||
parent,
|
||||
hm.approvalTimeout,
|
||||
name,
|
||||
"approve_tool",
|
||||
func(ctx context.Context) (ApprovalDecision, error) {
|
||||
return approver.ApproveTool(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func runInterceptorHook[T any](
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (T, HookDecision, error),
|
||||
) (T, HookDecision, bool) {
|
||||
var zero T
|
||||
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
value T
|
||||
decision HookDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
value, decision, err := fn(ctx)
|
||||
done <- result{value: value, decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
return res.value, res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func runApprovalHook(
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (ApprovalDecision, error),
|
||||
) (ApprovalDecision, bool) {
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
decision ApprovalDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
decision, err := fn(ctx)
|
||||
done <- result{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return ApprovalDecision{}, false
|
||||
}
|
||||
return res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
|
||||
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"action": action,
|
||||
})
|
||||
}
|
||||
|
||||
func cloneProviderMessages(messages []providers.Message) []providers.Message {
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.Message, len(messages))
|
||||
for i, msg := range messages {
|
||||
cloned[i] = msg
|
||||
if len(msg.Media) > 0 {
|
||||
cloned[i].Media = append([]string(nil), msg.Media...)
|
||||
}
|
||||
if len(msg.SystemParts) > 0 {
|
||||
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolCall, len(calls))
|
||||
for i, call := range calls {
|
||||
cloned[i] = call
|
||||
if call.Function != nil {
|
||||
fn := *call.Function
|
||||
cloned[i].Function = &fn
|
||||
}
|
||||
if call.Arguments != nil {
|
||||
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
|
||||
}
|
||||
if call.ExtraContent != nil {
|
||||
extra := *call.ExtraContent
|
||||
if call.ExtraContent.Google != nil {
|
||||
google := *call.ExtraContent.Google
|
||||
extra.Google = &google
|
||||
}
|
||||
cloned[i].ExtraContent = &extra
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolDefinition, len(defs))
|
||||
for i, def := range defs {
|
||||
cloned[i] = def
|
||||
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *resp
|
||||
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
|
||||
if len(resp.ReasoningDetails) > 0 {
|
||||
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
usage := *resp.Usage
|
||||
cloned.Usage = &usage
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneStringAnyMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
cloned[k] = v
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *result
|
||||
if len(result.Media) > 0 {
|
||||
cloned.Media = append([]string(nil), result.Media...)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func closeHookIfPossible(hook any) {
|
||||
closer, ok := hook.(io.Closer)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := closer.Close(); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to close hook", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func newHookTestLoop(
|
||||
t *testing.T,
|
||||
provider providers.LLMProvider,
|
||||
) (*AgentLoop, *AgentInstance, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
return al, agent, func() {
|
||||
al.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
defer hm.Close()
|
||||
|
||||
if err := hm.Mount(HookRegistration{
|
||||
Name: "process",
|
||||
Priority: -10,
|
||||
Source: HookSourceProcess,
|
||||
Hook: struct{}{},
|
||||
}); err != nil {
|
||||
t.Fatalf("mount process hook: %v", err)
|
||||
}
|
||||
if err := hm.Mount(HookRegistration{
|
||||
Name: "in-process",
|
||||
Priority: 100,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: struct{}{},
|
||||
}); err != nil {
|
||||
t.Fatalf("mount in-process hook: %v", err)
|
||||
}
|
||||
|
||||
ordered := hm.snapshotHooks()
|
||||
if len(ordered) != 2 {
|
||||
t.Fatalf("expected 2 hooks, got %d", len(ordered))
|
||||
}
|
||||
if ordered[0].Name != "in-process" {
|
||||
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
|
||||
}
|
||||
if ordered[1].Name != "process" {
|
||||
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
type llmHookTestProvider struct {
|
||||
mu sync.Mutex
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.lastModel = model
|
||||
p.mu.Unlock()
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "provider content",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
return "llm-hook-provider"
|
||||
}
|
||||
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if evt.Kind == EventKindTurnEnd {
|
||||
select {
|
||||
case h.eventCh <- evt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = "hook-model"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
next := resp.Clone()
|
||||
next.Response.Content = "hooked content"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
||||
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "hooked content" {
|
||||
t.Fatalf("expected hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "hook-model" {
|
||||
t.Fatalf("expected model hook-model, got %q", lastModel)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.eventCh:
|
||||
if evt.Kind != EventKindTurnEnd {
|
||||
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for hook observer event")
|
||||
}
|
||||
}
|
||||
|
||||
type toolHookProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.calls++
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "echo_text",
|
||||
Arguments: map[string]any{"text": "original"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
last := messages[len(messages)-1]
|
||||
return &providers.LLMResponse{
|
||||
Content: last.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) GetDefaultModel() string {
|
||||
return "tool-hook-provider"
|
||||
}
|
||||
|
||||
type echoTextTool struct{}
|
||||
|
||||
func (t *echoTextTool) Name() string {
|
||||
return "echo_text"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Description() string {
|
||||
return "echo a text argument"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
text, _ := args["text"].(string)
|
||||
return tools.SilentResult(text)
|
||||
}
|
||||
|
||||
type toolRewriteHook struct{}
|
||||
|
||||
func (h *toolRewriteHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
next := call.Clone()
|
||||
next.Arguments["text"] = "modified"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *toolRewriteHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
next := result.Clone()
|
||||
next.Result.ForLLM = "after:" + next.Result.ForLLM
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "after:modified" {
|
||||
t.Fatalf("expected rewritten tool result, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type denyApprovalHook struct{}
|
||||
|
||||
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
expected := "Tool execution denied by approval hook: blocked"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if payload.Reason != expected {
|
||||
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
+12
-1
@@ -130,6 +130,17 @@ func NewAgentInstance(
|
||||
maxTokens = 8192
|
||||
}
|
||||
|
||||
contextWindow := defaults.ContextWindow
|
||||
if contextWindow == 0 {
|
||||
// Default heuristic: 4x the output token limit.
|
||||
// Most models have context windows well above their output limits
|
||||
// (e.g., GPT-4o 128k ctx / 16k out, Claude 200k ctx / 8k out).
|
||||
// 4x is a conservative lower bound that avoids premature
|
||||
// summarization while remaining safe — the reactive
|
||||
// forceCompression handles any overshoot.
|
||||
contextWindow = maxTokens * 4
|
||||
}
|
||||
|
||||
temperature := 0.7
|
||||
if defaults.Temperature != nil {
|
||||
temperature = *defaults.Temperature
|
||||
@@ -182,7 +193,7 @@ func NewAgentInstance(
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
ContextWindow: maxTokens,
|
||||
ContextWindow: contextWindow,
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
Provider: provider,
|
||||
|
||||
+1559
-419
File diff suppressed because it is too large
Load Diff
@@ -1235,11 +1235,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Inject some history to simulate a full context
|
||||
// Inject some history to simulate a full context.
|
||||
// Session history only stores user/assistant/tool messages — the system
|
||||
// prompt is built dynamically by BuildMessages and is NOT stored here.
|
||||
sessionKey := "test-session-context"
|
||||
// Create dummy history
|
||||
history := []providers.Message{
|
||||
{Role: "system", Content: "System prompt"},
|
||||
{Role: "user", Content: "Old message 1"},
|
||||
{Role: "assistant", Content: "Old response 1"},
|
||||
{Role: "user", Content: "Old message 2"},
|
||||
@@ -1277,12 +1277,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
// Check final history length
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
// We verify that the history has been modified (compressed)
|
||||
// Original length: 6
|
||||
// Expected behavior: compression drops ~50% of history (mid slice)
|
||||
// We can assert that the length is NOT what it would be without compression.
|
||||
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
|
||||
if len(finalHistory) >= 8 {
|
||||
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
|
||||
// Original length: 5
|
||||
// Expected behavior: compression drops ~50% of Turns
|
||||
// Without compression: 5 + 1 (new user msg) + 1 (assistant msg) = 7
|
||||
if len(finalHistory) >= 7 {
|
||||
t.Errorf("Expected history to be compressed (len < 7), got %d", len(finalHistory))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,503 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// SteeringMode controls how queued steering messages are dequeued.
|
||||
type SteeringMode string
|
||||
|
||||
const (
|
||||
// SteeringOneAtATime dequeues only the first queued message per poll.
|
||||
SteeringOneAtATime SteeringMode = "one-at-a-time"
|
||||
// SteeringAll drains the entire queue in a single poll.
|
||||
SteeringAll SteeringMode = "all"
|
||||
// MaxQueueSize number of possible messages in the Steering Queue
|
||||
MaxQueueSize = 10
|
||||
// manualSteeringScope is the legacy fallback queue used when no active
|
||||
// turn/session scope is available.
|
||||
manualSteeringScope = "__manual__"
|
||||
)
|
||||
|
||||
// parseSteeringMode normalizes a config string into a SteeringMode.
|
||||
func parseSteeringMode(s string) SteeringMode {
|
||||
switch s {
|
||||
case "all":
|
||||
return SteeringAll
|
||||
default:
|
||||
return SteeringOneAtATime
|
||||
}
|
||||
}
|
||||
|
||||
// steeringQueue is a thread-safe queue of user messages that can be injected
|
||||
// into a running agent loop to interrupt it between tool calls.
|
||||
type steeringQueue struct {
|
||||
mu sync.Mutex
|
||||
queues map[string][]providers.Message
|
||||
mode SteeringMode
|
||||
}
|
||||
|
||||
func newSteeringQueue(mode SteeringMode) *steeringQueue {
|
||||
return &steeringQueue{
|
||||
queues: make(map[string][]providers.Message),
|
||||
mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSteeringScope(scope string) string {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
return manualSteeringScope
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
// push enqueues a steering message in the legacy fallback scope.
|
||||
func (sq *steeringQueue) push(msg providers.Message) error {
|
||||
return sq.pushScope(manualSteeringScope, msg)
|
||||
}
|
||||
|
||||
// pushScope enqueues a steering message for the provided scope.
|
||||
func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
scope = normalizeSteeringScope(scope)
|
||||
queue := sq.queues[scope]
|
||||
if len(queue) >= MaxQueueSize {
|
||||
return fmt.Errorf("steering queue is full")
|
||||
}
|
||||
sq.queues[scope] = append(queue, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// dequeue removes and returns pending steering messages from the legacy
|
||||
// fallback scope according to the configured mode.
|
||||
func (sq *steeringQueue) dequeue() []providers.Message {
|
||||
return sq.dequeueScope(manualSteeringScope)
|
||||
}
|
||||
|
||||
// dequeueScope removes and returns pending steering messages for the provided
|
||||
// scope according to the configured mode.
|
||||
func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
return sq.dequeueLocked(normalizeSteeringScope(scope))
|
||||
}
|
||||
|
||||
// dequeueScopeWithFallback drains the scoped queue first and falls back to the
|
||||
// legacy manual scope for backwards compatibility.
|
||||
func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope != "" {
|
||||
if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
|
||||
return sq.dequeueLocked(manualSteeringScope)
|
||||
}
|
||||
|
||||
func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
|
||||
queue := sq.queues[scope]
|
||||
if len(queue) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch sq.mode {
|
||||
case SteeringAll:
|
||||
msgs := append([]providers.Message(nil), queue...)
|
||||
delete(sq.queues, scope)
|
||||
return msgs
|
||||
default:
|
||||
msg := queue[0]
|
||||
queue[0] = providers.Message{} // Clear reference for GC
|
||||
queue = queue[1:]
|
||||
if len(queue) == 0 {
|
||||
delete(sq.queues, scope)
|
||||
} else {
|
||||
sq.queues[scope] = queue
|
||||
}
|
||||
return []providers.Message{msg}
|
||||
}
|
||||
}
|
||||
|
||||
// len returns the number of queued messages across all scopes.
|
||||
func (sq *steeringQueue) len() int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
total := 0
|
||||
for _, queue := range sq.queues {
|
||||
total += len(queue)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// lenScope returns the number of queued messages for a specific scope.
|
||||
func (sq *steeringQueue) lenScope(scope string) int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return len(sq.queues[normalizeSteeringScope(scope)])
|
||||
}
|
||||
|
||||
// setMode updates the steering mode.
|
||||
func (sq *steeringQueue) setMode(mode SteeringMode) {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
sq.mode = mode
|
||||
}
|
||||
|
||||
// getMode returns the current steering mode.
|
||||
func (sq *steeringQueue) getMode() SteeringMode {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return sq.mode
|
||||
}
|
||||
|
||||
// Steer enqueues a user message to be injected into the currently running
|
||||
// agent loop. The message will be picked up after the current tool finishes
|
||||
// executing, causing any remaining tool calls in the batch to be skipped.
|
||||
func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
scope := ""
|
||||
agentID := ""
|
||||
if ts := al.getAnyActiveTurnState(); ts != nil {
|
||||
scope = ts.sessionKey
|
||||
agentID = ts.agentID
|
||||
}
|
||||
return al.enqueueSteeringMessage(scope, agentID, msg)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
|
||||
if al.steering == nil {
|
||||
return fmt.Errorf("steering queue is not initialized")
|
||||
}
|
||||
|
||||
if err := al.steering.pushScope(scope, msg); err != nil {
|
||||
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
|
||||
"error": err.Error(),
|
||||
"role": msg.Role,
|
||||
"scope": normalizeSteeringScope(scope),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
queueDepth := al.steering.lenScope(scope)
|
||||
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
|
||||
"role": msg.Role,
|
||||
"content_len": len(msg.Content),
|
||||
"media_count": len(msg.Media),
|
||||
"queue_len": queueDepth,
|
||||
"scope": normalizeSteeringScope(scope),
|
||||
})
|
||||
|
||||
meta := EventMeta{
|
||||
Source: "Steer",
|
||||
TracePath: "turn.interrupt.received",
|
||||
}
|
||||
if ts := al.getAnyActiveTurnState(); ts != nil {
|
||||
meta = ts.eventMeta("Steer", "turn.interrupt.received")
|
||||
} else {
|
||||
if strings.TrimSpace(agentID) != "" {
|
||||
meta.AgentID = agentID
|
||||
}
|
||||
normalizedScope := normalizeSteeringScope(scope)
|
||||
if normalizedScope != manualSteeringScope {
|
||||
meta.SessionKey = normalizedScope
|
||||
}
|
||||
if meta.AgentID == "" {
|
||||
if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
meta.AgentID = agent.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
meta,
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindSteering,
|
||||
Role: msg.Role,
|
||||
ContentLen: len(msg.Content),
|
||||
QueueDepth: queueDepth,
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SteeringMode returns the current steering mode.
|
||||
func (al *AgentLoop) SteeringMode() SteeringMode {
|
||||
if al.steering == nil {
|
||||
return SteeringOneAtATime
|
||||
}
|
||||
return al.steering.getMode()
|
||||
}
|
||||
|
||||
// SetSteeringMode updates the steering mode.
|
||||
func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
|
||||
if al.steering == nil {
|
||||
return
|
||||
}
|
||||
al.steering.setMode(mode)
|
||||
}
|
||||
|
||||
// dequeueSteeringMessages is the internal method called by the agent loop
|
||||
// to poll for steering messages in the legacy fallback scope.
|
||||
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeue()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeueScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeueScopeWithFallback(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
|
||||
if al.steering == nil {
|
||||
return 0
|
||||
}
|
||||
return al.steering.lenScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
sessionKey, channel, chatID string,
|
||||
steeringMsgs []providers.Message,
|
||||
) (string, error) {
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
InitialSteeringMessages: steeringMsgs,
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
registry := al.GetRegistry()
|
||||
if registry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
|
||||
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
||||
return agent
|
||||
}
|
||||
}
|
||||
|
||||
return registry.GetDefaultAgent()
|
||||
}
|
||||
|
||||
// Continue resumes an idle agent by dequeuing any pending steering messages
|
||||
// and running them through the agent loop. This is used when the agent's last
|
||||
// message was from the assistant (i.e., it has stopped processing) and the
|
||||
// user has since enqueued steering messages.
|
||||
//
|
||||
// If no steering messages are pending, it returns an empty string.
|
||||
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
||||
}
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
|
||||
if len(steeringMsgs) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
agent := al.agentForSession(sessionKey)
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no agent available for session %q", sessionKey)
|
||||
}
|
||||
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
|
||||
resetter.ResetSentInRound()
|
||||
}
|
||||
}
|
||||
|
||||
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
ts := al.getAnyActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestGracefulInterrupt(hint) {
|
||||
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindGraceful,
|
||||
HintLen: len(hint),
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptHard() error {
|
||||
ts := al.getAnyActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestHardAbort() {
|
||||
return fmt.Errorf("turn %s is already aborting", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindHard,
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================== SubTurn Result Polling ======================
|
||||
|
||||
// dequeuePendingSubTurnResults polls the SubTurn result channel for the given
|
||||
// session and returns all available results without blocking.
|
||||
// Returns nil if no active turn state exists for this session.
|
||||
func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.ToolResult {
|
||||
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ts, ok := tsInterface.(*turnState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results []*tools.ToolResult
|
||||
for {
|
||||
select {
|
||||
case result, ok := <-ts.pendingResults:
|
||||
if !ok {
|
||||
return results
|
||||
}
|
||||
if result != nil {
|
||||
results = append(results, result)
|
||||
}
|
||||
default:
|
||||
return results
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== Hard Abort ======================
|
||||
|
||||
// HardAbort immediately cancels the running agent loop for the given session,
|
||||
// cascading the cancellation to all child SubTurns. This is a destructive operation
|
||||
// that terminates execution without waiting for graceful cleanup.
|
||||
//
|
||||
// Use this when the user explicitly requests immediate termination (e.g., "stop now", "abort").
|
||||
// For graceful interruption that allows the agent to finish the current tool and summarize,
|
||||
// use Steer() instead.
|
||||
func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("no active turn state found for session %s", sessionKey)
|
||||
}
|
||||
|
||||
ts, ok := tsInterface.(*turnState)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid turn state type for session %s", sessionKey)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Hard abort triggered", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"turn_id": ts.turnID,
|
||||
"depth": ts.depth,
|
||||
"initial_history_length": ts.initialHistoryLength,
|
||||
})
|
||||
|
||||
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
|
||||
// from adding more messages to the session. This prevents race conditions
|
||||
// where rollback happens while children are still writing.
|
||||
// Use isHardAbort=true for hard abort to immediately cancel all children.
|
||||
ts.Finish(true)
|
||||
|
||||
// Roll back session history to the state before the turn started.
|
||||
if ts.session != nil {
|
||||
history := ts.session.GetHistory(sessionKey)
|
||||
if ts.initialHistoryLength < len(history) {
|
||||
ts.session.SetHistory(sessionKey, history[:ts.initialHistoryLength])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================== Follow-Up Injection ======================
|
||||
|
||||
// InjectFollowUp enqueues a message to be automatically processed after the current
|
||||
// turn completes. Unlike Steer(), which interrupts the current execution, InjectFollowUp
|
||||
// waits for the current turn to finish naturally before processing the message.
|
||||
//
|
||||
// This is useful for:
|
||||
// - Automated workflows that need to chain multiple turns
|
||||
// - Background tasks that should run after the main task completes
|
||||
// - Scheduled follow-up actions
|
||||
//
|
||||
// The message will be processed via Continue() when the agent becomes idle.
|
||||
func (al *AgentLoop) InjectFollowUp(msg providers.Message) error {
|
||||
// InjectFollowUp uses the same steering queue mechanism as Steer(),
|
||||
// but the semantic difference is in when it's called:
|
||||
// - Steer() is called during active execution to interrupt
|
||||
// - InjectFollowUp() is called when planning future work
|
||||
//
|
||||
// Both end up in the same queue and are processed by Continue()
|
||||
// when the agent is idle.
|
||||
return al.Steer(msg)
|
||||
}
|
||||
|
||||
// ====================== API Aliases for Design Document Compatibility ======================
|
||||
|
||||
// InjectSteering is an alias for Steer() to match the design document naming.
|
||||
// It injects a steering message into the currently running agent loop.
|
||||
func (al *AgentLoop) InjectSteering(msg providers.Message) error {
|
||||
return al.Steer(msg)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,671 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// ====================== Config & Constants ======================
|
||||
const (
|
||||
// Default values for SubTurn configuration (used when config is not set or is zero)
|
||||
defaultMaxSubTurnDepth = 3
|
||||
defaultMaxConcurrentSubTurns = 5
|
||||
defaultConcurrencyTimeout = 30 * time.Second
|
||||
defaultSubTurnTimeout = 5 * time.Minute
|
||||
// maxEphemeralHistorySize limits the number of messages stored in ephemeral sessions.
|
||||
// This prevents memory accumulation in long-running sub-turns.
|
||||
maxEphemeralHistorySize = 50
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDepthLimitExceeded = errors.New("sub-turn depth limit exceeded")
|
||||
ErrInvalidSubTurnConfig = errors.New("invalid sub-turn config")
|
||||
ErrConcurrencyTimeout = errors.New("timeout waiting for concurrency slot")
|
||||
)
|
||||
|
||||
// getSubTurnConfig returns the effective SubTurn configuration with defaults applied.
|
||||
func (al *AgentLoop) getSubTurnConfig() subTurnRuntimeConfig {
|
||||
cfg := al.cfg.Agents.Defaults.SubTurn
|
||||
|
||||
maxDepth := cfg.MaxDepth
|
||||
if maxDepth <= 0 {
|
||||
maxDepth = defaultMaxSubTurnDepth
|
||||
}
|
||||
|
||||
maxConcurrent := cfg.MaxConcurrent
|
||||
if maxConcurrent <= 0 {
|
||||
maxConcurrent = defaultMaxConcurrentSubTurns
|
||||
}
|
||||
|
||||
concurrencyTimeout := time.Duration(cfg.ConcurrencyTimeoutSec) * time.Second
|
||||
if concurrencyTimeout <= 0 {
|
||||
concurrencyTimeout = defaultConcurrencyTimeout
|
||||
}
|
||||
|
||||
defaultTimeout := time.Duration(cfg.DefaultTimeoutMinutes) * time.Minute
|
||||
if defaultTimeout <= 0 {
|
||||
defaultTimeout = defaultSubTurnTimeout
|
||||
}
|
||||
|
||||
return subTurnRuntimeConfig{
|
||||
maxDepth: maxDepth,
|
||||
maxConcurrent: maxConcurrent,
|
||||
concurrencyTimeout: concurrencyTimeout,
|
||||
defaultTimeout: defaultTimeout,
|
||||
defaultTokenBudget: cfg.DefaultTokenBudget,
|
||||
}
|
||||
}
|
||||
|
||||
// subTurnRuntimeConfig holds the effective runtime configuration for SubTurn execution.
|
||||
type subTurnRuntimeConfig struct {
|
||||
maxDepth int
|
||||
maxConcurrent int
|
||||
concurrencyTimeout time.Duration
|
||||
defaultTimeout time.Duration
|
||||
defaultTokenBudget int
|
||||
}
|
||||
|
||||
// ====================== SubTurn Config ======================
|
||||
|
||||
// SubTurnConfig configures the execution of a child sub-turn.
|
||||
//
|
||||
// Usage Examples:
|
||||
//
|
||||
// Synchronous sub-turn (Async=false):
|
||||
//
|
||||
// cfg := SubTurnConfig{
|
||||
// Model: "gpt-4o-mini",
|
||||
// SystemPrompt: "Analyze this code",
|
||||
// Async: false, // Result returned immediately
|
||||
// }
|
||||
// result, err := SpawnSubTurn(ctx, cfg)
|
||||
// // Use result directly here
|
||||
// processResult(result)
|
||||
//
|
||||
// Asynchronous sub-turn (Async=true):
|
||||
//
|
||||
// cfg := SubTurnConfig{
|
||||
// Model: "gpt-4o-mini",
|
||||
// SystemPrompt: "Background analysis",
|
||||
// Async: true, // Result delivered to channel
|
||||
// }
|
||||
// result, err := SpawnSubTurn(ctx, cfg)
|
||||
// // Result also available in parent's pendingResults channel
|
||||
// // Parent turn will poll and process it in a later iteration
|
||||
type SubTurnConfig struct {
|
||||
Model string
|
||||
Tools []tools.Tool
|
||||
SystemPrompt string
|
||||
MaxTokens int
|
||||
|
||||
// Async controls the result delivery mechanism:
|
||||
//
|
||||
// When Async = false (synchronous sub-turn):
|
||||
// - The caller blocks until the sub-turn completes
|
||||
// - The result is ONLY returned via the function return value
|
||||
// - The result is NOT delivered to the parent's pendingResults channel
|
||||
// - This prevents double delivery: caller gets result immediately, no need for channel
|
||||
// - Use case: When the caller needs the result immediately to continue execution
|
||||
// - Example: A tool that needs to process the sub-turn result before returning
|
||||
//
|
||||
// When Async = true (asynchronous sub-turn):
|
||||
// - The sub-turn runs in the background (still blocks the caller, but semantically async)
|
||||
// - The result is delivered to the parent's pendingResults channel
|
||||
// - The result is ALSO returned via the function return value (for consistency)
|
||||
// - The parent turn can poll pendingResults in later iterations to process results
|
||||
// - Use case: Fire-and-forget operations, or when results are processed in batches
|
||||
// - Example: Spawning multiple sub-turns in parallel and collecting results later
|
||||
//
|
||||
// IMPORTANT: The Async flag does NOT make the call non-blocking. It only controls
|
||||
// whether the result is delivered via the channel. For true non-blocking execution,
|
||||
// the caller must spawn the sub-turn in a separate goroutine.
|
||||
Async bool
|
||||
|
||||
// Critical indicates this SubTurn's result is important and should continue
|
||||
// running even after the parent turn finishes gracefully.
|
||||
//
|
||||
// When parent finishes gracefully (Finish(false)):
|
||||
// - Critical=true: SubTurn continues running, delivers result as orphan
|
||||
// - Critical=false: SubTurn exits gracefully without error
|
||||
//
|
||||
// When parent finishes with hard abort (Finish(true)):
|
||||
// - All SubTurns are canceled regardless of Critical flag
|
||||
Critical bool
|
||||
|
||||
// Timeout is the maximum duration for this SubTurn.
|
||||
// If the SubTurn runs longer than this, it will be canceled.
|
||||
// Default is 5 minutes (defaultSubTurnTimeout) if not specified.
|
||||
Timeout time.Duration
|
||||
|
||||
// MaxContextRunes limits the context size (in runes) passed to the SubTurn.
|
||||
// This prevents context window overflow by truncating message history before LLM calls.
|
||||
//
|
||||
// Values:
|
||||
// 0 = Auto-calculate based on model's ContextWindow * 0.75 (default, recommended)
|
||||
// -1 = No limit (disable soft truncation, rely only on hard context errors)
|
||||
// >0 = Use specified rune limit
|
||||
//
|
||||
// The soft limit acts as a first line of defense before hitting the provider's
|
||||
// hard context window limit. When exceeded, older messages are intelligently
|
||||
// truncated while preserving system messages and recent context.
|
||||
MaxContextRunes int
|
||||
|
||||
// ActualSystemPrompt is injected as the true 'system' role message for the childAgent.
|
||||
// The legacy SystemPrompt field is actually used as the first 'user' message (task description).
|
||||
ActualSystemPrompt string
|
||||
|
||||
// InitialMessages preloads the ephemeral session history before the agent loop starts.
|
||||
// Used by evaluator-optimizer patterns to pass the full worker context across multiple iterations.
|
||||
InitialMessages []providers.Message
|
||||
|
||||
// InitialTokenBudget is a shared atomic counter for tracking remaining tokens.
|
||||
// If set, the SubTurn will inherit this budget and deduct tokens after each LLM call.
|
||||
// If nil, the SubTurn will inherit the parent's tokenBudget (if any).
|
||||
// Used by team tool to enforce token limits across all team members.
|
||||
InitialTokenBudget *atomic.Int64
|
||||
|
||||
// Can be extended with temperature, topP, etc.
|
||||
}
|
||||
|
||||
// ====================== Context Keys ======================
|
||||
type agentLoopKeyType struct{}
|
||||
|
||||
var agentLoopKey = agentLoopKeyType{}
|
||||
|
||||
// WithAgentLoop injects AgentLoop into context for tool access
|
||||
func WithAgentLoop(ctx context.Context, al *AgentLoop) context.Context {
|
||||
return context.WithValue(ctx, agentLoopKey, al)
|
||||
}
|
||||
|
||||
// AgentLoopFromContext retrieves AgentLoop from context
|
||||
func AgentLoopFromContext(ctx context.Context) *AgentLoop {
|
||||
al, _ := ctx.Value(agentLoopKey).(*AgentLoop)
|
||||
return al
|
||||
}
|
||||
|
||||
// ====================== Helper Functions ======================
|
||||
|
||||
func (al *AgentLoop) generateSubTurnID() string {
|
||||
return fmt.Sprintf("subturn-%d", al.subTurnCounter.Add(1))
|
||||
}
|
||||
|
||||
// ====================== Core Function: spawnSubTurn ======================
|
||||
|
||||
// AgentLoopSpawner implements tools.SubTurnSpawner interface.
|
||||
// This allows tools to spawn sub-turns without circular dependency.
|
||||
type AgentLoopSpawner struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
// SpawnSubTurn implements tools.SubTurnSpawner interface.
|
||||
func (s *AgentLoopSpawner) SpawnSubTurn(
|
||||
ctx context.Context,
|
||||
cfg tools.SubTurnConfig,
|
||||
) (*tools.ToolResult, error) {
|
||||
parentTS := turnStateFromContext(ctx)
|
||||
if parentTS == nil {
|
||||
return nil, errors.New(
|
||||
"parent turnState not found in context - cannot spawn sub-turn outside of a turn",
|
||||
)
|
||||
}
|
||||
|
||||
// Convert tools.SubTurnConfig to agent.SubTurnConfig
|
||||
agentCfg := SubTurnConfig{
|
||||
Model: cfg.Model,
|
||||
Tools: cfg.Tools,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
ActualSystemPrompt: cfg.ActualSystemPrompt,
|
||||
InitialMessages: cfg.InitialMessages,
|
||||
InitialTokenBudget: cfg.InitialTokenBudget,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Async: cfg.Async,
|
||||
Critical: cfg.Critical,
|
||||
Timeout: cfg.Timeout,
|
||||
MaxContextRunes: cfg.MaxContextRunes,
|
||||
}
|
||||
|
||||
return spawnSubTurn(ctx, s.al, parentTS, agentCfg)
|
||||
}
|
||||
|
||||
// NewSubTurnSpawner creates a SubTurnSpawner for the given AgentLoop.
|
||||
func NewSubTurnSpawner(al *AgentLoop) *AgentLoopSpawner {
|
||||
return &AgentLoopSpawner{al: al}
|
||||
}
|
||||
|
||||
// SpawnSubTurn is the exported entry point for tools to spawn sub-turns.
|
||||
// It retrieves AgentLoop and parent turnState from context and delegates to spawnSubTurn.
|
||||
func SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*tools.ToolResult, error) {
|
||||
al := AgentLoopFromContext(ctx)
|
||||
if al == nil {
|
||||
return nil, errors.New(
|
||||
"AgentLoop not found in context - ensure context is properly initialized",
|
||||
)
|
||||
}
|
||||
|
||||
parentTS := turnStateFromContext(ctx)
|
||||
if parentTS == nil {
|
||||
return nil, errors.New(
|
||||
"parent turnState not found in context - cannot spawn sub-turn outside of a turn",
|
||||
)
|
||||
}
|
||||
|
||||
return spawnSubTurn(ctx, al, parentTS, cfg)
|
||||
}
|
||||
|
||||
func spawnSubTurn(
|
||||
ctx context.Context,
|
||||
al *AgentLoop,
|
||||
parentTS *turnState,
|
||||
cfg SubTurnConfig,
|
||||
) (result *tools.ToolResult, err error) {
|
||||
// Get effective SubTurn configuration
|
||||
rtCfg := al.getSubTurnConfig()
|
||||
|
||||
// 0. Acquire concurrency semaphore FIRST to ensure it's released even if early validation fails.
|
||||
// Blocks if parent already has maxConcurrentSubTurns running, with a timeout to prevent indefinite blocking.
|
||||
// Also respects context cancellation so we don't block forever if parent is aborted.
|
||||
// NOTE: The semaphore is released immediately after runTurn completes (not in a defer) to
|
||||
// ensure it is freed before the cleanup phase (async result delivery), which may block on
|
||||
// a full pendingResults channel. Holding the semaphore through cleanup would allow the
|
||||
// parent's goroutine to be blocked waiting for a semaphore slot while child turns are
|
||||
// blocked delivering results — a deadlock.
|
||||
var semAcquired bool
|
||||
if parentTS.concurrencySem != nil {
|
||||
// Create a timeout context for semaphore acquisition
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, rtCfg.concurrencyTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case parentTS.concurrencySem <- struct{}{}:
|
||||
semAcquired = true
|
||||
defer func() {
|
||||
if semAcquired {
|
||||
<-parentTS.concurrencySem
|
||||
}
|
||||
}()
|
||||
case <-timeoutCtx.Done():
|
||||
// Check parent context first - if it was canceled, propagate that error
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
// Otherwise it's our timeout
|
||||
return nil, fmt.Errorf("%w: all %d slots occupied for %v",
|
||||
ErrConcurrencyTimeout, rtCfg.maxConcurrent, rtCfg.concurrencyTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Depth limit check
|
||||
if parentTS.depth >= rtCfg.maxDepth {
|
||||
logger.WarnCF("subturn", "Depth limit exceeded", map[string]any{
|
||||
"parent_id": parentTS.turnID,
|
||||
"depth": parentTS.depth,
|
||||
"max_depth": rtCfg.maxDepth,
|
||||
})
|
||||
return nil, ErrDepthLimitExceeded
|
||||
}
|
||||
|
||||
// 2. Config validation
|
||||
if cfg.Model == "" {
|
||||
return nil, ErrInvalidSubTurnConfig
|
||||
}
|
||||
|
||||
// 3. Determine timeout for child SubTurn
|
||||
timeout := cfg.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = rtCfg.defaultTimeout
|
||||
}
|
||||
|
||||
// 4. Create INDEPENDENT child context (not derived from parent ctx).
|
||||
// This allows the child to continue running after parent finishes gracefully.
|
||||
// The child has its own timeout for self-protection.
|
||||
childCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
childID := al.generateSubTurnID()
|
||||
|
||||
// Get the agent instance from parent, falling back to the default agent.
|
||||
// Wrap it in a shallow copy that uses an ephemeral (in-memory only) session store
|
||||
// so that child turns never pollute or persist to the parent's session history.
|
||||
baseAgent := parentTS.agent
|
||||
if baseAgent == nil {
|
||||
baseAgent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
if baseAgent == nil {
|
||||
return nil, errors.New("parent turnState has no agent instance")
|
||||
}
|
||||
ephemeralStore := newEphemeralSession(nil)
|
||||
agent := *baseAgent // shallow copy
|
||||
agent.Sessions = ephemeralStore
|
||||
// Clone the tool registry so child turn's tool registrations
|
||||
// don't pollute the parent's registry.
|
||||
if baseAgent.Tools != nil {
|
||||
agent.Tools = baseAgent.Tools.Clone()
|
||||
}
|
||||
|
||||
// Create processOptions for the child turn
|
||||
opts := processOptions{
|
||||
SessionKey: childID,
|
||||
Channel: parentTS.channel,
|
||||
ChatID: parentTS.chatID,
|
||||
SenderID: parentTS.opts.SenderID,
|
||||
SenderDisplayName: parentTS.opts.SenderDisplayName,
|
||||
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
|
||||
SystemPromptOverride: cfg.ActualSystemPrompt,
|
||||
Media: nil,
|
||||
InitialSteeringMessages: cfg.InitialMessages,
|
||||
DefaultResponse: "",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
NoHistory: true, // SubTurns don't use session history
|
||||
SkipInitialSteeringPoll: true,
|
||||
}
|
||||
|
||||
// Create event scope for the child turn
|
||||
scope := al.newTurnEventScope(agent.ID, childID)
|
||||
|
||||
// Create child turnState using the new API
|
||||
childTS := newTurnState(&agent, opts, scope)
|
||||
|
||||
// Set SubTurn-specific fields
|
||||
childTS.cancelFunc = cancel
|
||||
childTS.critical = cfg.Critical
|
||||
childTS.depth = parentTS.depth + 1
|
||||
childTS.parentTurnID = parentTS.turnID
|
||||
childTS.parentTurnState = parentTS
|
||||
childTS.pendingResults = make(chan *tools.ToolResult, 16)
|
||||
childTS.concurrencySem = make(chan struct{}, rtCfg.maxConcurrent)
|
||||
childTS.al = al // back-ref for hard abort cascade
|
||||
childTS.session = ephemeralStore // same store as agent.Sessions
|
||||
|
||||
// Token budget initialization/inheritance
|
||||
// If InitialTokenBudget is explicitly provided (e.g., by team tool), use it.
|
||||
// Otherwise, inherit from parent's tokenBudget (for nested SubTurns).
|
||||
if cfg.InitialTokenBudget != nil {
|
||||
childTS.tokenBudget = cfg.InitialTokenBudget
|
||||
} else if parentTS.tokenBudget != nil {
|
||||
childTS.tokenBudget = parentTS.tokenBudget
|
||||
} else if rtCfg.defaultTokenBudget > 0 {
|
||||
// Apply default token budget from config if no budget is set
|
||||
budget := &atomic.Int64{}
|
||||
budget.Store(int64(rtCfg.defaultTokenBudget))
|
||||
childTS.tokenBudget = budget
|
||||
}
|
||||
|
||||
// IMPORTANT: Put childTS into childCtx so that code inside runTurn can retrieve it
|
||||
childCtx = withTurnState(childCtx, childTS)
|
||||
childCtx = WithAgentLoop(childCtx, al) // Propagate AgentLoop to child turn
|
||||
|
||||
childTS.ctx = childCtx
|
||||
|
||||
// Register child turn state so GetAllActiveTurns/Subagents can find it
|
||||
al.activeTurnStates.Store(childID, childTS)
|
||||
defer al.activeTurnStates.Delete(childID)
|
||||
|
||||
// 5. Establish parent-child relationship (thread-safe)
|
||||
parentTS.mu.Lock()
|
||||
parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID)
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// 6. Emit Spawn event
|
||||
al.emitEvent(EventKindSubTurnSpawn,
|
||||
childTS.eventMeta("spawnSubTurn", "subturn.spawn"),
|
||||
SubTurnSpawnPayload{
|
||||
AgentID: childTS.agentID,
|
||||
Label: childID,
|
||||
ParentTurnID: parentTS.turnID,
|
||||
},
|
||||
)
|
||||
|
||||
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("subturn panicked: %v", r)
|
||||
result = nil
|
||||
logger.ErrorCF("subturn", "SubTurn panicked", map[string]any{
|
||||
"child_id": childID,
|
||||
"parent_id": parentTS.turnID,
|
||||
"panic": r,
|
||||
})
|
||||
}
|
||||
|
||||
// Result Delivery Strategy (Async vs Sync)
|
||||
if cfg.Async {
|
||||
deliverSubTurnResult(al, parentTS, childID, result)
|
||||
}
|
||||
|
||||
status := "completed"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
al.emitEvent(EventKindSubTurnEnd,
|
||||
childTS.eventMeta("spawnSubTurn", "subturn.end"),
|
||||
SubTurnEndPayload{
|
||||
AgentID: childTS.agentID,
|
||||
Status: status,
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
// 8. Execute sub-turn via the real agent loop.
|
||||
turnRes, turnErr := al.runTurn(childCtx, childTS)
|
||||
|
||||
// Release the concurrency semaphore immediately after runTurn completes,
|
||||
// before the cleanup defer runs. This prevents a deadlock where:
|
||||
// - All semaphore slots are held by sub-turns in their cleanup phase
|
||||
// - Cleanup blocks on a full pendingResults channel
|
||||
// - The parent goroutine is blocked waiting for a semaphore slot
|
||||
// - The parent cannot consume pendingResults because it is blocked on the semaphore
|
||||
if semAcquired {
|
||||
<-parentTS.concurrencySem
|
||||
semAcquired = false // prevent the defer from double-releasing
|
||||
}
|
||||
|
||||
// Convert turnResult to tools.ToolResult
|
||||
if turnErr != nil {
|
||||
err = turnErr
|
||||
result = &tools.ToolResult{
|
||||
Err: turnErr,
|
||||
ForLLM: fmt.Sprintf("SubTurn failed: %v", turnErr),
|
||||
}
|
||||
} else {
|
||||
result = &tools.ToolResult{
|
||||
ForLLM: turnRes.finalContent,
|
||||
ForUser: turnRes.finalContent,
|
||||
}
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ====================== Result Delivery ======================
|
||||
|
||||
// deliverSubTurnResult delivers a sub-turn result to the parent turn's pendingResults channel.
|
||||
//
|
||||
// IMPORTANT: This function is ONLY called for asynchronous sub-turns (Async=true).
|
||||
// For synchronous sub-turns (Async=false), results are returned directly via the function
|
||||
// return value to avoid double delivery.
|
||||
//
|
||||
// Delivery behavior:
|
||||
// - If parent turn is still running: attempts to deliver to pendingResults channel
|
||||
// - If channel is full: emits SubTurnOrphanResultEvent (result is lost from channel but tracked)
|
||||
// - If parent turn has finished: emits SubTurnOrphanResultEvent (late arrival)
|
||||
//
|
||||
// Thread safety:
|
||||
// - Reads parent state under lock, then releases lock before channel send
|
||||
// - Small race window exists but is acceptable (worst case: result becomes orphan)
|
||||
//
|
||||
// Event emissions:
|
||||
// - SubTurnResultDeliveredEvent: successful delivery to channel
|
||||
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
|
||||
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
|
||||
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
|
||||
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.WarnCF("subturn", "recovered panic sending to pendingResults", map[string]any{
|
||||
"parent_id": parentTS.turnID,
|
||||
"child_id": childID,
|
||||
"recover": r,
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
|
||||
)
|
||||
}
|
||||
}
|
||||
}()
|
||||
parentTS.mu.Lock()
|
||||
isFinished := parentTS.isFinished.Load()
|
||||
resultChan := parentTS.pendingResults
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// If parent turn has already finished, treat this as an orphan result
|
||||
if isFinished || resultChan == nil {
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Parent Turn is still running → attempt to deliver result
|
||||
// We use a select statement with parentTS.Finished() to ensure that if the
|
||||
// parent turn finishes while we are waiting to send the result (e.g. channel
|
||||
// is full), we don't leak this goroutine by blocking forever.
|
||||
select {
|
||||
case resultChan <- result:
|
||||
// Successfully delivered
|
||||
if al != nil {
|
||||
al.emitEvent(EventKindSubTurnResultDelivered,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.result_delivered"),
|
||||
SubTurnResultDeliveredPayload{ContentLen: len(result.ForLLM)},
|
||||
)
|
||||
}
|
||||
case <-parentTS.Finished():
|
||||
// Parent finished while we were waiting to deliver.
|
||||
// The result cannot be delivered to the LLM, so it becomes an orphan.
|
||||
logger.WarnCF("subturn", "parent finished before result could be delivered", map[string]any{
|
||||
"parent_id": parentTS.turnID,
|
||||
"child_id": childID,
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(
|
||||
EventKindSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{
|
||||
ParentTurnID: parentTS.turnID,
|
||||
ChildTurnID: childID,
|
||||
Reason: "parent_finished_waiting",
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== Other Types ======================
|
||||
|
||||
// ephemeralSessionStore is an in-memory session.SessionStore used by SubTurns.
|
||||
// It does not persist to disk and auto-truncates history to maxEphemeralHistorySize.
|
||||
type ephemeralSessionStore struct {
|
||||
mu sync.Mutex
|
||||
history []providers.Message
|
||||
summary string
|
||||
}
|
||||
|
||||
func newEphemeralSession(initial []providers.Message) ephemeralSessionStoreIface {
|
||||
s := &ephemeralSessionStore{}
|
||||
if len(initial) > 0 {
|
||||
s.history = append(s.history, initial...)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// ephemeralSessionStoreIface is satisfied by *ephemeralSessionStore.
|
||||
// Declared so newEphemeralSession can return a typed interface.
|
||||
type ephemeralSessionStoreIface interface {
|
||||
AddMessage(sessionKey, role, content string)
|
||||
AddFullMessage(sessionKey string, msg providers.Message)
|
||||
GetHistory(key string) []providers.Message
|
||||
GetSummary(key string) string
|
||||
SetSummary(key, summary string)
|
||||
SetHistory(key string, history []providers.Message)
|
||||
TruncateHistory(key string, keepLast int)
|
||||
Save(key string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) AddMessage(_, role, content string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = append(e.history, providers.Message{Role: role, Content: content})
|
||||
e.truncateLocked()
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) AddFullMessage(_ string, msg providers.Message) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = append(e.history, msg)
|
||||
e.truncateLocked()
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) GetHistory(_ string) []providers.Message {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]providers.Message, len(e.history))
|
||||
copy(out, e.history)
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) GetSummary(_ string) string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.summary
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) SetSummary(_, summary string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.summary = summary
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) SetHistory(_ string, history []providers.Message) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.history = make([]providers.Message, len(history))
|
||||
copy(e.history, history)
|
||||
e.truncateLocked()
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if keepLast <= 0 {
|
||||
e.history = nil
|
||||
return
|
||||
}
|
||||
|
||||
if keepLast >= len(e.history) {
|
||||
return
|
||||
}
|
||||
e.history = e.history[len(e.history)-keepLast:]
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
|
||||
func (e *ephemeralSessionStore) truncateLocked() {
|
||||
if len(e.history) > maxEphemeralHistorySize {
|
||||
e.history = e.history[len(e.history)-maxEphemeralHistorySize:]
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,481 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
type TurnPhase string
|
||||
|
||||
const (
|
||||
TurnPhaseSetup TurnPhase = "setup"
|
||||
TurnPhaseRunning TurnPhase = "running"
|
||||
TurnPhaseTools TurnPhase = "tools"
|
||||
TurnPhaseFinalizing TurnPhase = "finalizing"
|
||||
TurnPhaseCompleted TurnPhase = "completed"
|
||||
TurnPhaseAborted TurnPhase = "aborted"
|
||||
)
|
||||
|
||||
type ActiveTurnInfo struct {
|
||||
TurnID string
|
||||
AgentID string
|
||||
SessionKey string
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
Phase TurnPhase
|
||||
Iteration int
|
||||
StartedAt time.Time
|
||||
Depth int
|
||||
ParentTurnID string
|
||||
ChildTurnIDs []string
|
||||
}
|
||||
|
||||
type turnResult struct {
|
||||
finalContent string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
}
|
||||
|
||||
type turnState struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
agent *AgentInstance
|
||||
opts processOptions
|
||||
scope turnEventScope
|
||||
|
||||
turnID string
|
||||
agentID string
|
||||
sessionKey string
|
||||
|
||||
channel string
|
||||
chatID string
|
||||
userMessage string
|
||||
media []string
|
||||
|
||||
phase TurnPhase
|
||||
iteration int
|
||||
startedAt time.Time
|
||||
finalContent string
|
||||
|
||||
followUps []bus.InboundMessage
|
||||
|
||||
gracefulInterrupt bool
|
||||
gracefulInterruptHint string
|
||||
gracefulTerminalUsed bool
|
||||
hardAbort bool
|
||||
providerCancel context.CancelFunc
|
||||
turnCancel context.CancelFunc
|
||||
|
||||
restorePointHistory []providers.Message
|
||||
restorePointSummary string
|
||||
persistedMessages []providers.Message
|
||||
|
||||
// SubTurn support (from HEAD)
|
||||
depth int // SubTurn depth (0 for root turn)
|
||||
parentTurnID string // Parent turn ID (empty for root turn)
|
||||
childTurnIDs []string // Child turn IDs
|
||||
pendingResults chan *tools.ToolResult // Channel for SubTurn results
|
||||
concurrencySem chan struct{} // Semaphore for limiting concurrent SubTurns
|
||||
isFinished atomic.Bool // Whether this turn has finished
|
||||
session session.SessionStore // Session store reference
|
||||
initialHistoryLength int // Snapshot of history length at turn start
|
||||
|
||||
// Additional SubTurn fields
|
||||
ctx context.Context // Context for this turn
|
||||
cancelFunc context.CancelFunc // Cancel function for this turn's context
|
||||
critical bool // Whether this SubTurn should continue after parent ends
|
||||
parentTurnState *turnState // Reference to parent turnState
|
||||
parentEnded atomic.Bool // Whether parent has ended
|
||||
closeOnce sync.Once // Ensures pendingResults channel is closed once
|
||||
finishedChan chan struct{} // Closed when turn finishes
|
||||
|
||||
// Token budget tracking
|
||||
tokenBudget *atomic.Int64 // Shared token budget counter
|
||||
lastFinishReason string // Last LLM finish_reason
|
||||
lastUsage *providers.UsageInfo // Last LLM usage info
|
||||
|
||||
// Back-reference to the owning AgentLoop (set for SubTurns only, used for hard abort cascade)
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
|
||||
ts := &turnState{
|
||||
agent: agent,
|
||||
opts: opts,
|
||||
scope: scope,
|
||||
turnID: scope.turnID,
|
||||
agentID: agent.ID,
|
||||
sessionKey: opts.SessionKey,
|
||||
channel: opts.Channel,
|
||||
chatID: opts.ChatID,
|
||||
userMessage: opts.UserMessage,
|
||||
media: append([]string(nil), opts.Media...),
|
||||
phase: TurnPhaseSetup,
|
||||
startedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Bind session store and capture initial history length for rollback logic
|
||||
if agent != nil && agent.Sessions != nil {
|
||||
ts.session = agent.Sessions
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
|
||||
al.activeTurnStates.Store(ts.sessionKey, ts)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
|
||||
al.activeTurnStates.Delete(ts.sessionKey)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
|
||||
if val, ok := al.activeTurnStates.Load(sessionKey); ok {
|
||||
return val.(*turnState)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAnyActiveTurnState returns any active turn state (for backward compatibility)
|
||||
func (al *AgentLoop) getAnyActiveTurnState() *turnState {
|
||||
var firstTS *turnState
|
||||
al.activeTurnStates.Range(func(key, value any) bool {
|
||||
firstTS = value.(*turnState)
|
||||
return false // stop after first
|
||||
})
|
||||
return firstTS
|
||||
}
|
||||
|
||||
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
|
||||
// For backward compatibility, return the first active turn found
|
||||
// In the new architecture, there can be multiple concurrent turns
|
||||
var firstTS *turnState
|
||||
al.activeTurnStates.Range(func(key, value any) bool {
|
||||
firstTS = value.(*turnState)
|
||||
return false // stop after first
|
||||
})
|
||||
if firstTS == nil {
|
||||
return nil
|
||||
}
|
||||
info := firstTS.snapshot()
|
||||
return &info
|
||||
}
|
||||
|
||||
func (al *AgentLoop) GetActiveTurnBySession(sessionKey string) *ActiveTurnInfo {
|
||||
ts := al.getActiveTurnState(sessionKey)
|
||||
if ts == nil {
|
||||
return nil
|
||||
}
|
||||
info := ts.snapshot()
|
||||
return &info
|
||||
}
|
||||
|
||||
func (ts *turnState) snapshot() ActiveTurnInfo {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
return ActiveTurnInfo{
|
||||
TurnID: ts.turnID,
|
||||
AgentID: ts.agentID,
|
||||
SessionKey: ts.sessionKey,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
UserMessage: ts.userMessage,
|
||||
Phase: ts.phase,
|
||||
Iteration: ts.iteration,
|
||||
StartedAt: ts.startedAt,
|
||||
Depth: ts.depth,
|
||||
ParentTurnID: ts.parentTurnID,
|
||||
ChildTurnIDs: append([]string(nil), ts.childTurnIDs...),
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) setPhase(phase TurnPhase) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.phase = phase
|
||||
}
|
||||
|
||||
func (ts *turnState) setIteration(iteration int) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.iteration = iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) currentIteration() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) setFinalContent(content string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.finalContent = content
|
||||
}
|
||||
|
||||
func (ts *turnState) finalContentLen() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return len(ts.finalContent)
|
||||
}
|
||||
|
||||
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.turnCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = nil
|
||||
}
|
||||
|
||||
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ts.hardAbort {
|
||||
return false
|
||||
}
|
||||
ts.gracefulInterrupt = true
|
||||
ts.gracefulInterruptHint = hint
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
|
||||
}
|
||||
|
||||
func (ts *turnState) markGracefulTerminalUsed() {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.gracefulTerminalUsed = true
|
||||
}
|
||||
|
||||
func (ts *turnState) requestHardAbort() bool {
|
||||
ts.mu.Lock()
|
||||
if ts.hardAbort {
|
||||
ts.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
ts.hardAbort = true
|
||||
turnCancel := ts.turnCancel
|
||||
providerCancel := ts.providerCancel
|
||||
ts.mu.Unlock()
|
||||
|
||||
if providerCancel != nil {
|
||||
providerCancel()
|
||||
}
|
||||
if turnCancel != nil {
|
||||
turnCancel()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) hardAbortRequested() bool {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.hardAbort
|
||||
}
|
||||
|
||||
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
||||
snap := ts.snapshot()
|
||||
return EventMeta{
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
||||
ts.restorePointSummary = summary
|
||||
}
|
||||
|
||||
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.persistedMessages = append(ts.persistedMessages, msg)
|
||||
}
|
||||
|
||||
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
history := agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
||||
|
||||
ts.mu.RLock()
|
||||
persisted := append([]providers.Message(nil), ts.persistedMessages...)
|
||||
ts.mu.RUnlock()
|
||||
|
||||
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
||||
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
}
|
||||
|
||||
ts.captureRestorePoint(history, summary)
|
||||
}
|
||||
|
||||
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
ts.mu.RLock()
|
||||
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
||||
summary := ts.restorePointSummary
|
||||
ts.mu.RUnlock()
|
||||
|
||||
agent.Sessions.SetHistory(ts.sessionKey, history)
|
||||
agent.Sessions.SetSummary(ts.sessionKey, summary)
|
||||
return agent.Sessions.Save(ts.sessionKey)
|
||||
}
|
||||
|
||||
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
||||
maxMatch := min(len(history), len(persisted))
|
||||
for size := maxMatch; size > 0; size-- {
|
||||
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
return size
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ts *turnState) interruptHintMessage() providers.Message {
|
||||
_, hint := ts.gracefulInterruptRequested()
|
||||
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
|
||||
if hint != "" {
|
||||
content += "\n\nInterrupt hint: " + hint
|
||||
}
|
||||
return providers.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
// SubTurn-related methods
|
||||
|
||||
// Finish marks the turn as finished and closes the pendingResults channel
|
||||
func (ts *turnState) Finish(isHardAbort bool) {
|
||||
ts.isFinished.Store(true)
|
||||
|
||||
// Close pendingResults channel exactly once
|
||||
ts.closeOnce.Do(func() {
|
||||
if ts.pendingResults != nil {
|
||||
close(ts.pendingResults)
|
||||
}
|
||||
ts.mu.Lock()
|
||||
if ts.finishedChan == nil {
|
||||
ts.finishedChan = make(chan struct{})
|
||||
}
|
||||
close(ts.finishedChan)
|
||||
ts.mu.Unlock()
|
||||
})
|
||||
|
||||
// If this is a graceful finish (not hard abort), signal to children
|
||||
if !isHardAbort && ts.parentTurnState == nil {
|
||||
// This is a root turn finishing gracefully
|
||||
ts.parentEnded.Store(true)
|
||||
}
|
||||
|
||||
// Cancel the turn context
|
||||
if ts.cancelFunc != nil {
|
||||
ts.cancelFunc()
|
||||
}
|
||||
|
||||
// Hard abort cascades to all child turns
|
||||
if isHardAbort && ts.al != nil {
|
||||
ts.mu.RLock()
|
||||
children := append([]string(nil), ts.childTurnIDs...)
|
||||
ts.mu.RUnlock()
|
||||
for _, childID := range children {
|
||||
if val, ok := ts.al.activeTurnStates.Load(childID); ok {
|
||||
val.(*turnState).Finish(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finished returns whether the turn has finished
|
||||
func (ts *turnState) Finished() chan struct{} {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ts.finishedChan == nil {
|
||||
ts.finishedChan = make(chan struct{})
|
||||
}
|
||||
return ts.finishedChan
|
||||
}
|
||||
|
||||
// IsParentEnded checks if the parent turn has ended
|
||||
func (ts *turnState) IsParentEnded() bool {
|
||||
if ts.parentTurnState == nil {
|
||||
return false
|
||||
}
|
||||
return ts.parentTurnState.parentEnded.Load()
|
||||
}
|
||||
|
||||
// GetLastFinishReason returns the last LLM finish_reason
|
||||
func (ts *turnState) GetLastFinishReason() string {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.lastFinishReason
|
||||
}
|
||||
|
||||
// SetLastFinishReason sets the last LLM finish_reason
|
||||
func (ts *turnState) SetLastFinishReason(reason string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.lastFinishReason = reason
|
||||
}
|
||||
|
||||
// GetLastUsage returns the last LLM usage info
|
||||
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.lastUsage
|
||||
}
|
||||
|
||||
// SetLastUsage sets the last LLM usage info
|
||||
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.lastUsage = usage
|
||||
}
|
||||
|
||||
// Context helper functions for SubTurn
|
||||
|
||||
type turnStateKeyType struct{}
|
||||
|
||||
var turnStateKey = turnStateKeyType{}
|
||||
|
||||
func withTurnState(ctx context.Context, ts *turnState) context.Context {
|
||||
return context.WithValue(ctx, turnStateKey, ts)
|
||||
}
|
||||
|
||||
func turnStateFromContext(ctx context.Context) *turnState {
|
||||
ts, _ := ctx.Value(turnStateKey).(*turnState)
|
||||
return ts
|
||||
}
|
||||
|
||||
// TurnStateFromContext retrieves turnState from context (exported for tools)
|
||||
func TurnStateFromContext(ctx context.Context) *turnState {
|
||||
return turnStateFromContext(ctx)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ func BuiltinDefinitions() []Definition {
|
||||
switchCommand(),
|
||||
checkCommand(),
|
||||
clearCommand(),
|
||||
subagentsCommand(),
|
||||
reloadCommand(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// TurnInfo is a mirrored struct from agent.TurnInfo to avoid circular dependencies.
|
||||
type TurnInfo struct {
|
||||
TurnID string
|
||||
ParentTurnID string
|
||||
Depth int
|
||||
ChildTurnIDs []string
|
||||
IsFinished bool
|
||||
}
|
||||
|
||||
func subagentsCommand() Definition {
|
||||
return Definition{
|
||||
Name: "subagents",
|
||||
Description: "Show running subagents and task tree",
|
||||
Handler: func(ctx context.Context, req Request, rt *Runtime) error {
|
||||
getTurnFn := rt.GetActiveTurn
|
||||
if getTurnFn == nil {
|
||||
return req.Reply("Runtime does not support querying active turns.")
|
||||
}
|
||||
|
||||
turnRaw := getTurnFn()
|
||||
if turnRaw == nil {
|
||||
return req.Reply("No active tasks running in this session.")
|
||||
}
|
||||
|
||||
if treeStr, ok := turnRaw.(string); ok {
|
||||
if treeStr == "" {
|
||||
return req.Reply("No active tasks running in this session.")
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("🤖 **Active Subagents Tree**\n```text\n%s\n```", treeStr))
|
||||
}
|
||||
|
||||
return req.Reply(fmt.Sprintf("🤖 **Active Subagents List**\n```text\n%+v\n```", turnRaw))
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type Runtime struct {
|
||||
ListDefinitions func() []Definition
|
||||
ListSkillNames func() []string
|
||||
GetEnabledChannels func() []string
|
||||
GetActiveTurn func() any // Returning any to avoid circular dependency with agent package
|
||||
SwitchModel func(value string) (oldModel string, err error)
|
||||
SwitchChannel func(value string) error
|
||||
ClearHistory func() error
|
||||
|
||||
+58
-8
@@ -84,6 +84,7 @@ type Config struct {
|
||||
Providers ProvidersConfig `json:"providers,omitempty"`
|
||||
ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
Hooks HooksConfig `json:"hooks,omitempty"`
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
Devices DevicesConfig `json:"devices"`
|
||||
@@ -92,6 +93,36 @@ type Config struct {
|
||||
BuildInfo BuildInfo `json:"build_info,omitempty"`
|
||||
}
|
||||
|
||||
type HooksConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Defaults HookDefaultsConfig `json:"defaults,omitempty"`
|
||||
Builtins map[string]BuiltinHookConfig `json:"builtins,omitempty"`
|
||||
Processes map[string]ProcessHookConfig `json:"processes,omitempty"`
|
||||
}
|
||||
|
||||
type HookDefaultsConfig struct {
|
||||
ObserverTimeoutMS int `json:"observer_timeout_ms,omitempty"`
|
||||
InterceptorTimeoutMS int `json:"interceptor_timeout_ms,omitempty"`
|
||||
ApprovalTimeoutMS int `json:"approval_timeout_ms,omitempty"`
|
||||
}
|
||||
|
||||
type BuiltinHookConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Config json.RawMessage `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type ProcessHookConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Transport string `json:"transport,omitempty"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Observe []string `json:"observe,omitempty"`
|
||||
Intercept []string `json:"intercept,omitempty"`
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time version information
|
||||
type BuildInfo struct {
|
||||
Version string `json:"version"`
|
||||
@@ -219,9 +250,15 @@ type RoutingConfig struct {
|
||||
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
|
||||
}
|
||||
|
||||
// ToolFeedbackConfig controls whether tool execution details are sent to the
|
||||
// chat channel as real-time feedback messages. When enabled, every tool call
|
||||
// produces a short notification with the tool name and its parameters.
|
||||
// SubTurnConfig configures the SubTurn execution system.
|
||||
type SubTurnConfig struct {
|
||||
MaxDepth int `json:"max_depth" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_MAX_DEPTH"`
|
||||
MaxConcurrent int `json:"max_concurrent" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_MAX_CONCURRENT"`
|
||||
DefaultTimeoutMinutes int `json:"default_timeout_minutes" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_DEFAULT_TIMEOUT_MINUTES"`
|
||||
DefaultTokenBudget int `json:"default_token_budget" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_DEFAULT_TOKEN_BUDGET"`
|
||||
ConcurrencyTimeoutSec int `json:"concurrency_timeout_sec" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_CONCURRENCY_TIMEOUT_SEC"`
|
||||
}
|
||||
|
||||
type ToolFeedbackConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_ENABLED"`
|
||||
MaxArgsLength int `json:"max_args_length" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_MAX_ARGS_LENGTH"`
|
||||
@@ -238,12 +275,15 @@ type AgentDefaults struct {
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
|
||||
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
|
||||
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
|
||||
LogLevel string `json:"log_level,omitempty" env:"PICOCLAW_LOG_LEVEL"`
|
||||
}
|
||||
@@ -923,10 +963,13 @@ func LoadConfig(path string) (*Config, error) {
|
||||
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") &&
|
||||
!strings.HasPrefix(m.APIKey, "file://") {
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
|
||||
m.ModelName)
|
||||
m.ModelName,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -979,7 +1022,8 @@ func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelCo
|
||||
changed := false
|
||||
for i := range sealed {
|
||||
m := &sealed[i]
|
||||
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
|
||||
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") ||
|
||||
strings.HasPrefix(m.APIKey, "file://") {
|
||||
continue
|
||||
}
|
||||
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
|
||||
@@ -1012,7 +1056,13 @@ func resolveAPIKeys(models []ModelConfig, configDir string) error {
|
||||
for j, key := range models[i].APIKeys {
|
||||
resolved, err := cr.Resolve(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model_list[%d] (%s): api_keys[%d]: %w", i, models[i].ModelName, j, err)
|
||||
return fmt.Errorf(
|
||||
"model_list[%d] (%s): api_keys[%d]: %w",
|
||||
i,
|
||||
models[i].ModelName,
|
||||
j,
|
||||
err,
|
||||
)
|
||||
}
|
||||
models[i].APIKeys[j] = resolved
|
||||
}
|
||||
|
||||
@@ -470,6 +470,22 @@ func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_HooksDefaults(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Hooks.Enabled {
|
||||
t.Fatal("DefaultConfig().Hooks.Enabled should be true")
|
||||
}
|
||||
if cfg.Hooks.Defaults.ObserverTimeoutMS != 500 {
|
||||
t.Fatalf("ObserverTimeoutMS = %d, want 500", cfg.Hooks.Defaults.ObserverTimeoutMS)
|
||||
}
|
||||
if cfg.Hooks.Defaults.InterceptorTimeoutMS != 5000 {
|
||||
t.Fatalf("InterceptorTimeoutMS = %d, want 5000", cfg.Hooks.Defaults.InterceptorTimeoutMS)
|
||||
}
|
||||
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
|
||||
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_LogLevel(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Agents.Defaults.LogLevel != "fatal" {
|
||||
@@ -562,6 +578,88 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_HooksProcessConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.json")
|
||||
configJSON := `{
|
||||
"hooks": {
|
||||
"processes": {
|
||||
"review-gate": {
|
||||
"enabled": true,
|
||||
"transport": "stdio",
|
||||
"command": ["uvx", "picoclaw-hook-reviewer"],
|
||||
"dir": "/tmp/hooks",
|
||||
"env": {
|
||||
"HOOK_MODE": "rewrite"
|
||||
},
|
||||
"observe": ["turn_start", "turn_end"],
|
||||
"intercept": ["before_tool", "approve_tool"]
|
||||
}
|
||||
},
|
||||
"builtins": {
|
||||
"audit": {
|
||||
"enabled": true,
|
||||
"priority": 5,
|
||||
"config": {
|
||||
"label": "audit"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil {
|
||||
t.Fatalf("os.WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
processCfg, ok := cfg.Hooks.Processes["review-gate"]
|
||||
if !ok {
|
||||
t.Fatal("expected review-gate process hook")
|
||||
}
|
||||
if !processCfg.Enabled {
|
||||
t.Fatal("expected review-gate process hook to be enabled")
|
||||
}
|
||||
if processCfg.Transport != "stdio" {
|
||||
t.Fatalf("Transport = %q, want stdio", processCfg.Transport)
|
||||
}
|
||||
if len(processCfg.Command) != 2 || processCfg.Command[0] != "uvx" {
|
||||
t.Fatalf("Command = %v", processCfg.Command)
|
||||
}
|
||||
if processCfg.Dir != "/tmp/hooks" {
|
||||
t.Fatalf("Dir = %q, want /tmp/hooks", processCfg.Dir)
|
||||
}
|
||||
if processCfg.Env["HOOK_MODE"] != "rewrite" {
|
||||
t.Fatalf("HOOK_MODE = %q, want rewrite", processCfg.Env["HOOK_MODE"])
|
||||
}
|
||||
if len(processCfg.Observe) != 2 || processCfg.Observe[1] != "turn_end" {
|
||||
t.Fatalf("Observe = %v", processCfg.Observe)
|
||||
}
|
||||
if len(processCfg.Intercept) != 2 || processCfg.Intercept[1] != "approve_tool" {
|
||||
t.Fatalf("Intercept = %v", processCfg.Intercept)
|
||||
}
|
||||
|
||||
builtinCfg, ok := cfg.Hooks.Builtins["audit"]
|
||||
if !ok {
|
||||
t.Fatal("expected audit builtin hook")
|
||||
}
|
||||
if !builtinCfg.Enabled {
|
||||
t.Fatal("expected audit builtin hook to be enabled")
|
||||
}
|
||||
if builtinCfg.Priority != 5 {
|
||||
t.Fatalf("Priority = %d, want 5", builtinCfg.Priority)
|
||||
}
|
||||
if !strings.Contains(string(builtinCfg.Config), `"audit"`) {
|
||||
t.Fatalf("Config = %s", string(builtinCfg.Config))
|
||||
}
|
||||
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
|
||||
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_DMScope verifies the default dm_scope value
|
||||
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
|
||||
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
|
||||
|
||||
@@ -36,6 +36,7 @@ func DefaultConfig() *Config {
|
||||
MaxToolIterations: 50,
|
||||
SummarizeMessageThreshold: 20,
|
||||
SummarizeTokenPercent: 75,
|
||||
SteeringMode: "one-at-a-time",
|
||||
ToolFeedback: ToolFeedbackConfig{
|
||||
Enabled: true,
|
||||
MaxArgsLength: 300,
|
||||
@@ -193,6 +194,14 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Hooks: HooksConfig{
|
||||
Enabled: true,
|
||||
Defaults: HookDefaultsConfig{
|
||||
ObserverTimeoutMS: 500,
|
||||
InterceptorTimeoutMS: 5000,
|
||||
ApprovalTimeoutMS: 60000,
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: OpenAIProviderConfig{WebSearch: true},
|
||||
},
|
||||
|
||||
@@ -214,11 +214,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
FinishReason: normalizeFinishReason(choice.FinishReason),
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeFinishReason normalizes finish_reason values across providers.
|
||||
// Converts "length" to "truncated" for consistent handling.
|
||||
func normalizeFinishReason(reason string) string {
|
||||
if reason == "length" {
|
||||
return "truncated"
|
||||
}
|
||||
return reason
|
||||
}
|
||||
|
||||
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
|
||||
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
|
||||
@@ -384,3 +384,22 @@ func (r *ToolRegistry) GetSummaries() []string {
|
||||
}
|
||||
return summaries
|
||||
}
|
||||
|
||||
// GetAll returns all registered tools (both core and non-core with TTL > 0).
|
||||
// Used by SubTurn to inherit parent's tool set.
|
||||
func (r *ToolRegistry) GetAll() []Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
sorted := r.sortedToolNames()
|
||||
tools := make([]Tool, 0, len(sorted))
|
||||
for _, name := range sorted {
|
||||
entry := r.tools[name]
|
||||
|
||||
// Include core tools and non-core tools with active TTL
|
||||
if entry.IsCore || entry.TTL > 0 {
|
||||
tools = append(tools, entry.Tool)
|
||||
}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
+10
-1
@@ -1,6 +1,10 @@
|
||||
package tools
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ToolResult represents the structured return value from tool execution.
|
||||
// It provides clear semantics for different types of results and supports
|
||||
@@ -34,6 +38,11 @@ type ToolResult struct {
|
||||
// Media contains media store refs produced by this tool.
|
||||
// When non-empty, the agent will publish these as OutboundMediaMessage.
|
||||
Media []string `json:"media,omitempty"`
|
||||
|
||||
// Messages holds the ephemeral session history after execution.
|
||||
// Only populated by SubTurn executions; used by evaluator_optimizer
|
||||
// to carry stateful worker context across evaluation iterations.
|
||||
Messages []providers.Message `json:"-"`
|
||||
}
|
||||
|
||||
// NewToolResult creates a basic ToolResult with content for the LLM.
|
||||
|
||||
+71
-25
@@ -7,7 +7,10 @@ import (
|
||||
)
|
||||
|
||||
type SpawnTool struct {
|
||||
manager *SubagentManager
|
||||
spawner SubTurnSpawner
|
||||
defaultModel string
|
||||
maxTokens int
|
||||
temperature float64
|
||||
allowlistCheck func(targetAgentID string) bool
|
||||
}
|
||||
|
||||
@@ -15,9 +18,19 @@ type SpawnTool struct {
|
||||
var _ AsyncExecutor = (*SpawnTool)(nil)
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
return &SpawnTool{
|
||||
manager: manager,
|
||||
if manager == nil {
|
||||
return &SpawnTool{}
|
||||
}
|
||||
return &SpawnTool{
|
||||
defaultModel: manager.defaultModel,
|
||||
maxTokens: manager.maxTokens,
|
||||
temperature: manager.temperature,
|
||||
}
|
||||
}
|
||||
|
||||
// SetSpawner sets the SubTurnSpawner for direct sub-turn execution.
|
||||
func (t *SpawnTool) SetSpawner(spawner SubTurnSpawner) {
|
||||
t.spawner = spawner
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Name() string {
|
||||
@@ -59,11 +72,19 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul
|
||||
|
||||
// ExecuteAsync implements AsyncExecutor. The callback is passed through to the
|
||||
// subagent manager as a call parameter — never stored on the SpawnTool instance.
|
||||
func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
func (t *SpawnTool) ExecuteAsync(
|
||||
ctx context.Context,
|
||||
args map[string]any,
|
||||
cb AsyncCallback,
|
||||
) *ToolResult {
|
||||
return t.execute(ctx, args, cb)
|
||||
}
|
||||
|
||||
func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
|
||||
func (t *SpawnTool) execute(
|
||||
ctx context.Context,
|
||||
args map[string]any,
|
||||
cb AsyncCallback,
|
||||
) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok || strings.TrimSpace(task) == "" {
|
||||
return ErrorResult("task is required and must be a non-empty string")
|
||||
@@ -79,28 +100,53 @@ func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCa
|
||||
}
|
||||
}
|
||||
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
// Build system prompt for spawned subagent
|
||||
systemPrompt := fmt.Sprintf(
|
||||
`You are a spawned subagent running in the background. Complete the given task independently and report back when done.
|
||||
|
||||
Task: %s`,
|
||||
task,
|
||||
)
|
||||
|
||||
if label != "" {
|
||||
systemPrompt = fmt.Sprintf(
|
||||
`You are a spawned subagent labeled "%s" running in the background. Complete the given task independently and report back when done.
|
||||
|
||||
Task: %s`,
|
||||
label,
|
||||
task,
|
||||
)
|
||||
}
|
||||
|
||||
// Read channel/chatID from context (injected by registry).
|
||||
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
|
||||
// to preserve the same defaults as the original NewSpawnTool constructor.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
// Use spawner if available (direct SpawnSubTurn call)
|
||||
if t.spawner != nil {
|
||||
// Launch async sub-turn in goroutine
|
||||
go func() {
|
||||
result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{
|
||||
Model: t.defaultModel,
|
||||
Tools: nil, // Will inherit from parent via context
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxTokens: t.maxTokens,
|
||||
Temperature: t.temperature,
|
||||
Async: true, // Async execution
|
||||
})
|
||||
if err != nil {
|
||||
result = ErrorResult(fmt.Sprintf("Spawn failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
// Call callback if provided
|
||||
if cb != nil {
|
||||
cb(ctx, result)
|
||||
}
|
||||
}()
|
||||
|
||||
// Return immediate acknowledgment
|
||||
if label != "" {
|
||||
return AsyncResult(fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task))
|
||||
}
|
||||
return AsyncResult(fmt.Sprintf("Spawned subagent for task: %s", task))
|
||||
}
|
||||
|
||||
// Pass callback to manager for async completion notification
|
||||
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
// Return AsyncResult since the task runs in background
|
||||
return AsyncResult(result)
|
||||
// Fallback: spawner not configured
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,24 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockSpawner implements SubTurnSpawner for testing
|
||||
type mockSpawner struct{}
|
||||
|
||||
func (m *mockSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*ToolResult, error) {
|
||||
// Extract task from system prompt for response
|
||||
task := cfg.SystemPrompt
|
||||
if strings.Contains(task, "Task: ") {
|
||||
parts := strings.Split(task, "Task: ")
|
||||
if len(parts) > 1 {
|
||||
task = parts[1]
|
||||
}
|
||||
}
|
||||
return &ToolResult{
|
||||
ForLLM: "Task completed: " + task,
|
||||
ForUser: "Task completed",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
@@ -44,6 +62,7 @@ func TestSpawnTool_Execute_ValidTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
|
||||
+180
-127
@@ -4,11 +4,34 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// SubTurnSpawner is an interface for spawning sub-turns.
|
||||
// This avoids circular dependency between tools and agent packages.
|
||||
type SubTurnSpawner interface {
|
||||
SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*ToolResult, error)
|
||||
}
|
||||
|
||||
// SubTurnConfig holds configuration for spawning a sub-turn.
|
||||
type SubTurnConfig struct {
|
||||
Model string
|
||||
Tools []Tool
|
||||
SystemPrompt string
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
Async bool // true for async (spawn), false for sync (subagent)
|
||||
Critical bool // continue running after parent finishes gracefully
|
||||
Timeout time.Duration // 0 = use default (5 minutes)
|
||||
MaxContextRunes int // 0 = auto, -1 = no limit, >0 = explicit limit
|
||||
ActualSystemPrompt string
|
||||
InitialMessages []providers.Message
|
||||
InitialTokenBudget *atomic.Int64 // Shared token budget for team members; nil if no budget
|
||||
}
|
||||
|
||||
type SubagentTask struct {
|
||||
ID string
|
||||
Task string
|
||||
@@ -21,6 +44,15 @@ type SubagentTask struct {
|
||||
Created int64
|
||||
}
|
||||
|
||||
type SpawnSubTurnFunc func(
|
||||
ctx context.Context,
|
||||
task, label, agentID string,
|
||||
tools *ToolRegistry,
|
||||
maxTokens int,
|
||||
temperature float64,
|
||||
hasMaxTokens, hasTemperature bool,
|
||||
) (*ToolResult, error)
|
||||
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
@@ -34,6 +66,7 @@ type SubagentManager struct {
|
||||
hasMaxTokens bool
|
||||
hasTemperature bool
|
||||
nextID int
|
||||
spawner SpawnSubTurnFunc
|
||||
}
|
||||
|
||||
func NewSubagentManager(
|
||||
@@ -51,6 +84,12 @@ func NewSubagentManager(
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) SetSpawner(spawner SpawnSubTurnFunc) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.spawner = spawner
|
||||
}
|
||||
|
||||
// SetLLMOptions sets max tokens and temperature for subagent LLM calls.
|
||||
func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) {
|
||||
sm.mu.Lock()
|
||||
@@ -108,22 +147,16 @@ func (sm *SubagentManager) Spawn(
|
||||
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
|
||||
// Build system prompt for subagent
|
||||
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
|
||||
You have access to tools - use them as needed to complete your task.
|
||||
After completing the task, provide a clear summary of what was done.`
|
||||
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: task.Task,
|
||||
},
|
||||
}
|
||||
func (sm *SubagentManager) runTask(
|
||||
ctx context.Context,
|
||||
task *SubagentTask,
|
||||
callback AsyncCallback,
|
||||
) {
|
||||
task.Status = "running"
|
||||
task.Created = time.Now().UnixMilli()
|
||||
// TODO(eventbus): once subagents are modeled as child turns inside
|
||||
// pkg/agent, emit SubTurnEnd and SubTurnResultDelivered from the parent
|
||||
// AgentLoop instead of this legacy manager.
|
||||
|
||||
// Check if context is already canceled before starting
|
||||
select {
|
||||
@@ -136,8 +169,8 @@ After completing the task, provide a clear summary of what was done.`
|
||||
default:
|
||||
}
|
||||
|
||||
// Run tool loop with access to tools
|
||||
sm.mu.RLock()
|
||||
spawner := sm.spawner
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
maxTokens := sm.maxTokens
|
||||
@@ -146,27 +179,69 @@ After completing the task, provide a clear summary of what was done.`
|
||||
hasTemperature := sm.hasTemperature
|
||||
sm.mu.RUnlock()
|
||||
|
||||
var llmOptions map[string]any
|
||||
if hasMaxTokens || hasTemperature {
|
||||
llmOptions = map[string]any{}
|
||||
if hasMaxTokens {
|
||||
llmOptions["max_tokens"] = maxTokens
|
||||
var result *ToolResult
|
||||
var err error
|
||||
|
||||
if spawner != nil {
|
||||
result, err = spawner(
|
||||
ctx,
|
||||
task.Task,
|
||||
task.Label,
|
||||
task.AgentID,
|
||||
tools,
|
||||
maxTokens,
|
||||
temperature,
|
||||
hasMaxTokens,
|
||||
hasTemperature,
|
||||
)
|
||||
} else {
|
||||
// Fallback to legacy RunToolLoop
|
||||
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
|
||||
You have access to tools - use them as needed to complete your task.
|
||||
After completing the task, provide a clear summary of what was done.`
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: task.Task},
|
||||
}
|
||||
if hasTemperature {
|
||||
llmOptions["temperature"] = temperature
|
||||
|
||||
var llmOptions map[string]any
|
||||
if hasMaxTokens || hasTemperature {
|
||||
llmOptions = map[string]any{}
|
||||
if hasMaxTokens {
|
||||
llmOptions["max_tokens"] = maxTokens
|
||||
}
|
||||
if hasTemperature {
|
||||
llmOptions["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
var loopResult *ToolLoopResult
|
||||
loopResult, err = RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, task.OriginChannel, task.OriginChatID)
|
||||
|
||||
if err == nil {
|
||||
result = &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
"Subagent '%s' completed (iterations: %d): %s",
|
||||
task.Label,
|
||||
loopResult.Iterations,
|
||||
loopResult.Content,
|
||||
),
|
||||
ForUser: loopResult.Content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, task.OriginChannel, task.OriginChatID)
|
||||
|
||||
sm.mu.Lock()
|
||||
var result *ToolResult
|
||||
defer func() {
|
||||
sm.mu.Unlock()
|
||||
// Call callback if provided and result is set
|
||||
@@ -193,19 +268,7 @@ After completing the task, provide a clear summary of what was done.`
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
task.Result = loopResult.Content
|
||||
result = &ToolResult{
|
||||
ForLLM: fmt.Sprintf(
|
||||
"Subagent '%s' completed (iterations: %d): %s",
|
||||
task.Label,
|
||||
loopResult.Iterations,
|
||||
loopResult.Content,
|
||||
),
|
||||
ForUser: loopResult.Content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
task.Result = result.ForLLM
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,16 +316,28 @@ func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
|
||||
// and returns the result directly in the ToolResult.
|
||||
// It directly calls SubTurnSpawner with Async=false for synchronous execution.
|
||||
type SubagentTool struct {
|
||||
manager *SubagentManager
|
||||
spawner SubTurnSpawner
|
||||
defaultModel string
|
||||
maxTokens int
|
||||
temperature float64
|
||||
}
|
||||
|
||||
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
|
||||
return &SubagentTool{
|
||||
manager: manager,
|
||||
if manager == nil {
|
||||
return &SubagentTool{}
|
||||
}
|
||||
return &SubagentTool{
|
||||
defaultModel: manager.defaultModel,
|
||||
maxTokens: manager.maxTokens,
|
||||
temperature: manager.temperature,
|
||||
}
|
||||
}
|
||||
|
||||
// SetSpawner sets the SubTurnSpawner for direct sub-turn execution.
|
||||
func (t *SubagentTool) SetSpawner(spawner SubTurnSpawner) {
|
||||
t.spawner = spawner
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Name() string {
|
||||
@@ -298,86 +373,64 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
|
||||
// Build system prompt for subagent
|
||||
systemPrompt := fmt.Sprintf(
|
||||
`You are a subagent. Complete the given task independently and provide a clear, concise result.
|
||||
|
||||
Task: %s`,
|
||||
task,
|
||||
)
|
||||
|
||||
if label != "" {
|
||||
systemPrompt = fmt.Sprintf(
|
||||
`You are a subagent labeled "%s". Complete the given task independently and provide a clear, concise result.
|
||||
|
||||
Task: %s`,
|
||||
label,
|
||||
task,
|
||||
)
|
||||
}
|
||||
|
||||
// Build messages for subagent
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: task,
|
||||
},
|
||||
}
|
||||
|
||||
// Use RunToolLoop to execute with tools (same as async SpawnTool)
|
||||
sm := t.manager
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
maxTokens := sm.maxTokens
|
||||
temperature := sm.temperature
|
||||
hasMaxTokens := sm.hasMaxTokens
|
||||
hasTemperature := sm.hasTemperature
|
||||
sm.mu.RUnlock()
|
||||
|
||||
var llmOptions map[string]any
|
||||
if hasMaxTokens || hasTemperature {
|
||||
llmOptions = map[string]any{}
|
||||
if hasMaxTokens {
|
||||
llmOptions["max_tokens"] = maxTokens
|
||||
// Use spawner if available (direct SpawnSubTurn call)
|
||||
if t.spawner != nil {
|
||||
result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{
|
||||
Model: t.defaultModel,
|
||||
Tools: nil, // Will inherit from parent via context
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxTokens: t.maxTokens,
|
||||
Temperature: t.temperature,
|
||||
Async: false, // Synchronous execution
|
||||
})
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
if hasTemperature {
|
||||
llmOptions["temperature"] = temperature
|
||||
|
||||
// Format result for display
|
||||
userContent := result.ForLLM
|
||||
if result.ForUser != "" {
|
||||
userContent = result.ForUser
|
||||
}
|
||||
maxUserLen := 500
|
||||
if len(userContent) > maxUserLen {
|
||||
userContent = userContent[:maxUserLen] + "..."
|
||||
}
|
||||
|
||||
labelStr := label
|
||||
if labelStr == "" {
|
||||
labelStr = "(unnamed)"
|
||||
}
|
||||
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nResult: %s",
|
||||
labelStr, result.ForLLM)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: llmContent,
|
||||
ForUser: userContent,
|
||||
Silent: false,
|
||||
IsError: result.IsError,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
|
||||
// to preserve the same defaults as the original NewSubagentTool constructor.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: llmOptions,
|
||||
}, messages, channel, chatID)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
// ForUser: Brief summary for user (truncated if too long)
|
||||
userContent := loopResult.Content
|
||||
maxUserLen := 500
|
||||
if len(userContent) > maxUserLen {
|
||||
userContent = userContent[:maxUserLen] + "..."
|
||||
}
|
||||
|
||||
// ForLLM: Full execution details
|
||||
labelStr := label
|
||||
if labelStr == "" {
|
||||
labelStr = "(unnamed)"
|
||||
}
|
||||
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
|
||||
labelStr, loopResult.Iterations, loopResult.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: llmContent,
|
||||
ForUser: userContent,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
// Fallback: spawner not configured
|
||||
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("spawner not set"))
|
||||
}
|
||||
|
||||
@@ -48,24 +48,19 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
manager.SetLLMOptions(2048, 0.6)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
args := map[string]any{"task": "Do something"}
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("Expected successful result, got: %+v", result)
|
||||
// Verify options are set on manager
|
||||
if manager.maxTokens != 2048 {
|
||||
t.Errorf("manager.maxTokens = %d, want 2048", manager.maxTokens)
|
||||
}
|
||||
|
||||
if provider.lastOptions == nil {
|
||||
t.Fatal("Expected LLM options to be passed, got nil")
|
||||
if manager.temperature != 0.6 {
|
||||
t.Errorf("manager.temperature = %f, want 0.6", manager.temperature)
|
||||
}
|
||||
if provider.lastOptions["max_tokens"] != 2048 {
|
||||
t.Fatalf("max_tokens = %v, want %d", provider.lastOptions["max_tokens"], 2048)
|
||||
if !manager.hasMaxTokens {
|
||||
t.Error("manager.hasMaxTokens should be true")
|
||||
}
|
||||
if provider.lastOptions["temperature"] != 0.6 {
|
||||
t.Fatalf("temperature = %v, want %v", provider.lastOptions["temperature"], 0.6)
|
||||
if !manager.hasTemperature {
|
||||
t.Error("manager.hasTemperature should be true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,6 +145,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
|
||||
args := map[string]any{
|
||||
@@ -204,6 +200,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
@@ -277,6 +274,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
channel := "test-channel"
|
||||
chatID := "test-chat"
|
||||
@@ -302,6 +300,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetSpawner(&mockSpawner{})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// CalculateDefaultMaxContextRunes computes a default context limit based on the model's context window.
|
||||
// Strategy: Use 75% of the context window and convert to rune estimate.
|
||||
//
|
||||
// Token-to-rune conversion ratios (conservative estimates):
|
||||
// - English: ~4 chars per token
|
||||
// - Chinese: ~1.5-2 chars per token
|
||||
// - Mixed: ~3 chars per token (used here for safety)
|
||||
func CalculateDefaultMaxContextRunes(contextWindow int) int {
|
||||
if contextWindow <= 0 {
|
||||
// Conservative fallback when context window is unknown
|
||||
return 8000 // ~2000 tokens
|
||||
}
|
||||
|
||||
// Use 75% of context window to leave headroom
|
||||
targetTokens := int(float64(contextWindow) * 0.75)
|
||||
|
||||
// Convert tokens to runes using conservative ratio
|
||||
const avgCharsPerToken = 3
|
||||
return targetTokens * avgCharsPerToken
|
||||
}
|
||||
|
||||
// ResolveMaxContextRunes determines the final MaxContextRunes value to use.
|
||||
// Priority: explicit config > auto-calculate > conservative default
|
||||
func ResolveMaxContextRunes(configValue, contextWindow int) int {
|
||||
switch {
|
||||
case configValue > 0:
|
||||
// Explicitly configured, use as-is
|
||||
return configValue
|
||||
case configValue == -1:
|
||||
// Explicitly disabled
|
||||
return -1
|
||||
default:
|
||||
// 0 or unset: auto-calculate
|
||||
return CalculateDefaultMaxContextRunes(contextWindow)
|
||||
}
|
||||
}
|
||||
|
||||
// MeasureContextRunes calculates the total rune count of a message list.
|
||||
// Includes content, reasoning content, and estimates for tool calls.
|
||||
func MeasureContextRunes(messages []providers.Message) int {
|
||||
totalRunes := 0
|
||||
for _, msg := range messages {
|
||||
totalRunes += utf8.RuneCountInString(msg.Content)
|
||||
totalRunes += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
// Tool calls: serialize to JSON and count
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
totalRunes += utf8.RuneCountInString(tc.Name)
|
||||
// Arguments: serialize and count
|
||||
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
|
||||
totalRunes += utf8.RuneCount(argsJSON)
|
||||
} else {
|
||||
// Fallback estimate if serialization fails
|
||||
totalRunes += 100
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToolCallID
|
||||
totalRunes += utf8.RuneCountInString(msg.ToolCallID)
|
||||
}
|
||||
return totalRunes
|
||||
}
|
||||
|
||||
// TruncateContextSmart intelligently truncates message history to fit within maxRunes.
|
||||
//
|
||||
// Strategy:
|
||||
// 1. Always preserve system messages (they define the agent's behavior)
|
||||
// 2. Keep the most recent messages (they contain current context)
|
||||
// 3. Drop older middle messages when necessary
|
||||
// 4. Insert a truncation notice to inform the LLM
|
||||
//
|
||||
// Returns the truncated message list.
|
||||
func TruncateContextSmart(messages []providers.Message, maxRunes int) []providers.Message {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Separate system messages from others
|
||||
var systemMsgs []providers.Message
|
||||
var otherMsgs []providers.Message
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemMsgs = append(systemMsgs, msg)
|
||||
} else {
|
||||
otherMsgs = append(otherMsgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate system message size
|
||||
systemRunes := 0
|
||||
for _, msg := range systemMsgs {
|
||||
systemRunes += utf8.RuneCountInString(msg.Content)
|
||||
systemRunes += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
}
|
||||
|
||||
// Reserve space for truncation notice (estimate ~80 runes)
|
||||
const truncationNoticeEstimate = 80
|
||||
|
||||
// Allocate remaining space for other messages
|
||||
remainingRunes := maxRunes - systemRunes - truncationNoticeEstimate
|
||||
if remainingRunes <= 0 {
|
||||
// System messages already exceed limit - return only system messages
|
||||
return systemMsgs
|
||||
}
|
||||
|
||||
// Collect recent messages in reverse order until we hit the limit
|
||||
var keptMsgs []providers.Message
|
||||
currentRunes := 0
|
||||
|
||||
for i := len(otherMsgs) - 1; i >= 0; i-- {
|
||||
msg := otherMsgs[i]
|
||||
msgRunes := utf8.RuneCountInString(msg.Content) +
|
||||
utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
// Estimate tool call size
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
msgRunes += utf8.RuneCountInString(tc.Name)
|
||||
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
|
||||
msgRunes += utf8.RuneCount(argsJSON)
|
||||
} else {
|
||||
msgRunes += 100
|
||||
}
|
||||
}
|
||||
}
|
||||
msgRunes += utf8.RuneCountInString(msg.ToolCallID)
|
||||
|
||||
if currentRunes+msgRunes > remainingRunes {
|
||||
// Would exceed limit, stop collecting
|
||||
break
|
||||
}
|
||||
|
||||
// Prepend to maintain chronological order
|
||||
keptMsgs = append([]providers.Message{msg}, keptMsgs...)
|
||||
currentRunes += msgRunes
|
||||
}
|
||||
|
||||
// If we dropped messages, add a truncation notice
|
||||
result := systemMsgs
|
||||
if len(keptMsgs) < len(otherMsgs) {
|
||||
droppedCount := len(otherMsgs) - len(keptMsgs)
|
||||
truncationNotice := providers.Message{
|
||||
Role: "system",
|
||||
Content: fmt.Sprintf(
|
||||
"[Context truncated: %d earlier messages omitted to stay within context limits]",
|
||||
droppedCount,
|
||||
),
|
||||
}
|
||||
result = append(result, truncationNotice)
|
||||
}
|
||||
|
||||
result = append(result, keptMsgs...)
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestCalculateDefaultMaxContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contextWindow int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "zero context window uses fallback",
|
||||
contextWindow: 0,
|
||||
want: 8000,
|
||||
},
|
||||
{
|
||||
name: "negative context window uses fallback",
|
||||
contextWindow: -1,
|
||||
want: 8000,
|
||||
},
|
||||
{
|
||||
name: "small context window (4k tokens)",
|
||||
contextWindow: 4000,
|
||||
want: 9000, // 4000 * 0.75 * 3 = 9000
|
||||
},
|
||||
{
|
||||
name: "medium context window (128k tokens)",
|
||||
contextWindow: 128000,
|
||||
want: 288000, // 128000 * 0.75 * 3 = 288000
|
||||
},
|
||||
{
|
||||
name: "large context window (1M tokens)",
|
||||
contextWindow: 1000000,
|
||||
want: 2250000, // 1000000 * 0.75 * 3 = 2250000
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CalculateDefaultMaxContextRunes(tt.contextWindow)
|
||||
if got != tt.want {
|
||||
t.Errorf("CalculateDefaultMaxContextRunes(%d) = %d, want %d",
|
||||
tt.contextWindow, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMaxContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configValue int
|
||||
contextWindow int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "explicit positive value",
|
||||
configValue: 12000,
|
||||
contextWindow: 4000,
|
||||
want: 12000,
|
||||
},
|
||||
{
|
||||
name: "explicit disable (-1)",
|
||||
configValue: -1,
|
||||
contextWindow: 4000,
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "zero uses auto-calculate",
|
||||
configValue: 0,
|
||||
contextWindow: 4000,
|
||||
want: 9000, // 4000 * 0.75 * 3
|
||||
},
|
||||
{
|
||||
name: "unset (0) with unknown context window",
|
||||
configValue: 0,
|
||||
contextWindow: 0,
|
||||
want: 8000, // fallback
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ResolveMaxContextRunes(tt.configValue, tt.contextWindow)
|
||||
if got != tt.want {
|
||||
t.Errorf("ResolveMaxContextRunes(%d, %d) = %d, want %d",
|
||||
tt.configValue, tt.contextWindow, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeasureContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []providers.Message
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "empty messages",
|
||||
messages: []providers.Message{},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "single simple message",
|
||||
messages: []providers.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
want: 5, // "Hello" = 5 runes
|
||||
},
|
||||
{
|
||||
name: "message with reasoning",
|
||||
messages: []providers.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Answer",
|
||||
ReasoningContent: "Thinking",
|
||||
},
|
||||
},
|
||||
want: 14, // "Answer" (6) + "Thinking" (8) = 14
|
||||
},
|
||||
{
|
||||
name: "message with tool call",
|
||||
messages: []providers.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Using tool",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
Name: "test_tool",
|
||||
Arguments: map[string]any{"key": "value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 10 + 9 + 15, // "Using tool" + "test_tool" + {"key":"value"}
|
||||
},
|
||||
{
|
||||
name: "multiple messages",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello!"},
|
||||
},
|
||||
want: 15 + 2 + 6, // 15 + 2 + 6 = 23
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
messages: []providers.Message{
|
||||
{Role: "user", Content: "\u4f60\u597d\u4e16\u754c"}, // 4 Chinese characters
|
||||
},
|
||||
want: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MeasureContextRunes(tt.messages)
|
||||
if got != tt.want {
|
||||
t.Errorf("MeasureContextRunes() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateContextSmart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []providers.Message
|
||||
maxRunes int
|
||||
wantLen int
|
||||
wantHas []string // Content strings that should be present
|
||||
wantNot []string // Content strings that should be absent
|
||||
}{
|
||||
{
|
||||
name: "empty messages",
|
||||
messages: []providers.Message{},
|
||||
maxRunes: 100,
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "no truncation needed",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "System"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
maxRunes: 100,
|
||||
wantLen: 2,
|
||||
wantHas: []string{"System", "Hello"},
|
||||
},
|
||||
{
|
||||
name: "truncate when limit is tight",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "System"},
|
||||
{Role: "user", Content: "Message 1 with some content here"},
|
||||
{Role: "assistant", Content: "Response 1 with some content here"},
|
||||
{Role: "user", Content: "Message 2 with some content here"},
|
||||
{Role: "assistant", Content: "Response 2 with some content here"},
|
||||
{Role: "user", Content: "Latest"},
|
||||
},
|
||||
maxRunes: 120, // Tight limit to force truncation
|
||||
wantLen: -1, // Don't check exact length, just verify truncation occurred
|
||||
wantHas: []string{"System", "Latest"},
|
||||
wantNot: []string{"Message 1", "Response 1"},
|
||||
},
|
||||
{
|
||||
name: "system messages exceed limit",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "Very long system message"},
|
||||
{Role: "user", Content: "User message"},
|
||||
},
|
||||
maxRunes: 10, // Less than system message
|
||||
wantLen: 1, // Only system message
|
||||
wantHas: []string{"Very long system message"},
|
||||
wantNot: []string{"User message"},
|
||||
},
|
||||
{
|
||||
name: "preserve multiple system messages",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "Sys1"},
|
||||
{Role: "system", Content: "Sys2"},
|
||||
{Role: "user", Content: "Old"},
|
||||
{Role: "user", Content: "New"},
|
||||
},
|
||||
maxRunes: 200, // Generous limit
|
||||
wantLen: 4, // Both system + truncation notice + new
|
||||
wantHas: []string{"Sys1", "Sys2", "New"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := TruncateContextSmart(tt.messages, tt.maxRunes)
|
||||
|
||||
if tt.wantLen >= 0 && len(got) != tt.wantLen {
|
||||
t.Errorf("TruncateContextSmart() returned %d messages, want %d",
|
||||
len(got), tt.wantLen)
|
||||
}
|
||||
|
||||
// Check for expected content
|
||||
allContent := ""
|
||||
for _, msg := range got {
|
||||
allContent += msg.Content + " "
|
||||
}
|
||||
|
||||
for _, want := range tt.wantHas {
|
||||
found := false
|
||||
for _, msg := range got {
|
||||
if msg.Content == want || containsSubstring(msg.Content, want) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected content %q not found in truncated messages", want)
|
||||
}
|
||||
}
|
||||
|
||||
for _, notWant := range tt.wantNot {
|
||||
for _, msg := range got {
|
||||
if containsSubstring(msg.Content, notWant) {
|
||||
t.Errorf("Unexpected content %q found in truncated messages", notWant)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
return len(s) >= len(substr) && findSubstring(s, substr)
|
||||
}
|
||||
|
||||
func findSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestSubTurnConfigMaxContextRunes verifies that MaxContextRunes configuration
|
||||
// is properly integrated into the SubTurn execution flow.
|
||||
func TestSubTurnConfigMaxContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxContextRunes int
|
||||
contextWindow int
|
||||
wantResolved int
|
||||
}{
|
||||
{
|
||||
name: "default (0) auto-calculates from context window",
|
||||
maxContextRunes: 0,
|
||||
contextWindow: 4000,
|
||||
wantResolved: 9000, // 4000 * 0.75 * 3
|
||||
},
|
||||
{
|
||||
name: "explicit value is used",
|
||||
maxContextRunes: 12000,
|
||||
contextWindow: 4000,
|
||||
wantResolved: 12000,
|
||||
},
|
||||
{
|
||||
name: "disabled (-1) returns -1",
|
||||
maxContextRunes: -1,
|
||||
contextWindow: 4000,
|
||||
wantResolved: -1,
|
||||
},
|
||||
{
|
||||
name: "fallback when context window unknown",
|
||||
maxContextRunes: 0,
|
||||
contextWindow: 0,
|
||||
wantResolved: 8000, // conservative fallback
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ResolveMaxContextRunes(tt.maxContextRunes, tt.contextWindow)
|
||||
if got != tt.wantResolved {
|
||||
t.Errorf("utils.ResolveMaxContextRunes(%d, %d) = %d, want %d",
|
||||
tt.maxContextRunes, tt.contextWindow, got, tt.wantResolved)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextTruncationFlow verifies the complete context truncation flow:
|
||||
// 1. Messages accumulate beyond soft limit
|
||||
// 2. Truncation is triggered
|
||||
// 3. System messages are preserved
|
||||
// 4. Recent messages are kept
|
||||
func TestContextTruncationFlow(t *testing.T) {
|
||||
// Build a message history that exceeds the limit
|
||||
messages := []providers.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"}, // ~27 runes
|
||||
{Role: "user", Content: "First question"}, // ~14 runes
|
||||
{Role: "assistant", Content: "First answer"}, // ~12 runes
|
||||
{Role: "user", Content: "Second question"}, // ~15 runes
|
||||
{Role: "assistant", Content: "Second answer"}, // ~13 runes
|
||||
{Role: "user", Content: "Third question"}, // ~14 runes
|
||||
{Role: "assistant", Content: "Third answer"}, // ~12 runes
|
||||
{Role: "user", Content: "Latest question"}, // ~15 runes
|
||||
}
|
||||
|
||||
// Total: ~122 runes
|
||||
totalRunes := MeasureContextRunes(messages)
|
||||
if totalRunes < 100 {
|
||||
t.Errorf("Expected total runes > 100, got %d", totalRunes)
|
||||
}
|
||||
|
||||
// Set limit to 150 runes - should force truncation of old messages
|
||||
// but preserve system + truncation notice + recent messages
|
||||
maxRunes := 150
|
||||
truncated := TruncateContextSmart(messages, maxRunes)
|
||||
|
||||
// Verify truncation occurred
|
||||
if len(truncated) >= len(messages) {
|
||||
t.Errorf("Expected truncation, but got %d messages (original: %d)",
|
||||
len(truncated), len(messages))
|
||||
}
|
||||
|
||||
// Verify system message is preserved
|
||||
foundSystem := false
|
||||
for _, msg := range truncated {
|
||||
if msg.Role == "system" && msg.Content == "You are a helpful assistant" {
|
||||
foundSystem = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundSystem {
|
||||
t.Error("System message was not preserved after truncation")
|
||||
}
|
||||
|
||||
// Verify latest message is preserved
|
||||
foundLatest := false
|
||||
for _, msg := range truncated {
|
||||
if msg.Content == "Latest question" {
|
||||
foundLatest = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundLatest {
|
||||
t.Error("Latest message was not preserved after truncation")
|
||||
}
|
||||
|
||||
// Verify truncation notice is present
|
||||
foundNotice := false
|
||||
for _, msg := range truncated {
|
||||
if msg.Role == "system" && containsSubstring(msg.Content, "truncated") {
|
||||
foundNotice = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundNotice {
|
||||
t.Error("Truncation notice was not added")
|
||||
}
|
||||
|
||||
// Verify result is within limit (with some tolerance for estimation)
|
||||
resultRunes := MeasureContextRunes(truncated)
|
||||
if resultRunes > maxRunes+20 { // Allow 20 rune tolerance
|
||||
t.Errorf("Truncated context (%d runes) significantly exceeds limit (%d runes)",
|
||||
resultRunes, maxRunes)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextTruncationPreservesToolCalls verifies that tool calls are
|
||||
// properly handled during context truncation.
|
||||
func TestContextTruncationPreservesToolCalls(t *testing.T) {
|
||||
messages := []providers.Message{
|
||||
{Role: "system", Content: "System"},
|
||||
{Role: "user", Content: "Old message that should be dropped"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Recent tool use",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
Name: "important_tool",
|
||||
Arguments: map[string]any{"key": "value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Set a generous limit that should keep the tool call message
|
||||
maxRunes := 200
|
||||
truncated := TruncateContextSmart(messages, maxRunes)
|
||||
|
||||
// Verify tool call message is preserved
|
||||
foundToolCall := false
|
||||
for _, msg := range truncated {
|
||||
if len(msg.ToolCalls) > 0 && msg.ToolCalls[0].Name == "important_tool" {
|
||||
foundToolCall = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolCall {
|
||||
t.Error("Tool call message was not preserved during truncation")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user