feat(agent): steering (#1517)

* feat(agent): steering

* fix loop

* fix lint

* fix lint
This commit is contained in:
Mauro
2026-03-15 17:08:16 +01:00
committed by GitHub
parent 0f700a6bf0
commit 021aa7d6d5
7 changed files with 1589 additions and 102 deletions
+183 -102
View File
@@ -48,6 +48,7 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
steering *steeringQueue
mu sync.RWMutex
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
@@ -55,15 +56,16 @@ type AgentLoop struct {
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
}
const (
@@ -105,6 +107,7 @@ func NewAgentLoop(
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
return al
@@ -257,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
continue
}
// Start a goroutine that drains the bus while processMessage is
// running. Any inbound messages that arrive during processing are
// redirected into the steering queue so the agent loop can pick
// them up between tool calls.
drainCtx, drainCancel := context.WithCancel(ctx)
go al.drainBusToSteering(drainCtx)
// Process message
func() {
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
@@ -272,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// }
// }()
defer drainCancel()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
@@ -318,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return nil
}
// drainBusToSteering continuously consumes inbound messages and redirects
// them into the steering queue. It runs in a goroutine while processMessage
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
for {
msg, ok := al.bus.ConsumeInbound(ctx)
if !ok {
return
}
// Transcribe audio if needed before steering, so the agent sees text.
msg, _ = al.transcribeAudioInMessage(ctx, msg)
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
map[string]any{
"channel": msg.Channel,
"sender_id": msg.SenderID,
"content_len": len(msg.Content),
})
if err := al.Steer(providers.Message{
Role: "user",
Content: msg.Content,
}); err != nil {
logger.WarnCF("agent", "Failed to steer message, will be lost",
map[string]any{
"error": err.Error(),
"channel": msg.Channel,
})
}
}
}
func (al *AgentLoop) Stop() {
al.running.Store(false)
}
@@ -999,6 +1044,16 @@ func (al *AgentLoop) runLLMIteration(
) (string, int, error) {
iteration := 0
var finalContent string
var pendingMessages []providers.Message
// Poll for steering messages at loop start (in case the user typed while
// the agent was setting up), unless the caller already provided initial
// steering messages (e.g. Continue).
if !opts.SkipInitialSteeringPoll {
if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 {
pendingMessages = msgs
}
}
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
@@ -1006,9 +1061,25 @@ func (al *AgentLoop) runLLMIteration(
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
for iteration < agent.MaxIterations {
for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
iteration++
// Inject pending steering messages into the conversation context
// before the next LLM call.
if len(pendingMessages) > 0 {
for _, pm := range pendingMessages {
messages = append(messages, pm)
agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"content_len": len(pm.Content),
})
}
pendingMessages = nil
}
logger.DebugCF("agent", "LLM iteration",
map[string]any{
"agent_id": agent.ID,
@@ -1251,107 +1322,83 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls in parallel
type indexedAgentResult struct {
result *tools.ToolResult
tc providers.ToolCall
}
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
var wg sync.WaitGroup
// Execute tool calls sequentially. After each tool completes, check
// for steering messages. If any are found, skip remaining tools.
var steeringAfterTools []providers.Message
for i, tc := range normalizedToolCalls {
agentResults[i].tc = tc
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Create async callback for tools that implement AsyncExecutor.
// When the background work completes, this publishes the result
// as an inbound system message so processSystemMessage routes it
// back to the user via the normal agent loop.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
// Send ForUser content directly to the user (immediate feedback),
// mirroring the synchronous tool execution path.
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: result.ForUser,
})
}
// Determine content for the agent loop (ForLLM or error).
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
if content == "" {
return
}
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": tc.Name,
"content_len": len(content),
"channel": opts.Channel,
})
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", tc.Name),
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
Content: content,
// Create async callback for tools that implement AsyncExecutor.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: result.ForUser,
})
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
agentResults[idx].result = toolResult
}(i, tc)
}
wg.Wait()
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
if content == "" {
return
}
// Process results in original order (send to user, save to session)
for _, r := range agentResults {
// Send ForUser content to user immediately if not Silent
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": tc.Name,
"content_len": len(content),
"channel": opts.Channel,
})
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", tc.Name),
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
Content: content,
})
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
// Process tool result
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: r.result.ForUser,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": r.tc.Name,
"content_len": len(r.result.ForUser),
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// If tool returned media refs, publish them as outbound media
if len(r.result.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(r.result.Media))
for _, ref := range r.result.Media {
if len(toolResult.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
@@ -1369,21 +1416,55 @@ func (al *AgentLoop) runLLMIteration(
})
}
// Determine content for LLM based on tool result
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: r.tc.ID,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
// Save tool result message to session
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
// After EVERY tool (including the first and last), check for
// steering messages. If found and there are remaining tools,
// skip them all.
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
remaining := len(normalizedToolCalls) - i - 1
if remaining > 0 {
logger.InfoCF("agent", "Steering interrupt: skipping remaining tools",
map[string]any{
"agent_id": agent.ID,
"completed": i + 1,
"skipped": remaining,
"total_tools": len(normalizedToolCalls),
"steering_count": len(steerMsgs),
})
// Mark remaining tool calls as skipped
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
toolResultMsg := providers.Message{
Role: "tool",
Content: "Skipped due to queued user message.",
ToolCallID: skippedTC.ID,
}
messages = append(messages, toolResultMsg)
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
}
}
steeringAfterTools = steerMsgs
break
}
}
// If steering messages were captured during tool execution, they
// become pendingMessages for the next iteration of the inner loop.
if len(steeringAfterTools) > 0 {
pendingMessages = steeringAfterTools
}
// Tick down TTL of discovered tools after processing tool results.
+188
View File
@@ -0,0 +1,188 @@
package agent
import (
"context"
"fmt"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
// SteeringMode controls how queued steering messages are dequeued.
type SteeringMode string
const (
// SteeringOneAtATime dequeues only the first queued message per poll.
SteeringOneAtATime SteeringMode = "one-at-a-time"
// SteeringAll drains the entire queue in a single poll.
SteeringAll SteeringMode = "all"
// MaxQueueSize number of possible messages in the Steering Queue
MaxQueueSize = 10
)
// parseSteeringMode normalizes a config string into a SteeringMode.
func parseSteeringMode(s string) SteeringMode {
switch s {
case "all":
return SteeringAll
default:
return SteeringOneAtATime
}
}
// steeringQueue is a thread-safe queue of user messages that can be injected
// into a running agent loop to interrupt it between tool calls.
type steeringQueue struct {
mu sync.Mutex
queue []providers.Message
mode SteeringMode
}
func newSteeringQueue(mode SteeringMode) *steeringQueue {
return &steeringQueue{
mode: mode,
}
}
// push enqueues a steering message.
func (sq *steeringQueue) push(msg providers.Message) error {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) >= MaxQueueSize {
return fmt.Errorf("steering queue is full")
}
sq.queue = append(sq.queue, msg)
return nil
}
// dequeue removes and returns pending steering messages according to the
// configured mode. Returns nil when the queue is empty.
func (sq *steeringQueue) dequeue() []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) == 0 {
return nil
}
switch sq.mode {
case SteeringAll:
msgs := sq.queue
sq.queue = nil
return msgs
default: // one-at-a-time
msg := sq.queue[0]
sq.queue[0] = providers.Message{} // Clear reference for GC
sq.queue = sq.queue[1:]
return []providers.Message{msg}
}
}
// len returns the number of queued messages.
func (sq *steeringQueue) len() int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queue)
}
// setMode updates the steering mode.
func (sq *steeringQueue) setMode(mode SteeringMode) {
sq.mu.Lock()
defer sq.mu.Unlock()
sq.mode = mode
}
// getMode returns the current steering mode.
func (sq *steeringQueue) getMode() SteeringMode {
sq.mu.Lock()
defer sq.mu.Unlock()
return sq.mode
}
// --- AgentLoop steering API ---
// Steer enqueues a user message to be injected into the currently running
// agent loop. The message will be picked up after the current tool finishes
// executing, causing any remaining tool calls in the batch to be skipped.
func (al *AgentLoop) Steer(msg providers.Message) error {
if al.steering == nil {
return fmt.Errorf("steering queue is not initialized")
}
if err := al.steering.push(msg); err != nil {
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
"error": err.Error(),
"role": msg.Role,
})
return err
}
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
"role": msg.Role,
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
})
return nil
}
// SteeringMode returns the current steering mode.
func (al *AgentLoop) SteeringMode() SteeringMode {
if al.steering == nil {
return SteeringOneAtATime
}
return al.steering.getMode()
}
// SetSteeringMode updates the steering mode.
func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
if al.steering == nil {
return
}
al.steering.setMode(mode)
}
// dequeueSteeringMessages is the internal method called by the agent loop
// to poll for steering messages. Returns nil when no messages are pending.
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeue()
}
// Continue resumes an idle agent by dequeuing any pending steering messages
// and running them through the agent loop. This is used when the agent's last
// message was from the assistant (i.e., it has stopped processing) and the
// user has since enqueued steering messages.
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
return "", nil
}
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent available")
}
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
contents = append(contents, msg.Content)
}
combinedContent := strings.Join(contents, "\n")
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
UserMessage: combinedContent,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SkipInitialSteeringPoll: true,
})
}
+744
View File
@@ -0,0 +1,744 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// --- steeringQueue unit tests ---
func TestSteeringQueue_PushDequeue_OneAtATime(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
sq.push(providers.Message{Role: "user", Content: "msg1"})
sq.push(providers.Message{Role: "user", Content: "msg2"})
sq.push(providers.Message{Role: "user", Content: "msg3"})
if sq.len() != 3 {
t.Fatalf("expected 3 messages, got %d", sq.len())
}
msgs := sq.dequeue()
if len(msgs) != 1 {
t.Fatalf("expected 1 message in one-at-a-time mode, got %d", len(msgs))
}
if msgs[0].Content != "msg1" {
t.Fatalf("expected 'msg1', got %q", msgs[0].Content)
}
if sq.len() != 2 {
t.Fatalf("expected 2 remaining, got %d", sq.len())
}
msgs = sq.dequeue()
if len(msgs) != 1 || msgs[0].Content != "msg2" {
t.Fatalf("expected 'msg2', got %v", msgs)
}
msgs = sq.dequeue()
if len(msgs) != 1 || msgs[0].Content != "msg3" {
t.Fatalf("expected 'msg3', got %v", msgs)
}
msgs = sq.dequeue()
if msgs != nil {
t.Fatalf("expected nil from empty queue, got %v", msgs)
}
}
func TestSteeringQueue_PushDequeue_All(t *testing.T) {
sq := newSteeringQueue(SteeringAll)
sq.push(providers.Message{Role: "user", Content: "msg1"})
sq.push(providers.Message{Role: "user", Content: "msg2"})
sq.push(providers.Message{Role: "user", Content: "msg3"})
msgs := sq.dequeue()
if len(msgs) != 3 {
t.Fatalf("expected 3 messages in all mode, got %d", len(msgs))
}
if msgs[0].Content != "msg1" || msgs[1].Content != "msg2" || msgs[2].Content != "msg3" {
t.Fatalf("unexpected messages: %v", msgs)
}
if sq.len() != 0 {
t.Fatalf("expected 0 remaining, got %d", sq.len())
}
msgs = sq.dequeue()
if msgs != nil {
t.Fatalf("expected nil from empty queue, got %v", msgs)
}
}
func TestSteeringQueue_EmptyDequeue(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
if msgs := sq.dequeue(); msgs != nil {
t.Fatalf("expected nil, got %v", msgs)
}
}
func TestSteeringQueue_SetMode(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
if sq.getMode() != SteeringOneAtATime {
t.Fatalf("expected one-at-a-time, got %v", sq.getMode())
}
sq.setMode(SteeringAll)
if sq.getMode() != SteeringAll {
t.Fatalf("expected all, got %v", sq.getMode())
}
// Push two messages and verify all-mode drains them
sq.push(providers.Message{Role: "user", Content: "a"})
sq.push(providers.Message{Role: "user", Content: "b"})
msgs := sq.dequeue()
if len(msgs) != 2 {
t.Fatalf("expected 2 messages after mode switch, got %d", len(msgs))
}
}
func TestSteeringQueue_ConcurrentAccess(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
var wg sync.WaitGroup
const n = MaxQueueSize
// Push from multiple goroutines
for i := 0; i < n; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
}(i)
}
wg.Wait()
if sq.len() != n {
t.Fatalf("expected %d messages, got %d", n, sq.len())
}
// Drain from multiple goroutines
var drained int
var mu sync.Mutex
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if msgs := sq.dequeue(); len(msgs) > 0 {
mu.Lock()
drained += len(msgs)
mu.Unlock()
}
}()
}
wg.Wait()
if drained != n {
t.Fatalf("expected to drain %d messages, got %d", n, drained)
}
}
func TestSteeringQueue_Overflow(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
// Fill the queue up to its maximum capacity
for i := 0; i < MaxQueueSize; i++ {
err := sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
if err != nil {
t.Fatalf("unexpected error pushing message %d: %v", i, err)
}
}
// Sanity check: ensure the queue is actually full
if sq.len() != MaxQueueSize {
t.Fatalf("expected queue length %d, got %d", MaxQueueSize, sq.len())
}
// Attempt to push one more message, which MUST fail
err := sq.push(providers.Message{Role: "user", Content: "overflow_msg"})
// Assert the error happened and is the exact one we expect
if err == nil {
t.Fatal("expected an error when pushing to a full queue, but got nil")
}
expectedErr := "steering queue is full"
if err.Error() != expectedErr {
t.Errorf("expected error message %q, got %q", expectedErr, err.Error())
}
}
func TestParseSteeringMode(t *testing.T) {
tests := []struct {
input string
expected SteeringMode
}{
{"", SteeringOneAtATime},
{"one-at-a-time", SteeringOneAtATime},
{"all", SteeringAll},
{"unknown", SteeringOneAtATime},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
if got := parseSteeringMode(tt.input); got != tt.expected {
t.Fatalf("parseSteeringMode(%q) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}
// --- AgentLoop steering integration tests ---
func TestAgentLoop_Steer_Enqueues(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
if cfg == nil {
t.Fatal("expected config to be initialized")
}
if msgBus == nil {
t.Fatal("expected message bus to be initialized")
}
if provider == nil {
t.Fatal("expected provider to be initialized")
}
al.Steer(providers.Message{Role: "user", Content: "interrupt me"})
if al.steering.len() != 1 {
t.Fatalf("expected 1 steering message, got %d", al.steering.len())
}
msgs := al.dequeueSteeringMessages()
if len(msgs) != 1 || msgs[0].Content != "interrupt me" {
t.Fatalf("unexpected dequeued message: %v", msgs)
}
}
func TestAgentLoop_SteeringMode_GetSet(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
if cfg == nil {
t.Fatal("expected config to be initialized")
}
if msgBus == nil {
t.Fatal("expected message bus to be initialized")
}
if provider == nil {
t.Fatal("expected provider to be initialized")
}
if al.SteeringMode() != SteeringOneAtATime {
t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode())
}
al.SetSteeringMode(SteeringAll)
if al.SteeringMode() != SteeringAll {
t.Fatalf("expected all mode, got %v", al.SteeringMode())
}
}
func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
SteeringMode: "all",
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
if al.SteeringMode() != SteeringAll {
t.Fatalf("expected 'all' mode from config, got %v", al.SteeringMode())
}
}
func TestAgentLoop_Continue_NoMessages(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
if cfg == nil {
t.Fatal("expected config to be initialized")
}
if msgBus == nil {
t.Fatal("expected message bus to be initialized")
}
if provider == nil {
t.Fatal("expected provider to be initialized")
}
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp != "" {
t.Fatalf("expected empty response for no steering messages, got %q", resp)
}
}
func TestAgentLoop_Continue_WithMessages(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "continued response"}
al := NewAgentLoop(cfg, msgBus, provider)
al.Steer(providers.Message{Role: "user", Content: "new direction"})
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp != "continued response" {
t.Fatalf("expected 'continued response', got %q", resp)
}
}
// slowTool simulates a tool that takes some time to execute.
type slowTool struct {
name string
duration time.Duration
execCh chan struct{} // closed when Execute starts
}
func (t *slowTool) Name() string { return t.name }
func (t *slowTool) Description() string { return "slow tool for testing" }
func (t *slowTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if t.execCh != nil {
close(t.execCh)
}
time.Sleep(t.duration)
return tools.SilentResult(fmt.Sprintf("executed %s", t.name))
}
// toolCallProvider returns an LLM response with tool calls on the first call,
// then a direct response on subsequent calls.
type toolCallProvider struct {
mu sync.Mutex
calls int
toolCalls []providers.ToolCall
finalResp string
}
func (m *toolCallProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.calls == 1 && len(m.toolCalls) > 0 {
return &providers.LLMResponse{
Content: "",
ToolCalls: m.toolCalls,
}, nil
}
return &providers.LLMResponse{
Content: m.finalResp,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *toolCallProvider) GetDefaultModel() string {
return "tool-call-mock"
}
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
tool1ExecCh := make(chan struct{})
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "tool_one",
Function: &providers.FunctionCall{
Name: "tool_one",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "tool_two",
Function: &providers.FunctionCall{
Name: "tool_two",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "steered response",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
// Start processing in a goroutine
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do something",
"test-session",
"test",
"chat1",
)
resultCh <- result{resp, err}
}()
// Wait for tool_one to start executing, then enqueue a steering message
select {
case <-tool1ExecCh:
// tool_one has started executing
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tool_one to start")
}
al.Steer(providers.Message{Role: "user", Content: "change course"})
// Get the result
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "steered response" {
t.Fatalf("expected 'steered response', got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for agent loop to complete")
}
// The provider should have been called twice:
// 1. first call returned tool calls
// 2. second call (after steering) returned the final response
provider.mu.Lock()
calls := provider.calls
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
}
func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Provider that captures messages it receives
var capturedMessages []providers.Message
var capMu sync.Mutex
provider := &capturingMockProvider{
response: "ack",
captureFn: func(msgs []providers.Message) {
capMu.Lock()
capturedMessages = make([]providers.Message, len(msgs))
copy(capturedMessages, msgs)
capMu.Unlock()
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
// Enqueue a steering message before processing starts
al.Steer(providers.Message{Role: "user", Content: "pre-enqueued steering"})
// Process a normal message - the initial steering poll should inject the steering message
_, err = al.ProcessDirectWithChannel(
context.Background(),
"initial message",
"test-session",
"test",
"chat1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// The steering message should have been injected into the conversation
capMu.Lock()
msgs := capturedMessages
capMu.Unlock()
// Look for the steering message in the captured messages
found := false
for _, m := range msgs {
if m.Content == "pre-enqueued steering" {
found = true
break
}
}
if !found {
t.Fatal("expected steering message to be injected into conversation context")
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
calls int
captureFn func([]providers.Message)
}
func (m *capturingMockProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.captureFn != nil {
m.captureFn(messages)
}
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *capturingMockProvider) GetDefaultModel() string {
return "capturing-mock"
}
func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
execCh := make(chan struct{})
tool1 := &slowTool{name: "slow_tool", duration: 50 * time.Millisecond, execCh: execCh}
tool2 := &slowTool{name: "skipped_tool", duration: 50 * time.Millisecond}
// Provider that captures messages on the second call (after tools)
var secondCallMessages []providers.Message
var capMu sync.Mutex
callCount := 0
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "slow_tool",
Function: &providers.FunctionCall{
Name: "slow_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "skipped_tool",
Function: &providers.FunctionCall{
Name: "skipped_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "done",
}
// Wrap provider to capture messages on second call
wrappedProvider := &wrappingProvider{
inner: provider,
onChat: func(msgs []providers.Message) {
capMu.Lock()
callCount++
if callCount >= 2 {
secondCallMessages = make([]providers.Message, len(msgs))
copy(secondCallMessages, msgs)
}
capMu.Unlock()
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, wrappedProvider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
resultCh := make(chan string, 1)
go func() {
resp, _ := al.ProcessDirectWithChannel(
context.Background(), "go", "test-session", "test", "chat1",
)
resultCh <- resp
}()
<-execCh
al.Steer(providers.Message{Role: "user", Content: "interrupt!"})
select {
case <-resultCh:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
// Check that the skipped tool result message is in the conversation
capMu.Lock()
msgs := secondCallMessages
capMu.Unlock()
foundSkipped := false
for _, m := range msgs {
if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." {
foundSkipped = true
break
}
}
if !foundSkipped {
// Log what we actually got
for i, m := range msgs {
t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80))
}
t.Fatal("expected skipped tool result for call_2")
}
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
// wrappingProvider wraps another provider to hook into Chat calls.
type wrappingProvider struct {
inner providers.LLMProvider
onChat func([]providers.Message)
}
func (w *wrappingProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
if w.onChat != nil {
w.onChat(messages)
}
return w.inner.Chat(ctx, messages, tools, model, opts)
}
func (w *wrappingProvider) GetDefaultModel() string {
return w.inner.GetDefaultModel()
}
// Ensure NormalizeToolCall handles our test tool calls.
func init() {
// This is a no-op init; we just need the tool call tests to work
// with the proper argument serialization.
_ = json.Marshal
}
+1
View File
@@ -234,6 +234,7 @@ type AgentDefaults struct {
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
+1
View File
@@ -35,6 +35,7 @@ func DefaultConfig() *Config {
MaxToolIterations: 50,
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
SteeringMode: "one-at-a-time",
},
},
Bindings: []AgentBinding{},