feat(agent): steering (#1517)

* feat(agent): steering

* fix loop

* fix lint

* fix lint
This commit is contained in:
Mauro
2026-03-15 17:08:16 +01:00
committed by GitHub
parent 0f700a6bf0
commit 021aa7d6d5
7 changed files with 1589 additions and 102 deletions
+306
View File
@@ -0,0 +1,306 @@
# Steering — Implementation Specification
## Problem
When the agent is running (executing a chain of tool calls), the user has no way to redirect it. They must wait for the full cycle to complete before sending a new message. This creates a poor experience when the agent takes a wrong direction — the user watches it waste time on tools that are no longer relevant.
## Solution
Steering introduces a **message queue** that external callers can push into at any time. The agent loop polls this queue at well-defined checkpoints. When a steering message is found, the agent:
1. Stops executing further tools in the current batch
2. Injects the user's message into the conversation context
3. Calls the LLM again with the updated context
The user's intent reaches the model **as soon as the current tool finishes**, not after the entire turn completes.
## Architecture Overview
```mermaid
graph TD
subgraph External Callers
TG[Telegram]
DC[Discord]
SL[Slack]
end
subgraph AgentLoop
BUS[MessageBus]
DRAIN[drainBusToSteering goroutine]
SQ[steeringQueue]
RLI[runLLMIteration]
TE[Tool Execution Loop]
LLM[LLM Call]
end
TG -->|PublishInbound| BUS
DC -->|PublishInbound| BUS
SL -->|PublishInbound| BUS
BUS -->|ConsumeInbound while busy| DRAIN
DRAIN -->|Steer| SQ
RLI -->|1. initial poll| SQ
TE -->|2. poll after each tool| SQ
SQ -->|pendingMessages| RLI
RLI -->|inject into context| LLM
```
### Bus drain mechanism
Channels (Telegram, Discord, etc.) publish messages to the `MessageBus` via `PublishInbound`. Without additional wiring, these messages would sit in the bus buffer until the current `processMessage` finishes — meaning steering would never work for real users.
The solution: when `Run()` starts processing a message, it spawns a **drain goroutine** (`drainBusToSteering`) that keeps consuming from the bus and calling `Steer()`. When `processMessage` returns, the drain is canceled and normal consumption resumes.
```mermaid
sequenceDiagram
participant Bus
participant Run
participant Drain
participant AgentLoop
Run->>Bus: ConsumeInbound() → msg
Run->>Drain: spawn drainBusToSteering(ctx)
Run->>Run: processMessage(msg)
Note over Drain: running concurrently
Bus-->>Drain: ConsumeInbound() → newMsg
Drain->>AgentLoop: al.transcribeAudioInMessage(ctx, newMsg)
Drain->>AgentLoop: Steer(providers.Message{Content: newMsg.Content})
Run->>Run: processMessage returns
Run->>Drain: cancel context
Note over Drain: exits
```
## Data Structures
### steeringQueue
A thread-safe FIFO queue, private to the `agent` package.
| Field | Type | Description |
|-------|------|-------------|
| `mu` | `sync.Mutex` | Protects all access to `queue` and `mode` |
| `queue` | `[]providers.Message` | Pending steering messages |
| `mode` | `SteeringMode` | Dequeue strategy |
**Methods:**
| Method | Description |
|--------|-------------|
| `push(msg) error` | Appends a message to the queue. Returns an error if the queue is full (`MaxQueueSize`) |
| `dequeue() []Message` | Removes and returns messages according to `mode`. Returns `nil` if empty |
| `len() int` | Returns the current queue length |
| `setMode(mode)` | Updates the dequeue strategy |
| `getMode() SteeringMode` | Returns the current mode |
### SteeringMode
| Value | Constant | Behavior |
|-------|----------|----------|
| `"one-at-a-time"` | `SteeringOneAtATime` | `dequeue()` returns only the **first** message. Remaining messages stay in the queue for subsequent polls. |
| `"all"` | `SteeringAll` | `dequeue()` drains the **entire** queue and returns all messages at once. |
Default: `"one-at-a-time"`.
### processOptions extension
A new field was added to `processOptions`:
| Field | Type | Description |
|-------|------|-------------|
| `SkipInitialSteeringPoll` | `bool` | When `true`, the initial steering poll at loop start is skipped. Used by `Continue()` to avoid double-dequeuing. |
## Public API on AgentLoop
| Method | Signature | Description |
|--------|-----------|-------------|
| `Steer` | `Steer(msg providers.Message) error` | Enqueues a steering message. Returns an error if the queue is full or not initialized. Thread-safe, can be called from any goroutine. |
| `SteeringMode` | `SteeringMode() SteeringMode` | Returns the current dequeue mode. |
| `SetSteeringMode` | `SetSteeringMode(mode SteeringMode)` | Changes the dequeue mode at runtime. |
| `Continue` | `Continue(ctx, sessionKey, channel, chatID) (string, error)` | Resumes an idle agent using pending steering messages. Returns `""` if queue is empty. |
## Integration into the Agent Loop
### Where steering is wired
The steering queue lives as a field on `AgentLoop`:
```
AgentLoop
├── bus
├── cfg
├── registry
├── steering *steeringQueue ← new
├── ...
```
It is initialized in `NewAgentLoop` from `cfg.Agents.Defaults.SteeringMode`.
### Detailed flow through runLLMIteration
```mermaid
sequenceDiagram
participant User
participant AgentLoop
participant runLLMIteration
participant ToolExecution
participant LLM
User->>AgentLoop: Steer(message)
Note over AgentLoop: steeringQueue.push(message)
Note over runLLMIteration: ── iteration starts ──
runLLMIteration->>AgentLoop: dequeueSteeringMessages()<br/>[initial poll]
AgentLoop-->>runLLMIteration: [] (empty, or messages)
alt pendingMessages not empty
runLLMIteration->>runLLMIteration: inject into messages[]<br/>save to session
end
runLLMIteration->>LLM: Chat(messages, tools)
LLM-->>runLLMIteration: response with toolCalls[0..N]
loop for each tool call (sequential)
ToolExecution->>ToolExecution: execute tool[i]
ToolExecution->>ToolExecution: process result,<br/>append to messages[]
ToolExecution->>AgentLoop: dequeueSteeringMessages()
AgentLoop-->>ToolExecution: steeringMessages
alt steering found
opt remaining tools > 0
Note over ToolExecution: Mark tool[i+1..N-1] as<br/>"Skipped due to queued user message."
end
Note over ToolExecution: steeringAfterTools = steeringMessages
Note over ToolExecution: break out of tool loop
end
end
alt steeringAfterTools not empty
ToolExecution-->>runLLMIteration: pendingMessages = steeringAfterTools
Note over runLLMIteration: next iteration will inject<br/>these before calling LLM
end
Note over runLLMIteration: ── loop back to iteration start ──
```
### Polling checkpoints
| # | Location | When | Purpose |
|---|----------|------|---------|
| 1 | Top of `runLLMIteration`, before first LLM call | Once, at loop entry | Catch messages enqueued while the agent was still setting up context |
| 2 | After every tool completes (including the first and the last) | Immediately after each tool's result is processed | Interrupt the batch as early as possible — if steering is found and there are remaining tools, they are all skipped |
### What happens to skipped tools
When steering interrupts a tool batch after tool `[i]` completes, all tools from `[i+1]` to `[N-1]` are **not executed**. Instead, a tool result message is generated for each:
```json
{
"role": "tool",
"content": "Skipped due to queued user message.",
"tool_call_id": "<original_call_id>"
}
```
These results are:
- Appended to the conversation `messages[]`
- Saved to the session via `AddFullMessage`
This ensures the LLM knows which of its requested actions were not performed.
### Loop condition change
The iteration loop condition was changed from:
```go
for iteration < agent.MaxIterations
```
to:
```go
for iteration < agent.MaxIterations || len(pendingMessages) > 0
```
This allows **one extra iteration** when steering arrives right at the max iteration boundary, ensuring the steering message is always processed.
### Tool execution: parallel → sequential
**Before steering:** all tool calls in a batch were executed in parallel using `sync.WaitGroup`.
**After steering:** tool calls execute **sequentially**. This is required because steering must be polled between individual tool completions. A parallel execution model would not allow interrupting mid-batch.
> **Trade-off:** This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal. The benefit of being able to interrupt outweighs the cost.
### Why skip remaining tools (instead of letting them finish)
Two strategies were considered when a steering message is detected mid-batch:
1. **Skip remaining tools** (chosen) — stop executing, mark the rest as skipped, inject steering
2. **Finish all tools, then inject** — let everything run, append steering afterwards
Strategy 2 was rejected for three reasons:
**Irreversible side effects.** Tools can send emails, write files, spawn subagents, or call external APIs. If the user says "stop" or "change direction", those actions have already happened and cannot be undone.
| Tool batch | Steering | Skip (1) | Finish (2) |
|---|---|---|---|
| `[search, send_email]` | "don't send it" | Email not sent | Email sent |
| `[query, write_file, spawn]` | "wrong database" | Only query runs | File + subagent wasted |
| `[fetch₁, fetch₂, fetch₃, write]` | topic change | 1 fetch | 3 fetches + write, all discarded |
**Wasted latency.** Tools like web fetches and API calls take seconds each. In a 3-tool batch averaging 3-4s per tool, the user would wait 10+ seconds for work that gets thrown away.
**The LLM retains full awareness.** Skipped tools receive an explicit `"Skipped due to queued user message."` result, so the model knows what was not done and can decide whether to re-execute with the new context or take a different path.
## The Continue() method
`Continue` handles the case where the agent is **idle** (its last message was from the assistant) and the user has enqueued steering messages in the meantime.
```mermaid
flowchart TD
A[Continue called] --> B{dequeueSteeringMessages}
B -->|empty| C["return ('', nil)"]
B -->|messages found| D[Combine message contents]
D --> E["runAgentLoop with<br/>SkipInitialSteeringPoll: true"]
E --> F[Return response]
```
**Why `SkipInitialSteeringPoll: true`?** Because `Continue` already dequeued the messages itself. Without this flag, `runLLMIteration` would poll again at the start and find nothing (the queue is already empty), or worse, double-process if new messages arrived in the meantime.
## Configuration
```json
{
"agents": {
"defaults": {
"steering_mode": "one-at-a-time"
}
}
}
```
| Field | Type | Default | Env var |
|-------|------|---------|---------|
| `steering_mode` | `string` | `"one-at-a-time"` | `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` |
## Design decisions and trade-offs
| Decision | Rationale |
|----------|-----------|
| Sequential tool execution | Required for per-tool steering polls. Parallel execution cannot be interrupted mid-batch. |
| Polling-based (not channel/signal) | Keeps the implementation simple. No need for `select` or signal channels. The polling cost is negligible (mutex lock + slice length check). |
| `one-at-a-time` as default | Gives the model a chance to react to each steering message individually. More predictable behavior than dumping all messages at once. |
| Skipped tools get explicit error results | The LLM protocol requires a tool result for every tool call in the assistant message. Omitting them would cause API errors. The skip message also informs the model about what was not done. |
| `Continue()` uses `SkipInitialSteeringPoll` | Prevents race conditions and double-dequeuing when resuming an idle agent. |
| Queue stored on `AgentLoop`, not `AgentInstance` | Steering is a loop-level concern (it affects the iteration flow), not a per-agent concern. All agents share the same steering queue since `processMessage` is sequential. |
| Bus drain goroutine in `Run()` | Channels (Telegram, Discord, etc.) publish to the bus via `PublishInbound`. Without the drain, messages would queue in the bus channel buffer and only be consumed after `processMessage` returns — defeating the purpose of steering. The drain goroutine bridges the gap by consuming new bus messages and calling `Steer()` while the agent is busy. |
| Audio transcription before steering | The drain goroutine calls `al.transcribeAudioInMessage(ctx, msg)` before steering, so voice messages are converted to text before the agent sees them. If transcription fails, the error is silently discarded and the original message is steered as-is. |
| `MaxQueueSize = 10` | Prevents unbounded memory growth if a user sends many messages while the agent is busy. Excess messages are dropped with a warning. |
+166
View File
@@ -0,0 +1,166 @@
# Steering
Steering allows injecting messages into an already-running agent loop, interrupting it between tool calls without waiting for the entire cycle to complete.
## How it works
When the agent is executing a sequence of tool calls (e.g. the model requested 3 tools in a single turn), steering checks the queue **after each tool** completes. If it finds queued messages:
1. The remaining tools are **skipped** and receive `"Skipped due to queued user message."` as their result
2. The steering messages are **injected into the conversation context**
3. The model is called again with the updated context, including the user's steering message
```
User ──► Steer("change approach")
Agent Loop ▼
├─ tool[0] ✔ (executed)
├─ [polling] → steering found!
├─ tool[1] ✘ (skipped)
├─ tool[2] ✘ (skipped)
└─ new LLM turn with steering message
```
## Configuration
In `config.json`, under `agents.defaults`:
```json
{
"agents": {
"defaults": {
"steering_mode": "one-at-a-time"
}
}
}
```
### Modes
| Value | Behavior |
|-------|----------|
| `"one-at-a-time"` | **(default)** Dequeues only one message per polling cycle. If there are 3 messages in the queue, they are processed one at a time across 3 successive iterations. |
| `"all"` | Drains the entire queue in a single poll. All pending messages are injected into the context together. |
The environment variable `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` can be used as an alternative.
## Go API
### Steer — Send a steering message
```go
err := agentLoop.Steer(providers.Message{
Role: "user",
Content: "change direction, focus on X instead",
})
if err != nil {
// Queue is full (MaxQueueSize=10) or not initialized
}
```
The message is enqueued in a thread-safe manner. Returns an error if the queue is full or not initialized. It will be picked up at the next polling point (after the current tool finishes).
### SteeringMode / SetSteeringMode
```go
// Read the current mode
mode := agentLoop.SteeringMode() // SteeringOneAtATime | SteeringAll
// Change it at runtime
agentLoop.SetSteeringMode(agent.SteeringAll)
```
### Continue — Resume an idle agent
When the agent is idle (it has finished processing and its last message was from the assistant), `Continue` checks if there are steering messages in the queue and uses them to start a new cycle:
```go
response, err := agentLoop.Continue(ctx, sessionKey, channel, chatID)
if err != nil {
// Error (e.g. "no default agent available")
}
if response == "" {
// No steering messages in queue, the agent stays idle
}
```
`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input).
## Polling points in the loop
Steering is checked at **two points** in the agent cycle:
1. **At loop start** — before the first LLM call, to catch messages enqueued during setup
2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately
## Why remaining tools are skipped
When a steering message is detected, all remaining tools in the batch are skipped rather than executed. The alternative — let all tools finish and inject the steering message afterwards — was considered and rejected. Here is why.
### Preventing unwanted side effects
Tools can have **irreversible side effects**. If the user says "no, wait" while the agent is mid-batch, executing the remaining tools means those side effects happen anyway:
| Tool batch | Steering message | With skip | Without skip |
|---|---|---|---|
| `[web_search, send_email]` | "don't send it" | Email **not** sent | Email sent, damage done |
| `[query_db, write_file, spawn_agent]` | "use another database" | Only the query runs | File written + subagent spawned, all wasted |
| `[search₁, search₂, search₃, write_file]` | user changes topic entirely | 1 search | 3 searches + file write, all irrelevant |
### Avoiding wasted time
Tools that take seconds (web fetches, API calls, database queries) would all run to completion before the agent sees the user's correction. In a batch of 3 tools each taking 3-4 seconds, that's 10+ seconds of work that will be discarded.
With skipping, the agent reacts as soon as the current tool finishes — typically within a few seconds instead of waiting for the entire batch.
### The LLM gets full context
Skipped tools receive an explicit error result (`"Skipped due to queued user message."`), so the model knows exactly which actions were not performed. It can then decide whether to re-execute them with the new context, or take a different path entirely.
### Trade-off: sequential execution
Skipping requires tools to run **sequentially** (the previous implementation ran them in parallel). This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal compared to the benefit of being able to stop unwanted actions.
## Skipped tool result format
When steering interrupts a batch, each tool that was not executed receives a `tool` result with:
```
Content: "Skipped due to queued user message."
```
This is saved to the session via `AddFullMessage` and sent to the model, so it is aware that some requested actions were not performed.
## Full flow example
```
1. User: "search for info on X, write a file, and send me a message"
2. LLM responds with 3 tool calls: [web_search, write_file, message]
3. web_search is executed → result saved
4. [polling] → User called Steer("no, search for Y instead")
5. write_file is skipped → "Skipped due to queued user message."
message is skipped → "Skipped due to queued user message."
6. Message "search for Y instead" injected into context
7. LLM receives the full updated context and responds accordingly
```
## Automatic bus drain
When the agent loop (`Run()`) starts processing a message, it spawns a background goroutine that keeps consuming new inbound messages from the bus. These messages are automatically redirected into the steering queue via `Steer()`. This means:
- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy
- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is
- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes
## Notes
- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue.
- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually.
- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once.
- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped.
+183 -102
View File
@@ -48,6 +48,7 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
steering *steeringQueue
mu sync.RWMutex
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
@@ -55,15 +56,16 @@ type AgentLoop struct {
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
}
const (
@@ -105,6 +107,7 @@ func NewAgentLoop(
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
return al
@@ -257,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
continue
}
// Start a goroutine that drains the bus while processMessage is
// running. Any inbound messages that arrive during processing are
// redirected into the steering queue so the agent loop can pick
// them up between tool calls.
drainCtx, drainCancel := context.WithCancel(ctx)
go al.drainBusToSteering(drainCtx)
// Process message
func() {
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
@@ -272,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// }
// }()
defer drainCancel()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
@@ -318,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return nil
}
// drainBusToSteering continuously consumes inbound messages and redirects
// them into the steering queue. It runs in a goroutine while processMessage
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
for {
msg, ok := al.bus.ConsumeInbound(ctx)
if !ok {
return
}
// Transcribe audio if needed before steering, so the agent sees text.
msg, _ = al.transcribeAudioInMessage(ctx, msg)
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
map[string]any{
"channel": msg.Channel,
"sender_id": msg.SenderID,
"content_len": len(msg.Content),
})
if err := al.Steer(providers.Message{
Role: "user",
Content: msg.Content,
}); err != nil {
logger.WarnCF("agent", "Failed to steer message, will be lost",
map[string]any{
"error": err.Error(),
"channel": msg.Channel,
})
}
}
}
func (al *AgentLoop) Stop() {
al.running.Store(false)
}
@@ -999,6 +1044,16 @@ func (al *AgentLoop) runLLMIteration(
) (string, int, error) {
iteration := 0
var finalContent string
var pendingMessages []providers.Message
// Poll for steering messages at loop start (in case the user typed while
// the agent was setting up), unless the caller already provided initial
// steering messages (e.g. Continue).
if !opts.SkipInitialSteeringPoll {
if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 {
pendingMessages = msgs
}
}
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
@@ -1006,9 +1061,25 @@ func (al *AgentLoop) runLLMIteration(
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
for iteration < agent.MaxIterations {
for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
iteration++
// Inject pending steering messages into the conversation context
// before the next LLM call.
if len(pendingMessages) > 0 {
for _, pm := range pendingMessages {
messages = append(messages, pm)
agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"content_len": len(pm.Content),
})
}
pendingMessages = nil
}
logger.DebugCF("agent", "LLM iteration",
map[string]any{
"agent_id": agent.ID,
@@ -1251,107 +1322,83 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls in parallel
type indexedAgentResult struct {
result *tools.ToolResult
tc providers.ToolCall
}
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
var wg sync.WaitGroup
// Execute tool calls sequentially. After each tool completes, check
// for steering messages. If any are found, skip remaining tools.
var steeringAfterTools []providers.Message
for i, tc := range normalizedToolCalls {
agentResults[i].tc = tc
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Create async callback for tools that implement AsyncExecutor.
// When the background work completes, this publishes the result
// as an inbound system message so processSystemMessage routes it
// back to the user via the normal agent loop.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
// Send ForUser content directly to the user (immediate feedback),
// mirroring the synchronous tool execution path.
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: result.ForUser,
})
}
// Determine content for the agent loop (ForLLM or error).
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
if content == "" {
return
}
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": tc.Name,
"content_len": len(content),
"channel": opts.Channel,
})
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", tc.Name),
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
Content: content,
// Create async callback for tools that implement AsyncExecutor.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: result.ForUser,
})
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
agentResults[idx].result = toolResult
}(i, tc)
}
wg.Wait()
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
if content == "" {
return
}
// Process results in original order (send to user, save to session)
for _, r := range agentResults {
// Send ForUser content to user immediately if not Silent
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": tc.Name,
"content_len": len(content),
"channel": opts.Channel,
})
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", tc.Name),
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
Content: content,
})
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
// Process tool result
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: r.result.ForUser,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": r.tc.Name,
"content_len": len(r.result.ForUser),
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// If tool returned media refs, publish them as outbound media
if len(r.result.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(r.result.Media))
for _, ref := range r.result.Media {
if len(toolResult.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
@@ -1369,21 +1416,55 @@ func (al *AgentLoop) runLLMIteration(
})
}
// Determine content for LLM based on tool result
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: r.tc.ID,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
// Save tool result message to session
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
// After EVERY tool (including the first and last), check for
// steering messages. If found and there are remaining tools,
// skip them all.
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
remaining := len(normalizedToolCalls) - i - 1
if remaining > 0 {
logger.InfoCF("agent", "Steering interrupt: skipping remaining tools",
map[string]any{
"agent_id": agent.ID,
"completed": i + 1,
"skipped": remaining,
"total_tools": len(normalizedToolCalls),
"steering_count": len(steerMsgs),
})
// Mark remaining tool calls as skipped
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
toolResultMsg := providers.Message{
Role: "tool",
Content: "Skipped due to queued user message.",
ToolCallID: skippedTC.ID,
}
messages = append(messages, toolResultMsg)
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
}
}
steeringAfterTools = steerMsgs
break
}
}
// If steering messages were captured during tool execution, they
// become pendingMessages for the next iteration of the inner loop.
if len(steeringAfterTools) > 0 {
pendingMessages = steeringAfterTools
}
// Tick down TTL of discovered tools after processing tool results.
+188
View File
@@ -0,0 +1,188 @@
package agent
import (
"context"
"fmt"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
// SteeringMode controls how queued steering messages are dequeued.
type SteeringMode string
const (
// SteeringOneAtATime dequeues only the first queued message per poll.
SteeringOneAtATime SteeringMode = "one-at-a-time"
// SteeringAll drains the entire queue in a single poll.
SteeringAll SteeringMode = "all"
// MaxQueueSize number of possible messages in the Steering Queue
MaxQueueSize = 10
)
// parseSteeringMode normalizes a config string into a SteeringMode.
func parseSteeringMode(s string) SteeringMode {
switch s {
case "all":
return SteeringAll
default:
return SteeringOneAtATime
}
}
// steeringQueue is a thread-safe queue of user messages that can be injected
// into a running agent loop to interrupt it between tool calls.
type steeringQueue struct {
mu sync.Mutex
queue []providers.Message
mode SteeringMode
}
func newSteeringQueue(mode SteeringMode) *steeringQueue {
return &steeringQueue{
mode: mode,
}
}
// push enqueues a steering message.
func (sq *steeringQueue) push(msg providers.Message) error {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) >= MaxQueueSize {
return fmt.Errorf("steering queue is full")
}
sq.queue = append(sq.queue, msg)
return nil
}
// dequeue removes and returns pending steering messages according to the
// configured mode. Returns nil when the queue is empty.
func (sq *steeringQueue) dequeue() []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) == 0 {
return nil
}
switch sq.mode {
case SteeringAll:
msgs := sq.queue
sq.queue = nil
return msgs
default: // one-at-a-time
msg := sq.queue[0]
sq.queue[0] = providers.Message{} // Clear reference for GC
sq.queue = sq.queue[1:]
return []providers.Message{msg}
}
}
// len returns the number of queued messages.
func (sq *steeringQueue) len() int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queue)
}
// setMode updates the steering mode.
func (sq *steeringQueue) setMode(mode SteeringMode) {
sq.mu.Lock()
defer sq.mu.Unlock()
sq.mode = mode
}
// getMode returns the current steering mode.
func (sq *steeringQueue) getMode() SteeringMode {
sq.mu.Lock()
defer sq.mu.Unlock()
return sq.mode
}
// --- AgentLoop steering API ---
// Steer enqueues a user message to be injected into the currently running
// agent loop. The message will be picked up after the current tool finishes
// executing, causing any remaining tool calls in the batch to be skipped.
func (al *AgentLoop) Steer(msg providers.Message) error {
if al.steering == nil {
return fmt.Errorf("steering queue is not initialized")
}
if err := al.steering.push(msg); err != nil {
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
"error": err.Error(),
"role": msg.Role,
})
return err
}
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
"role": msg.Role,
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
})
return nil
}
// SteeringMode returns the current steering mode.
func (al *AgentLoop) SteeringMode() SteeringMode {
if al.steering == nil {
return SteeringOneAtATime
}
return al.steering.getMode()
}
// SetSteeringMode updates the steering mode.
func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
if al.steering == nil {
return
}
al.steering.setMode(mode)
}
// dequeueSteeringMessages is the internal method called by the agent loop
// to poll for steering messages. Returns nil when no messages are pending.
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeue()
}
// Continue resumes an idle agent by dequeuing any pending steering messages
// and running them through the agent loop. This is used when the agent's last
// message was from the assistant (i.e., it has stopped processing) and the
// user has since enqueued steering messages.
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
return "", nil
}
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent available")
}
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
contents = append(contents, msg.Content)
}
combinedContent := strings.Join(contents, "\n")
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
UserMessage: combinedContent,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SkipInitialSteeringPoll: true,
})
}
+744
View File
@@ -0,0 +1,744 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// --- steeringQueue unit tests ---
func TestSteeringQueue_PushDequeue_OneAtATime(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
sq.push(providers.Message{Role: "user", Content: "msg1"})
sq.push(providers.Message{Role: "user", Content: "msg2"})
sq.push(providers.Message{Role: "user", Content: "msg3"})
if sq.len() != 3 {
t.Fatalf("expected 3 messages, got %d", sq.len())
}
msgs := sq.dequeue()
if len(msgs) != 1 {
t.Fatalf("expected 1 message in one-at-a-time mode, got %d", len(msgs))
}
if msgs[0].Content != "msg1" {
t.Fatalf("expected 'msg1', got %q", msgs[0].Content)
}
if sq.len() != 2 {
t.Fatalf("expected 2 remaining, got %d", sq.len())
}
msgs = sq.dequeue()
if len(msgs) != 1 || msgs[0].Content != "msg2" {
t.Fatalf("expected 'msg2', got %v", msgs)
}
msgs = sq.dequeue()
if len(msgs) != 1 || msgs[0].Content != "msg3" {
t.Fatalf("expected 'msg3', got %v", msgs)
}
msgs = sq.dequeue()
if msgs != nil {
t.Fatalf("expected nil from empty queue, got %v", msgs)
}
}
func TestSteeringQueue_PushDequeue_All(t *testing.T) {
sq := newSteeringQueue(SteeringAll)
sq.push(providers.Message{Role: "user", Content: "msg1"})
sq.push(providers.Message{Role: "user", Content: "msg2"})
sq.push(providers.Message{Role: "user", Content: "msg3"})
msgs := sq.dequeue()
if len(msgs) != 3 {
t.Fatalf("expected 3 messages in all mode, got %d", len(msgs))
}
if msgs[0].Content != "msg1" || msgs[1].Content != "msg2" || msgs[2].Content != "msg3" {
t.Fatalf("unexpected messages: %v", msgs)
}
if sq.len() != 0 {
t.Fatalf("expected 0 remaining, got %d", sq.len())
}
msgs = sq.dequeue()
if msgs != nil {
t.Fatalf("expected nil from empty queue, got %v", msgs)
}
}
func TestSteeringQueue_EmptyDequeue(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
if msgs := sq.dequeue(); msgs != nil {
t.Fatalf("expected nil, got %v", msgs)
}
}
func TestSteeringQueue_SetMode(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
if sq.getMode() != SteeringOneAtATime {
t.Fatalf("expected one-at-a-time, got %v", sq.getMode())
}
sq.setMode(SteeringAll)
if sq.getMode() != SteeringAll {
t.Fatalf("expected all, got %v", sq.getMode())
}
// Push two messages and verify all-mode drains them
sq.push(providers.Message{Role: "user", Content: "a"})
sq.push(providers.Message{Role: "user", Content: "b"})
msgs := sq.dequeue()
if len(msgs) != 2 {
t.Fatalf("expected 2 messages after mode switch, got %d", len(msgs))
}
}
func TestSteeringQueue_ConcurrentAccess(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
var wg sync.WaitGroup
const n = MaxQueueSize
// Push from multiple goroutines
for i := 0; i < n; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
}(i)
}
wg.Wait()
if sq.len() != n {
t.Fatalf("expected %d messages, got %d", n, sq.len())
}
// Drain from multiple goroutines
var drained int
var mu sync.Mutex
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if msgs := sq.dequeue(); len(msgs) > 0 {
mu.Lock()
drained += len(msgs)
mu.Unlock()
}
}()
}
wg.Wait()
if drained != n {
t.Fatalf("expected to drain %d messages, got %d", n, drained)
}
}
func TestSteeringQueue_Overflow(t *testing.T) {
sq := newSteeringQueue(SteeringOneAtATime)
// Fill the queue up to its maximum capacity
for i := 0; i < MaxQueueSize; i++ {
err := sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
if err != nil {
t.Fatalf("unexpected error pushing message %d: %v", i, err)
}
}
// Sanity check: ensure the queue is actually full
if sq.len() != MaxQueueSize {
t.Fatalf("expected queue length %d, got %d", MaxQueueSize, sq.len())
}
// Attempt to push one more message, which MUST fail
err := sq.push(providers.Message{Role: "user", Content: "overflow_msg"})
// Assert the error happened and is the exact one we expect
if err == nil {
t.Fatal("expected an error when pushing to a full queue, but got nil")
}
expectedErr := "steering queue is full"
if err.Error() != expectedErr {
t.Errorf("expected error message %q, got %q", expectedErr, err.Error())
}
}
func TestParseSteeringMode(t *testing.T) {
tests := []struct {
input string
expected SteeringMode
}{
{"", SteeringOneAtATime},
{"one-at-a-time", SteeringOneAtATime},
{"all", SteeringAll},
{"unknown", SteeringOneAtATime},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
if got := parseSteeringMode(tt.input); got != tt.expected {
t.Fatalf("parseSteeringMode(%q) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}
// --- AgentLoop steering integration tests ---
func TestAgentLoop_Steer_Enqueues(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
if cfg == nil {
t.Fatal("expected config to be initialized")
}
if msgBus == nil {
t.Fatal("expected message bus to be initialized")
}
if provider == nil {
t.Fatal("expected provider to be initialized")
}
al.Steer(providers.Message{Role: "user", Content: "interrupt me"})
if al.steering.len() != 1 {
t.Fatalf("expected 1 steering message, got %d", al.steering.len())
}
msgs := al.dequeueSteeringMessages()
if len(msgs) != 1 || msgs[0].Content != "interrupt me" {
t.Fatalf("unexpected dequeued message: %v", msgs)
}
}
func TestAgentLoop_SteeringMode_GetSet(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
if cfg == nil {
t.Fatal("expected config to be initialized")
}
if msgBus == nil {
t.Fatal("expected message bus to be initialized")
}
if provider == nil {
t.Fatal("expected provider to be initialized")
}
if al.SteeringMode() != SteeringOneAtATime {
t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode())
}
al.SetSteeringMode(SteeringAll)
if al.SteeringMode() != SteeringAll {
t.Fatalf("expected all mode, got %v", al.SteeringMode())
}
}
func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
SteeringMode: "all",
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
if al.SteeringMode() != SteeringAll {
t.Fatalf("expected 'all' mode from config, got %v", al.SteeringMode())
}
}
func TestAgentLoop_Continue_NoMessages(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
if cfg == nil {
t.Fatal("expected config to be initialized")
}
if msgBus == nil {
t.Fatal("expected message bus to be initialized")
}
if provider == nil {
t.Fatal("expected provider to be initialized")
}
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp != "" {
t.Fatalf("expected empty response for no steering messages, got %q", resp)
}
}
func TestAgentLoop_Continue_WithMessages(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "continued response"}
al := NewAgentLoop(cfg, msgBus, provider)
al.Steer(providers.Message{Role: "user", Content: "new direction"})
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp != "continued response" {
t.Fatalf("expected 'continued response', got %q", resp)
}
}
// slowTool simulates a tool that takes some time to execute.
type slowTool struct {
name string
duration time.Duration
execCh chan struct{} // closed when Execute starts
}
func (t *slowTool) Name() string { return t.name }
func (t *slowTool) Description() string { return "slow tool for testing" }
func (t *slowTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if t.execCh != nil {
close(t.execCh)
}
time.Sleep(t.duration)
return tools.SilentResult(fmt.Sprintf("executed %s", t.name))
}
// toolCallProvider returns an LLM response with tool calls on the first call,
// then a direct response on subsequent calls.
type toolCallProvider struct {
mu sync.Mutex
calls int
toolCalls []providers.ToolCall
finalResp string
}
func (m *toolCallProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.calls == 1 && len(m.toolCalls) > 0 {
return &providers.LLMResponse{
Content: "",
ToolCalls: m.toolCalls,
}, nil
}
return &providers.LLMResponse{
Content: m.finalResp,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *toolCallProvider) GetDefaultModel() string {
return "tool-call-mock"
}
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
tool1ExecCh := make(chan struct{})
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "tool_one",
Function: &providers.FunctionCall{
Name: "tool_one",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "tool_two",
Function: &providers.FunctionCall{
Name: "tool_two",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "steered response",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
// Start processing in a goroutine
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do something",
"test-session",
"test",
"chat1",
)
resultCh <- result{resp, err}
}()
// Wait for tool_one to start executing, then enqueue a steering message
select {
case <-tool1ExecCh:
// tool_one has started executing
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tool_one to start")
}
al.Steer(providers.Message{Role: "user", Content: "change course"})
// Get the result
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "steered response" {
t.Fatalf("expected 'steered response', got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for agent loop to complete")
}
// The provider should have been called twice:
// 1. first call returned tool calls
// 2. second call (after steering) returned the final response
provider.mu.Lock()
calls := provider.calls
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
}
func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Provider that captures messages it receives
var capturedMessages []providers.Message
var capMu sync.Mutex
provider := &capturingMockProvider{
response: "ack",
captureFn: func(msgs []providers.Message) {
capMu.Lock()
capturedMessages = make([]providers.Message, len(msgs))
copy(capturedMessages, msgs)
capMu.Unlock()
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
// Enqueue a steering message before processing starts
al.Steer(providers.Message{Role: "user", Content: "pre-enqueued steering"})
// Process a normal message - the initial steering poll should inject the steering message
_, err = al.ProcessDirectWithChannel(
context.Background(),
"initial message",
"test-session",
"test",
"chat1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// The steering message should have been injected into the conversation
capMu.Lock()
msgs := capturedMessages
capMu.Unlock()
// Look for the steering message in the captured messages
found := false
for _, m := range msgs {
if m.Content == "pre-enqueued steering" {
found = true
break
}
}
if !found {
t.Fatal("expected steering message to be injected into conversation context")
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
calls int
captureFn func([]providers.Message)
}
func (m *capturingMockProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.captureFn != nil {
m.captureFn(messages)
}
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *capturingMockProvider) GetDefaultModel() string {
return "capturing-mock"
}
func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
execCh := make(chan struct{})
tool1 := &slowTool{name: "slow_tool", duration: 50 * time.Millisecond, execCh: execCh}
tool2 := &slowTool{name: "skipped_tool", duration: 50 * time.Millisecond}
// Provider that captures messages on the second call (after tools)
var secondCallMessages []providers.Message
var capMu sync.Mutex
callCount := 0
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "slow_tool",
Function: &providers.FunctionCall{
Name: "slow_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "skipped_tool",
Function: &providers.FunctionCall{
Name: "skipped_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "done",
}
// Wrap provider to capture messages on second call
wrappedProvider := &wrappingProvider{
inner: provider,
onChat: func(msgs []providers.Message) {
capMu.Lock()
callCount++
if callCount >= 2 {
secondCallMessages = make([]providers.Message, len(msgs))
copy(secondCallMessages, msgs)
}
capMu.Unlock()
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, wrappedProvider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
resultCh := make(chan string, 1)
go func() {
resp, _ := al.ProcessDirectWithChannel(
context.Background(), "go", "test-session", "test", "chat1",
)
resultCh <- resp
}()
<-execCh
al.Steer(providers.Message{Role: "user", Content: "interrupt!"})
select {
case <-resultCh:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
// Check that the skipped tool result message is in the conversation
capMu.Lock()
msgs := secondCallMessages
capMu.Unlock()
foundSkipped := false
for _, m := range msgs {
if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." {
foundSkipped = true
break
}
}
if !foundSkipped {
// Log what we actually got
for i, m := range msgs {
t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80))
}
t.Fatal("expected skipped tool result for call_2")
}
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
// wrappingProvider wraps another provider to hook into Chat calls.
type wrappingProvider struct {
inner providers.LLMProvider
onChat func([]providers.Message)
}
func (w *wrappingProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
if w.onChat != nil {
w.onChat(messages)
}
return w.inner.Chat(ctx, messages, tools, model, opts)
}
func (w *wrappingProvider) GetDefaultModel() string {
return w.inner.GetDefaultModel()
}
// Ensure NormalizeToolCall handles our test tool calls.
func init() {
// This is a no-op init; we just need the tool call tests to work
// with the proper argument serialization.
_ = json.Marshal
}
+1
View File
@@ -234,6 +234,7 @@ type AgentDefaults struct {
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
+1
View File
@@ -35,6 +35,7 @@ func DefaultConfig() *Config {
MaxToolIterations: 50,
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
SteeringMode: "one-at-a-time",
},
},
Bindings: []AgentBinding{},