mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): steering (#1517)
* feat(agent): steering * fix loop * fix lint * fix lint
This commit is contained in:
@@ -0,0 +1,306 @@
|
||||
# Steering — Implementation Specification
|
||||
|
||||
## Problem
|
||||
|
||||
When the agent is running (executing a chain of tool calls), the user has no way to redirect it. They must wait for the full cycle to complete before sending a new message. This creates a poor experience when the agent takes a wrong direction — the user watches it waste time on tools that are no longer relevant.
|
||||
|
||||
## Solution
|
||||
|
||||
Steering introduces a **message queue** that external callers can push into at any time. The agent loop polls this queue at well-defined checkpoints. When a steering message is found, the agent:
|
||||
|
||||
1. Stops executing further tools in the current batch
|
||||
2. Injects the user's message into the conversation context
|
||||
3. Calls the LLM again with the updated context
|
||||
|
||||
The user's intent reaches the model **as soon as the current tool finishes**, not after the entire turn completes.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph External Callers
|
||||
TG[Telegram]
|
||||
DC[Discord]
|
||||
SL[Slack]
|
||||
end
|
||||
|
||||
subgraph AgentLoop
|
||||
BUS[MessageBus]
|
||||
DRAIN[drainBusToSteering goroutine]
|
||||
SQ[steeringQueue]
|
||||
RLI[runLLMIteration]
|
||||
TE[Tool Execution Loop]
|
||||
LLM[LLM Call]
|
||||
end
|
||||
|
||||
TG -->|PublishInbound| BUS
|
||||
DC -->|PublishInbound| BUS
|
||||
SL -->|PublishInbound| BUS
|
||||
|
||||
BUS -->|ConsumeInbound while busy| DRAIN
|
||||
DRAIN -->|Steer| SQ
|
||||
|
||||
RLI -->|1. initial poll| SQ
|
||||
TE -->|2. poll after each tool| SQ
|
||||
|
||||
SQ -->|pendingMessages| RLI
|
||||
RLI -->|inject into context| LLM
|
||||
```
|
||||
|
||||
### Bus drain mechanism
|
||||
|
||||
Channels (Telegram, Discord, etc.) publish messages to the `MessageBus` via `PublishInbound`. Without additional wiring, these messages would sit in the bus buffer until the current `processMessage` finishes — meaning steering would never work for real users.
|
||||
|
||||
The solution: when `Run()` starts processing a message, it spawns a **drain goroutine** (`drainBusToSteering`) that keeps consuming from the bus and calling `Steer()`. When `processMessage` returns, the drain is canceled and normal consumption resumes.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Bus
|
||||
participant Run
|
||||
participant Drain
|
||||
participant AgentLoop
|
||||
|
||||
Run->>Bus: ConsumeInbound() → msg
|
||||
Run->>Drain: spawn drainBusToSteering(ctx)
|
||||
Run->>Run: processMessage(msg)
|
||||
|
||||
Note over Drain: running concurrently
|
||||
|
||||
Bus-->>Drain: ConsumeInbound() → newMsg
|
||||
Drain->>AgentLoop: al.transcribeAudioInMessage(ctx, newMsg)
|
||||
Drain->>AgentLoop: Steer(providers.Message{Content: newMsg.Content})
|
||||
|
||||
Run->>Run: processMessage returns
|
||||
Run->>Drain: cancel context
|
||||
Note over Drain: exits
|
||||
```
|
||||
|
||||
## Data Structures
|
||||
|
||||
### steeringQueue
|
||||
|
||||
A thread-safe FIFO queue, private to the `agent` package.
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `mu` | `sync.Mutex` | Protects all access to `queue` and `mode` |
|
||||
| `queue` | `[]providers.Message` | Pending steering messages |
|
||||
| `mode` | `SteeringMode` | Dequeue strategy |
|
||||
|
||||
**Methods:**
|
||||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| `push(msg) error` | Appends a message to the queue. Returns an error if the queue is full (`MaxQueueSize`) |
|
||||
| `dequeue() []Message` | Removes and returns messages according to `mode`. Returns `nil` if empty |
|
||||
| `len() int` | Returns the current queue length |
|
||||
| `setMode(mode)` | Updates the dequeue strategy |
|
||||
| `getMode() SteeringMode` | Returns the current mode |
|
||||
|
||||
### SteeringMode
|
||||
|
||||
| Value | Constant | Behavior |
|
||||
|-------|----------|----------|
|
||||
| `"one-at-a-time"` | `SteeringOneAtATime` | `dequeue()` returns only the **first** message. Remaining messages stay in the queue for subsequent polls. |
|
||||
| `"all"` | `SteeringAll` | `dequeue()` drains the **entire** queue and returns all messages at once. |
|
||||
|
||||
Default: `"one-at-a-time"`.
|
||||
|
||||
### processOptions extension
|
||||
|
||||
A new field was added to `processOptions`:
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `SkipInitialSteeringPoll` | `bool` | When `true`, the initial steering poll at loop start is skipped. Used by `Continue()` to avoid double-dequeuing. |
|
||||
|
||||
## Public API on AgentLoop
|
||||
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `Steer` | `Steer(msg providers.Message) error` | Enqueues a steering message. Returns an error if the queue is full or not initialized. Thread-safe, can be called from any goroutine. |
|
||||
| `SteeringMode` | `SteeringMode() SteeringMode` | Returns the current dequeue mode. |
|
||||
| `SetSteeringMode` | `SetSteeringMode(mode SteeringMode)` | Changes the dequeue mode at runtime. |
|
||||
| `Continue` | `Continue(ctx, sessionKey, channel, chatID) (string, error)` | Resumes an idle agent using pending steering messages. Returns `""` if queue is empty. |
|
||||
|
||||
## Integration into the Agent Loop
|
||||
|
||||
### Where steering is wired
|
||||
|
||||
The steering queue lives as a field on `AgentLoop`:
|
||||
|
||||
```
|
||||
AgentLoop
|
||||
├── bus
|
||||
├── cfg
|
||||
├── registry
|
||||
├── steering *steeringQueue ← new
|
||||
├── ...
|
||||
```
|
||||
|
||||
It is initialized in `NewAgentLoop` from `cfg.Agents.Defaults.SteeringMode`.
|
||||
|
||||
### Detailed flow through runLLMIteration
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User
|
||||
participant AgentLoop
|
||||
participant runLLMIteration
|
||||
participant ToolExecution
|
||||
participant LLM
|
||||
|
||||
User->>AgentLoop: Steer(message)
|
||||
Note over AgentLoop: steeringQueue.push(message)
|
||||
|
||||
Note over runLLMIteration: ── iteration starts ──
|
||||
|
||||
runLLMIteration->>AgentLoop: dequeueSteeringMessages()<br/>[initial poll]
|
||||
AgentLoop-->>runLLMIteration: [] (empty, or messages)
|
||||
|
||||
alt pendingMessages not empty
|
||||
runLLMIteration->>runLLMIteration: inject into messages[]<br/>save to session
|
||||
end
|
||||
|
||||
runLLMIteration->>LLM: Chat(messages, tools)
|
||||
LLM-->>runLLMIteration: response with toolCalls[0..N]
|
||||
|
||||
loop for each tool call (sequential)
|
||||
ToolExecution->>ToolExecution: execute tool[i]
|
||||
ToolExecution->>ToolExecution: process result,<br/>append to messages[]
|
||||
|
||||
ToolExecution->>AgentLoop: dequeueSteeringMessages()
|
||||
AgentLoop-->>ToolExecution: steeringMessages
|
||||
|
||||
alt steering found
|
||||
opt remaining tools > 0
|
||||
Note over ToolExecution: Mark tool[i+1..N-1] as<br/>"Skipped due to queued user message."
|
||||
end
|
||||
Note over ToolExecution: steeringAfterTools = steeringMessages
|
||||
Note over ToolExecution: break out of tool loop
|
||||
end
|
||||
end
|
||||
|
||||
alt steeringAfterTools not empty
|
||||
ToolExecution-->>runLLMIteration: pendingMessages = steeringAfterTools
|
||||
Note over runLLMIteration: next iteration will inject<br/>these before calling LLM
|
||||
end
|
||||
|
||||
Note over runLLMIteration: ── loop back to iteration start ──
|
||||
```
|
||||
|
||||
### Polling checkpoints
|
||||
|
||||
| # | Location | When | Purpose |
|
||||
|---|----------|------|---------|
|
||||
| 1 | Top of `runLLMIteration`, before first LLM call | Once, at loop entry | Catch messages enqueued while the agent was still setting up context |
|
||||
| 2 | After every tool completes (including the first and the last) | Immediately after each tool's result is processed | Interrupt the batch as early as possible — if steering is found and there are remaining tools, they are all skipped |
|
||||
|
||||
### What happens to skipped tools
|
||||
|
||||
When steering interrupts a tool batch after tool `[i]` completes, all tools from `[i+1]` to `[N-1]` are **not executed**. Instead, a tool result message is generated for each:
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "Skipped due to queued user message.",
|
||||
"tool_call_id": "<original_call_id>"
|
||||
}
|
||||
```
|
||||
|
||||
These results are:
|
||||
- Appended to the conversation `messages[]`
|
||||
- Saved to the session via `AddFullMessage`
|
||||
|
||||
This ensures the LLM knows which of its requested actions were not performed.
|
||||
|
||||
### Loop condition change
|
||||
|
||||
The iteration loop condition was changed from:
|
||||
|
||||
```go
|
||||
for iteration < agent.MaxIterations
|
||||
```
|
||||
|
||||
to:
|
||||
|
||||
```go
|
||||
for iteration < agent.MaxIterations || len(pendingMessages) > 0
|
||||
```
|
||||
|
||||
This allows **one extra iteration** when steering arrives right at the max iteration boundary, ensuring the steering message is always processed.
|
||||
|
||||
### Tool execution: parallel → sequential
|
||||
|
||||
**Before steering:** all tool calls in a batch were executed in parallel using `sync.WaitGroup`.
|
||||
|
||||
**After steering:** tool calls execute **sequentially**. This is required because steering must be polled between individual tool completions. A parallel execution model would not allow interrupting mid-batch.
|
||||
|
||||
> **Trade-off:** This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal. The benefit of being able to interrupt outweighs the cost.
|
||||
|
||||
### Why skip remaining tools (instead of letting them finish)
|
||||
|
||||
Two strategies were considered when a steering message is detected mid-batch:
|
||||
|
||||
1. **Skip remaining tools** (chosen) — stop executing, mark the rest as skipped, inject steering
|
||||
2. **Finish all tools, then inject** — let everything run, append steering afterwards
|
||||
|
||||
Strategy 2 was rejected for three reasons:
|
||||
|
||||
**Irreversible side effects.** Tools can send emails, write files, spawn subagents, or call external APIs. If the user says "stop" or "change direction", those actions have already happened and cannot be undone.
|
||||
|
||||
| Tool batch | Steering | Skip (1) | Finish (2) |
|
||||
|---|---|---|---|
|
||||
| `[search, send_email]` | "don't send it" | Email not sent | Email sent |
|
||||
| `[query, write_file, spawn]` | "wrong database" | Only query runs | File + subagent wasted |
|
||||
| `[fetch₁, fetch₂, fetch₃, write]` | topic change | 1 fetch | 3 fetches + write, all discarded |
|
||||
|
||||
**Wasted latency.** Tools like web fetches and API calls take seconds each. In a 3-tool batch averaging 3-4s per tool, the user would wait 10+ seconds for work that gets thrown away.
|
||||
|
||||
**The LLM retains full awareness.** Skipped tools receive an explicit `"Skipped due to queued user message."` result, so the model knows what was not done and can decide whether to re-execute with the new context or take a different path.
|
||||
|
||||
## The Continue() method
|
||||
|
||||
`Continue` handles the case where the agent is **idle** (its last message was from the assistant) and the user has enqueued steering messages in the meantime.
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Continue called] --> B{dequeueSteeringMessages}
|
||||
B -->|empty| C["return ('', nil)"]
|
||||
B -->|messages found| D[Combine message contents]
|
||||
D --> E["runAgentLoop with<br/>SkipInitialSteeringPoll: true"]
|
||||
E --> F[Return response]
|
||||
```
|
||||
|
||||
**Why `SkipInitialSteeringPoll: true`?** Because `Continue` already dequeued the messages itself. Without this flag, `runLLMIteration` would poll again at the start and find nothing (the queue is already empty), or worse, double-process if new messages arrived in the meantime.
|
||||
|
||||
## Configuration
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"steering_mode": "one-at-a-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Type | Default | Env var |
|
||||
|-------|------|---------|---------|
|
||||
| `steering_mode` | `string` | `"one-at-a-time"` | `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` |
|
||||
|
||||
|
||||
## Design decisions and trade-offs
|
||||
|
||||
| Decision | Rationale |
|
||||
|----------|-----------|
|
||||
| Sequential tool execution | Required for per-tool steering polls. Parallel execution cannot be interrupted mid-batch. |
|
||||
| Polling-based (not channel/signal) | Keeps the implementation simple. No need for `select` or signal channels. The polling cost is negligible (mutex lock + slice length check). |
|
||||
| `one-at-a-time` as default | Gives the model a chance to react to each steering message individually. More predictable behavior than dumping all messages at once. |
|
||||
| Skipped tools get explicit error results | The LLM protocol requires a tool result for every tool call in the assistant message. Omitting them would cause API errors. The skip message also informs the model about what was not done. |
|
||||
| `Continue()` uses `SkipInitialSteeringPoll` | Prevents race conditions and double-dequeuing when resuming an idle agent. |
|
||||
| Queue stored on `AgentLoop`, not `AgentInstance` | Steering is a loop-level concern (it affects the iteration flow), not a per-agent concern. All agents share the same steering queue since `processMessage` is sequential. |
|
||||
| Bus drain goroutine in `Run()` | Channels (Telegram, Discord, etc.) publish to the bus via `PublishInbound`. Without the drain, messages would queue in the bus channel buffer and only be consumed after `processMessage` returns — defeating the purpose of steering. The drain goroutine bridges the gap by consuming new bus messages and calling `Steer()` while the agent is busy. |
|
||||
| Audio transcription before steering | The drain goroutine calls `al.transcribeAudioInMessage(ctx, msg)` before steering, so voice messages are converted to text before the agent sees them. If transcription fails, the error is silently discarded and the original message is steered as-is. |
|
||||
| `MaxQueueSize = 10` | Prevents unbounded memory growth if a user sends many messages while the agent is busy. Excess messages are dropped with a warning. |
|
||||
@@ -0,0 +1,166 @@
|
||||
# Steering
|
||||
|
||||
Steering allows injecting messages into an already-running agent loop, interrupting it between tool calls without waiting for the entire cycle to complete.
|
||||
|
||||
## How it works
|
||||
|
||||
When the agent is executing a sequence of tool calls (e.g. the model requested 3 tools in a single turn), steering checks the queue **after each tool** completes. If it finds queued messages:
|
||||
|
||||
1. The remaining tools are **skipped** and receive `"Skipped due to queued user message."` as their result
|
||||
2. The steering messages are **injected into the conversation context**
|
||||
3. The model is called again with the updated context, including the user's steering message
|
||||
|
||||
```
|
||||
User ──► Steer("change approach")
|
||||
│
|
||||
Agent Loop ▼
|
||||
├─ tool[0] ✔ (executed)
|
||||
├─ [polling] → steering found!
|
||||
├─ tool[1] ✘ (skipped)
|
||||
├─ tool[2] ✘ (skipped)
|
||||
└─ new LLM turn with steering message
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
In `config.json`, under `agents.defaults`:
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"steering_mode": "one-at-a-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Modes
|
||||
|
||||
| Value | Behavior |
|
||||
|-------|----------|
|
||||
| `"one-at-a-time"` | **(default)** Dequeues only one message per polling cycle. If there are 3 messages in the queue, they are processed one at a time across 3 successive iterations. |
|
||||
| `"all"` | Drains the entire queue in a single poll. All pending messages are injected into the context together. |
|
||||
|
||||
The environment variable `PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE` can be used as an alternative.
|
||||
|
||||
## Go API
|
||||
|
||||
### Steer — Send a steering message
|
||||
|
||||
```go
|
||||
err := agentLoop.Steer(providers.Message{
|
||||
Role: "user",
|
||||
Content: "change direction, focus on X instead",
|
||||
})
|
||||
if err != nil {
|
||||
// Queue is full (MaxQueueSize=10) or not initialized
|
||||
}
|
||||
```
|
||||
|
||||
The message is enqueued in a thread-safe manner. Returns an error if the queue is full or not initialized. It will be picked up at the next polling point (after the current tool finishes).
|
||||
|
||||
### SteeringMode / SetSteeringMode
|
||||
|
||||
```go
|
||||
// Read the current mode
|
||||
mode := agentLoop.SteeringMode() // SteeringOneAtATime | SteeringAll
|
||||
|
||||
// Change it at runtime
|
||||
agentLoop.SetSteeringMode(agent.SteeringAll)
|
||||
```
|
||||
|
||||
### Continue — Resume an idle agent
|
||||
|
||||
When the agent is idle (it has finished processing and its last message was from the assistant), `Continue` checks if there are steering messages in the queue and uses them to start a new cycle:
|
||||
|
||||
```go
|
||||
response, err := agentLoop.Continue(ctx, sessionKey, channel, chatID)
|
||||
if err != nil {
|
||||
// Error (e.g. "no default agent available")
|
||||
}
|
||||
if response == "" {
|
||||
// No steering messages in queue, the agent stays idle
|
||||
}
|
||||
```
|
||||
|
||||
`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input).
|
||||
|
||||
## Polling points in the loop
|
||||
|
||||
Steering is checked at **two points** in the agent cycle:
|
||||
|
||||
1. **At loop start** — before the first LLM call, to catch messages enqueued during setup
|
||||
2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately
|
||||
|
||||
## Why remaining tools are skipped
|
||||
|
||||
When a steering message is detected, all remaining tools in the batch are skipped rather than executed. The alternative — let all tools finish and inject the steering message afterwards — was considered and rejected. Here is why.
|
||||
|
||||
### Preventing unwanted side effects
|
||||
|
||||
Tools can have **irreversible side effects**. If the user says "no, wait" while the agent is mid-batch, executing the remaining tools means those side effects happen anyway:
|
||||
|
||||
| Tool batch | Steering message | With skip | Without skip |
|
||||
|---|---|---|---|
|
||||
| `[web_search, send_email]` | "don't send it" | Email **not** sent | Email sent, damage done |
|
||||
| `[query_db, write_file, spawn_agent]` | "use another database" | Only the query runs | File written + subagent spawned, all wasted |
|
||||
| `[search₁, search₂, search₃, write_file]` | user changes topic entirely | 1 search | 3 searches + file write, all irrelevant |
|
||||
|
||||
### Avoiding wasted time
|
||||
|
||||
Tools that take seconds (web fetches, API calls, database queries) would all run to completion before the agent sees the user's correction. In a batch of 3 tools each taking 3-4 seconds, that's 10+ seconds of work that will be discarded.
|
||||
|
||||
With skipping, the agent reacts as soon as the current tool finishes — typically within a few seconds instead of waiting for the entire batch.
|
||||
|
||||
### The LLM gets full context
|
||||
|
||||
Skipped tools receive an explicit error result (`"Skipped due to queued user message."`), so the model knows exactly which actions were not performed. It can then decide whether to re-execute them with the new context, or take a different path entirely.
|
||||
|
||||
### Trade-off: sequential execution
|
||||
|
||||
Skipping requires tools to run **sequentially** (the previous implementation ran them in parallel). This introduces latency when the LLM requests multiple independent tools in a single turn. In practice, most batches contain 1-2 tools, so the impact is minimal compared to the benefit of being able to stop unwanted actions.
|
||||
|
||||
## Skipped tool result format
|
||||
|
||||
When steering interrupts a batch, each tool that was not executed receives a `tool` result with:
|
||||
|
||||
```
|
||||
Content: "Skipped due to queued user message."
|
||||
```
|
||||
|
||||
This is saved to the session via `AddFullMessage` and sent to the model, so it is aware that some requested actions were not performed.
|
||||
|
||||
## Full flow example
|
||||
|
||||
```
|
||||
1. User: "search for info on X, write a file, and send me a message"
|
||||
|
||||
2. LLM responds with 3 tool calls: [web_search, write_file, message]
|
||||
|
||||
3. web_search is executed → result saved
|
||||
|
||||
4. [polling] → User called Steer("no, search for Y instead")
|
||||
|
||||
5. write_file is skipped → "Skipped due to queued user message."
|
||||
message is skipped → "Skipped due to queued user message."
|
||||
|
||||
6. Message "search for Y instead" injected into context
|
||||
|
||||
7. LLM receives the full updated context and responds accordingly
|
||||
```
|
||||
|
||||
## Automatic bus drain
|
||||
|
||||
When the agent loop (`Run()`) starts processing a message, it spawns a background goroutine that keeps consuming new inbound messages from the bus. These messages are automatically redirected into the steering queue via `Steer()`. This means:
|
||||
|
||||
- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy
|
||||
- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is
|
||||
- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes
|
||||
|
||||
## Notes
|
||||
|
||||
- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue.
|
||||
- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually.
|
||||
- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once.
|
||||
- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped.
|
||||
+183
-102
@@ -48,6 +48,7 @@ type AgentLoop struct {
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
steering *steeringQueue
|
||||
mu sync.RWMutex
|
||||
// Track active requests for safe provider cleanup
|
||||
activeRequests sync.WaitGroup
|
||||
@@ -55,15 +56,16 @@ type AgentLoop struct {
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -105,6 +107,7 @@ func NewAgentLoop(
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
|
||||
return al
|
||||
@@ -257,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Start a goroutine that drains the bus while processMessage is
|
||||
// running. Any inbound messages that arrive during processing are
|
||||
// redirected into the steering queue so the agent loop can pick
|
||||
// them up between tool calls.
|
||||
drainCtx, drainCancel := context.WithCancel(ctx)
|
||||
go al.drainBusToSteering(drainCtx)
|
||||
|
||||
// Process message
|
||||
func() {
|
||||
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
|
||||
@@ -272,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// }
|
||||
// }()
|
||||
|
||||
defer drainCancel()
|
||||
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
@@ -318,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// drainBusToSteering continuously consumes inbound messages and redirects
|
||||
// them into the steering queue. It runs in a goroutine while processMessage
|
||||
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
|
||||
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
|
||||
for {
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Transcribe audio if needed before steering, so the agent sees text.
|
||||
msg, _ = al.transcribeAudioInMessage(ctx, msg)
|
||||
|
||||
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"sender_id": msg.SenderID,
|
||||
"content_len": len(msg.Content),
|
||||
})
|
||||
|
||||
if err := al.Steer(providers.Message{
|
||||
Role: "user",
|
||||
Content: msg.Content,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Failed to steer message, will be lost",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
}
|
||||
@@ -999,6 +1044,16 @@ func (al *AgentLoop) runLLMIteration(
|
||||
) (string, int, error) {
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
var pendingMessages []providers.Message
|
||||
|
||||
// Poll for steering messages at loop start (in case the user typed while
|
||||
// the agent was setting up), unless the caller already provided initial
|
||||
// steering messages (e.g. Continue).
|
||||
if !opts.SkipInitialSteeringPoll {
|
||||
if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 {
|
||||
pendingMessages = msgs
|
||||
}
|
||||
}
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
@@ -1006,9 +1061,25 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// tool chain doesn't switch models mid-way through.
|
||||
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
|
||||
|
||||
for iteration < agent.MaxIterations {
|
||||
for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
|
||||
iteration++
|
||||
|
||||
// Inject pending steering messages into the conversation context
|
||||
// before the next LLM call.
|
||||
if len(pendingMessages) > 0 {
|
||||
for _, pm := range pendingMessages {
|
||||
messages = append(messages, pm)
|
||||
agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
|
||||
logger.InfoCF("agent", "Injected steering message into context",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_len": len(pm.Content),
|
||||
})
|
||||
}
|
||||
pendingMessages = nil
|
||||
}
|
||||
|
||||
logger.DebugCF("agent", "LLM iteration",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
@@ -1251,107 +1322,83 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Save assistant message with tool calls to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||
|
||||
// Execute tool calls in parallel
|
||||
type indexedAgentResult struct {
|
||||
result *tools.ToolResult
|
||||
tc providers.ToolCall
|
||||
}
|
||||
|
||||
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
// Execute tool calls sequentially. After each tool completes, check
|
||||
// for steering messages. If any are found, skip remaining tools.
|
||||
var steeringAfterTools []providers.Message
|
||||
|
||||
for i, tc := range normalizedToolCalls {
|
||||
agentResults[i].tc = tc
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
go func(idx int, tc providers.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Create async callback for tools that implement AsyncExecutor.
|
||||
// When the background work completes, this publishes the result
|
||||
// as an inbound system message so processSystemMessage routes it
|
||||
// back to the user via the normal agent loop.
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
// Send ForUser content directly to the user (immediate feedback),
|
||||
// mirroring the synchronous tool execution path.
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer outCancel()
|
||||
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: result.ForUser,
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for the agent loop (ForLLM or error).
|
||||
content := result.ForLLM
|
||||
if content == "" && result.Err != nil {
|
||||
content = result.Err.Error()
|
||||
}
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(content),
|
||||
"channel": opts.Channel,
|
||||
})
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", tc.Name),
|
||||
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
|
||||
Content: content,
|
||||
// Create async callback for tools that implement AsyncExecutor.
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer outCancel()
|
||||
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: result.ForUser,
|
||||
})
|
||||
}
|
||||
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
agentResults[idx].result = toolResult
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
content := result.ForLLM
|
||||
if content == "" && result.Err != nil {
|
||||
content = result.Err.Error()
|
||||
}
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Process results in original order (send to user, save to session)
|
||||
for _, r := range agentResults {
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(content),
|
||||
"channel": opts.Channel,
|
||||
})
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", tc.Name),
|
||||
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
|
||||
// Process tool result
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: r.result.ForUser,
|
||||
Content: toolResult.ForUser,
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": r.tc.Name,
|
||||
"content_len": len(r.result.ForUser),
|
||||
"tool": tc.Name,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
// If tool returned media refs, publish them as outbound media
|
||||
if len(r.result.Media) > 0 {
|
||||
parts := make([]bus.MediaPart, 0, len(r.result.Media))
|
||||
for _, ref := range r.result.Media {
|
||||
if len(toolResult.Media) > 0 {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
@@ -1369,21 +1416,55 @@ func (al *AgentLoop) runLLMIteration(
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := r.result.ForLLM
|
||||
if contentForLLM == "" && r.result.Err != nil {
|
||||
contentForLLM = r.result.Err.Error()
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: r.tc.ID,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
|
||||
// Save tool result message to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
|
||||
// After EVERY tool (including the first and last), check for
|
||||
// steering messages. If found and there are remaining tools,
|
||||
// skip them all.
|
||||
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
|
||||
remaining := len(normalizedToolCalls) - i - 1
|
||||
if remaining > 0 {
|
||||
logger.InfoCF("agent", "Steering interrupt: skipping remaining tools",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"completed": i + 1,
|
||||
"skipped": remaining,
|
||||
"total_tools": len(normalizedToolCalls),
|
||||
"steering_count": len(steerMsgs),
|
||||
})
|
||||
|
||||
// Mark remaining tool calls as skipped
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: "Skipped due to queued user message.",
|
||||
ToolCallID: skippedTC.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
}
|
||||
}
|
||||
steeringAfterTools = steerMsgs
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If steering messages were captured during tool execution, they
|
||||
// become pendingMessages for the next iteration of the inner loop.
|
||||
if len(steeringAfterTools) > 0 {
|
||||
pendingMessages = steeringAfterTools
|
||||
}
|
||||
|
||||
// Tick down TTL of discovered tools after processing tool results.
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// SteeringMode controls how queued steering messages are dequeued.
|
||||
type SteeringMode string
|
||||
|
||||
const (
|
||||
// SteeringOneAtATime dequeues only the first queued message per poll.
|
||||
SteeringOneAtATime SteeringMode = "one-at-a-time"
|
||||
// SteeringAll drains the entire queue in a single poll.
|
||||
SteeringAll SteeringMode = "all"
|
||||
// MaxQueueSize number of possible messages in the Steering Queue
|
||||
MaxQueueSize = 10
|
||||
)
|
||||
|
||||
// parseSteeringMode normalizes a config string into a SteeringMode.
|
||||
func parseSteeringMode(s string) SteeringMode {
|
||||
switch s {
|
||||
case "all":
|
||||
return SteeringAll
|
||||
default:
|
||||
return SteeringOneAtATime
|
||||
}
|
||||
}
|
||||
|
||||
// steeringQueue is a thread-safe queue of user messages that can be injected
|
||||
// into a running agent loop to interrupt it between tool calls.
|
||||
type steeringQueue struct {
|
||||
mu sync.Mutex
|
||||
queue []providers.Message
|
||||
mode SteeringMode
|
||||
}
|
||||
|
||||
func newSteeringQueue(mode SteeringMode) *steeringQueue {
|
||||
return &steeringQueue{
|
||||
mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// push enqueues a steering message.
|
||||
func (sq *steeringQueue) push(msg providers.Message) error {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
if len(sq.queue) >= MaxQueueSize {
|
||||
return fmt.Errorf("steering queue is full")
|
||||
}
|
||||
sq.queue = append(sq.queue, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// dequeue removes and returns pending steering messages according to the
|
||||
// configured mode. Returns nil when the queue is empty.
|
||||
func (sq *steeringQueue) dequeue() []providers.Message {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
if len(sq.queue) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch sq.mode {
|
||||
case SteeringAll:
|
||||
msgs := sq.queue
|
||||
sq.queue = nil
|
||||
return msgs
|
||||
default: // one-at-a-time
|
||||
msg := sq.queue[0]
|
||||
sq.queue[0] = providers.Message{} // Clear reference for GC
|
||||
sq.queue = sq.queue[1:]
|
||||
return []providers.Message{msg}
|
||||
}
|
||||
}
|
||||
|
||||
// len returns the number of queued messages.
|
||||
func (sq *steeringQueue) len() int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return len(sq.queue)
|
||||
}
|
||||
|
||||
// setMode updates the steering mode.
|
||||
func (sq *steeringQueue) setMode(mode SteeringMode) {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
sq.mode = mode
|
||||
}
|
||||
|
||||
// getMode returns the current steering mode.
|
||||
func (sq *steeringQueue) getMode() SteeringMode {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return sq.mode
|
||||
}
|
||||
|
||||
// --- AgentLoop steering API ---
|
||||
|
||||
// Steer enqueues a user message to be injected into the currently running
|
||||
// agent loop. The message will be picked up after the current tool finishes
|
||||
// executing, causing any remaining tool calls in the batch to be skipped.
|
||||
func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
if al.steering == nil {
|
||||
return fmt.Errorf("steering queue is not initialized")
|
||||
}
|
||||
if err := al.steering.push(msg); err != nil {
|
||||
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
|
||||
"error": err.Error(),
|
||||
"role": msg.Role,
|
||||
})
|
||||
return err
|
||||
}
|
||||
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
|
||||
"role": msg.Role,
|
||||
"content_len": len(msg.Content),
|
||||
"queue_len": al.steering.len(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SteeringMode returns the current steering mode.
|
||||
func (al *AgentLoop) SteeringMode() SteeringMode {
|
||||
if al.steering == nil {
|
||||
return SteeringOneAtATime
|
||||
}
|
||||
return al.steering.getMode()
|
||||
}
|
||||
|
||||
// SetSteeringMode updates the steering mode.
|
||||
func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
|
||||
if al.steering == nil {
|
||||
return
|
||||
}
|
||||
al.steering.setMode(mode)
|
||||
}
|
||||
|
||||
// dequeueSteeringMessages is the internal method called by the agent loop
|
||||
// to poll for steering messages. Returns nil when no messages are pending.
|
||||
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeue()
|
||||
}
|
||||
|
||||
// Continue resumes an idle agent by dequeuing any pending steering messages
|
||||
// and running them through the agent loop. This is used when the agent's last
|
||||
// message was from the assistant (i.e., it has stopped processing) and the
|
||||
// user has since enqueued steering messages.
|
||||
//
|
||||
// If no steering messages are pending, it returns an empty string.
|
||||
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
|
||||
steeringMsgs := al.dequeueSteeringMessages()
|
||||
if len(steeringMsgs) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent available")
|
||||
}
|
||||
|
||||
// Build a combined user message from the steering messages.
|
||||
var contents []string
|
||||
for _, msg := range steeringMsgs {
|
||||
contents = append(contents, msg.Content)
|
||||
}
|
||||
combinedContent := strings.Join(contents, "\n")
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
UserMessage: combinedContent,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,744 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// --- steeringQueue unit tests ---
|
||||
|
||||
func TestSteeringQueue_PushDequeue_OneAtATime(t *testing.T) {
|
||||
sq := newSteeringQueue(SteeringOneAtATime)
|
||||
|
||||
sq.push(providers.Message{Role: "user", Content: "msg1"})
|
||||
sq.push(providers.Message{Role: "user", Content: "msg2"})
|
||||
sq.push(providers.Message{Role: "user", Content: "msg3"})
|
||||
|
||||
if sq.len() != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", sq.len())
|
||||
}
|
||||
|
||||
msgs := sq.dequeue()
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1 message in one-at-a-time mode, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Content != "msg1" {
|
||||
t.Fatalf("expected 'msg1', got %q", msgs[0].Content)
|
||||
}
|
||||
if sq.len() != 2 {
|
||||
t.Fatalf("expected 2 remaining, got %d", sq.len())
|
||||
}
|
||||
|
||||
msgs = sq.dequeue()
|
||||
if len(msgs) != 1 || msgs[0].Content != "msg2" {
|
||||
t.Fatalf("expected 'msg2', got %v", msgs)
|
||||
}
|
||||
|
||||
msgs = sq.dequeue()
|
||||
if len(msgs) != 1 || msgs[0].Content != "msg3" {
|
||||
t.Fatalf("expected 'msg3', got %v", msgs)
|
||||
}
|
||||
|
||||
msgs = sq.dequeue()
|
||||
if msgs != nil {
|
||||
t.Fatalf("expected nil from empty queue, got %v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSteeringQueue_PushDequeue_All(t *testing.T) {
|
||||
sq := newSteeringQueue(SteeringAll)
|
||||
|
||||
sq.push(providers.Message{Role: "user", Content: "msg1"})
|
||||
sq.push(providers.Message{Role: "user", Content: "msg2"})
|
||||
sq.push(providers.Message{Role: "user", Content: "msg3"})
|
||||
|
||||
msgs := sq.dequeue()
|
||||
if len(msgs) != 3 {
|
||||
t.Fatalf("expected 3 messages in all mode, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Content != "msg1" || msgs[1].Content != "msg2" || msgs[2].Content != "msg3" {
|
||||
t.Fatalf("unexpected messages: %v", msgs)
|
||||
}
|
||||
|
||||
if sq.len() != 0 {
|
||||
t.Fatalf("expected 0 remaining, got %d", sq.len())
|
||||
}
|
||||
|
||||
msgs = sq.dequeue()
|
||||
if msgs != nil {
|
||||
t.Fatalf("expected nil from empty queue, got %v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSteeringQueue_EmptyDequeue(t *testing.T) {
|
||||
sq := newSteeringQueue(SteeringOneAtATime)
|
||||
if msgs := sq.dequeue(); msgs != nil {
|
||||
t.Fatalf("expected nil, got %v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSteeringQueue_SetMode(t *testing.T) {
|
||||
sq := newSteeringQueue(SteeringOneAtATime)
|
||||
if sq.getMode() != SteeringOneAtATime {
|
||||
t.Fatalf("expected one-at-a-time, got %v", sq.getMode())
|
||||
}
|
||||
|
||||
sq.setMode(SteeringAll)
|
||||
if sq.getMode() != SteeringAll {
|
||||
t.Fatalf("expected all, got %v", sq.getMode())
|
||||
}
|
||||
|
||||
// Push two messages and verify all-mode drains them
|
||||
sq.push(providers.Message{Role: "user", Content: "a"})
|
||||
sq.push(providers.Message{Role: "user", Content: "b"})
|
||||
|
||||
msgs := sq.dequeue()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected 2 messages after mode switch, got %d", len(msgs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSteeringQueue_ConcurrentAccess(t *testing.T) {
|
||||
sq := newSteeringQueue(SteeringOneAtATime)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const n = MaxQueueSize
|
||||
|
||||
// Push from multiple goroutines
|
||||
for i := 0; i < n; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if sq.len() != n {
|
||||
t.Fatalf("expected %d messages, got %d", n, sq.len())
|
||||
}
|
||||
|
||||
// Drain from multiple goroutines
|
||||
var drained int
|
||||
var mu sync.Mutex
|
||||
for i := 0; i < n; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if msgs := sq.dequeue(); len(msgs) > 0 {
|
||||
mu.Lock()
|
||||
drained += len(msgs)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if drained != n {
|
||||
t.Fatalf("expected to drain %d messages, got %d", n, drained)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSteeringQueue_Overflow(t *testing.T) {
|
||||
sq := newSteeringQueue(SteeringOneAtATime)
|
||||
|
||||
// Fill the queue up to its maximum capacity
|
||||
for i := 0; i < MaxQueueSize; i++ {
|
||||
err := sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error pushing message %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check: ensure the queue is actually full
|
||||
if sq.len() != MaxQueueSize {
|
||||
t.Fatalf("expected queue length %d, got %d", MaxQueueSize, sq.len())
|
||||
}
|
||||
|
||||
// Attempt to push one more message, which MUST fail
|
||||
err := sq.push(providers.Message{Role: "user", Content: "overflow_msg"})
|
||||
|
||||
// Assert the error happened and is the exact one we expect
|
||||
if err == nil {
|
||||
t.Fatal("expected an error when pushing to a full queue, but got nil")
|
||||
}
|
||||
|
||||
expectedErr := "steering queue is full"
|
||||
if err.Error() != expectedErr {
|
||||
t.Errorf("expected error message %q, got %q", expectedErr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSteeringMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected SteeringMode
|
||||
}{
|
||||
{"", SteeringOneAtATime},
|
||||
{"one-at-a-time", SteeringOneAtATime},
|
||||
{"all", SteeringAll},
|
||||
{"unknown", SteeringOneAtATime},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
if got := parseSteeringMode(tt.input); got != tt.expected {
|
||||
t.Fatalf("parseSteeringMode(%q) = %v, want %v", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- AgentLoop steering integration tests ---
|
||||
|
||||
func TestAgentLoop_Steer_Enqueues(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("expected config to be initialized")
|
||||
}
|
||||
if msgBus == nil {
|
||||
t.Fatal("expected message bus to be initialized")
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("expected provider to be initialized")
|
||||
}
|
||||
|
||||
al.Steer(providers.Message{Role: "user", Content: "interrupt me"})
|
||||
|
||||
if al.steering.len() != 1 {
|
||||
t.Fatalf("expected 1 steering message, got %d", al.steering.len())
|
||||
}
|
||||
|
||||
msgs := al.dequeueSteeringMessages()
|
||||
if len(msgs) != 1 || msgs[0].Content != "interrupt me" {
|
||||
t.Fatalf("unexpected dequeued message: %v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_SteeringMode_GetSet(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("expected config to be initialized")
|
||||
}
|
||||
if msgBus == nil {
|
||||
t.Fatal("expected message bus to be initialized")
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("expected provider to be initialized")
|
||||
}
|
||||
|
||||
if al.SteeringMode() != SteeringOneAtATime {
|
||||
t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode())
|
||||
}
|
||||
|
||||
al.SetSteeringMode(SteeringAll)
|
||||
if al.SteeringMode() != SteeringAll {
|
||||
t.Fatalf("expected all mode, got %v", al.SteeringMode())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
SteeringMode: "all",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
if al.SteeringMode() != SteeringAll {
|
||||
t.Fatalf("expected 'all' mode from config, got %v", al.SteeringMode())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Continue_NoMessages(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("expected config to be initialized")
|
||||
}
|
||||
if msgBus == nil {
|
||||
t.Fatal("expected message bus to be initialized")
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("expected provider to be initialized")
|
||||
}
|
||||
|
||||
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp != "" {
|
||||
t.Fatalf("expected empty response for no steering messages, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Continue_WithMessages(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "continued response"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
al.Steer(providers.Message{Role: "user", Content: "new direction"})
|
||||
|
||||
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp != "continued response" {
|
||||
t.Fatalf("expected 'continued response', got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// slowTool simulates a tool that takes some time to execute.
|
||||
type slowTool struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
execCh chan struct{} // closed when Execute starts
|
||||
}
|
||||
|
||||
func (t *slowTool) Name() string { return t.name }
|
||||
func (t *slowTool) Description() string { return "slow tool for testing" }
|
||||
func (t *slowTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
if t.execCh != nil {
|
||||
close(t.execCh)
|
||||
}
|
||||
time.Sleep(t.duration)
|
||||
return tools.SilentResult(fmt.Sprintf("executed %s", t.name))
|
||||
}
|
||||
|
||||
// toolCallProvider returns an LLM response with tool calls on the first call,
|
||||
// then a direct response on subsequent calls.
|
||||
type toolCallProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
toolCalls []providers.ToolCall
|
||||
finalResp string
|
||||
}
|
||||
|
||||
func (m *toolCallProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.calls++
|
||||
|
||||
if m.calls == 1 && len(m.toolCalls) > 0 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "",
|
||||
ToolCalls: m.toolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: m.finalResp,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *toolCallProvider) GetDefaultModel() string {
|
||||
return "tool-call-mock"
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool1ExecCh := make(chan struct{})
|
||||
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
||||
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
||||
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "tool_one",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_one",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Type: "function",
|
||||
Name: "tool_two",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_two",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "steered response",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
|
||||
// Start processing in a goroutine
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do something",
|
||||
"test-session",
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp, err}
|
||||
}()
|
||||
|
||||
// Wait for tool_one to start executing, then enqueue a steering message
|
||||
select {
|
||||
case <-tool1ExecCh:
|
||||
// tool_one has started executing
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for tool_one to start")
|
||||
}
|
||||
|
||||
al.Steer(providers.Message{Role: "user", Content: "change course"})
|
||||
|
||||
// Get the result
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
if r.resp != "steered response" {
|
||||
t.Fatalf("expected 'steered response', got %q", r.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for agent loop to complete")
|
||||
}
|
||||
|
||||
// The provider should have been called twice:
|
||||
// 1. first call returned tool calls
|
||||
// 2. second call (after steering) returned the final response
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Provider that captures messages it receives
|
||||
var capturedMessages []providers.Message
|
||||
var capMu sync.Mutex
|
||||
provider := &capturingMockProvider{
|
||||
response: "ack",
|
||||
captureFn: func(msgs []providers.Message) {
|
||||
capMu.Lock()
|
||||
capturedMessages = make([]providers.Message, len(msgs))
|
||||
copy(capturedMessages, msgs)
|
||||
capMu.Unlock()
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Enqueue a steering message before processing starts
|
||||
al.Steer(providers.Message{Role: "user", Content: "pre-enqueued steering"})
|
||||
|
||||
// Process a normal message - the initial steering poll should inject the steering message
|
||||
_, err = al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"initial message",
|
||||
"test-session",
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// The steering message should have been injected into the conversation
|
||||
capMu.Lock()
|
||||
msgs := capturedMessages
|
||||
capMu.Unlock()
|
||||
|
||||
// Look for the steering message in the captured messages
|
||||
found := false
|
||||
for _, m := range msgs {
|
||||
if m.Content == "pre-enqueued steering" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected steering message to be injected into conversation context")
|
||||
}
|
||||
}
|
||||
|
||||
// capturingMockProvider captures messages sent to Chat for inspection.
|
||||
type capturingMockProvider struct {
|
||||
response string
|
||||
calls int
|
||||
captureFn func([]providers.Message)
|
||||
}
|
||||
|
||||
func (m *capturingMockProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
if m.captureFn != nil {
|
||||
m.captureFn(messages)
|
||||
}
|
||||
return &providers.LLMResponse{
|
||||
Content: m.response,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *capturingMockProvider) GetDefaultModel() string {
|
||||
return "capturing-mock"
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
execCh := make(chan struct{})
|
||||
tool1 := &slowTool{name: "slow_tool", duration: 50 * time.Millisecond, execCh: execCh}
|
||||
tool2 := &slowTool{name: "skipped_tool", duration: 50 * time.Millisecond}
|
||||
|
||||
// Provider that captures messages on the second call (after tools)
|
||||
var secondCallMessages []providers.Message
|
||||
var capMu sync.Mutex
|
||||
callCount := 0
|
||||
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "slow_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "slow_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Type: "function",
|
||||
Name: "skipped_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "skipped_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "done",
|
||||
}
|
||||
|
||||
// Wrap provider to capture messages on second call
|
||||
wrappedProvider := &wrappingProvider{
|
||||
inner: provider,
|
||||
onChat: func(msgs []providers.Message) {
|
||||
capMu.Lock()
|
||||
callCount++
|
||||
if callCount >= 2 {
|
||||
secondCallMessages = make([]providers.Message, len(msgs))
|
||||
copy(secondCallMessages, msgs)
|
||||
}
|
||||
capMu.Unlock()
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, wrappedProvider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
|
||||
resultCh := make(chan string, 1)
|
||||
go func() {
|
||||
resp, _ := al.ProcessDirectWithChannel(
|
||||
context.Background(), "go", "test-session", "test", "chat1",
|
||||
)
|
||||
resultCh <- resp
|
||||
}()
|
||||
|
||||
<-execCh
|
||||
al.Steer(providers.Message{Role: "user", Content: "interrupt!"})
|
||||
|
||||
select {
|
||||
case <-resultCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// Check that the skipped tool result message is in the conversation
|
||||
capMu.Lock()
|
||||
msgs := secondCallMessages
|
||||
capMu.Unlock()
|
||||
|
||||
foundSkipped := false
|
||||
for _, m := range msgs {
|
||||
if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." {
|
||||
foundSkipped = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundSkipped {
|
||||
// Log what we actually got
|
||||
for i, m := range msgs {
|
||||
t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80))
|
||||
}
|
||||
t.Fatal("expected skipped tool result for call_2")
|
||||
}
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
// wrappingProvider wraps another provider to hook into Chat calls.
|
||||
type wrappingProvider struct {
|
||||
inner providers.LLMProvider
|
||||
onChat func([]providers.Message)
|
||||
}
|
||||
|
||||
func (w *wrappingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
if w.onChat != nil {
|
||||
w.onChat(messages)
|
||||
}
|
||||
return w.inner.Chat(ctx, messages, tools, model, opts)
|
||||
}
|
||||
|
||||
func (w *wrappingProvider) GetDefaultModel() string {
|
||||
return w.inner.GetDefaultModel()
|
||||
}
|
||||
|
||||
// Ensure NormalizeToolCall handles our test tool calls.
|
||||
func init() {
|
||||
// This is a no-op init; we just need the tool call tests to work
|
||||
// with the proper argument serialization.
|
||||
_ = json.Marshal
|
||||
}
|
||||
@@ -234,6 +234,7 @@ type AgentDefaults struct {
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
@@ -35,6 +35,7 @@ func DefaultConfig() *Config {
|
||||
MaxToolIterations: 50,
|
||||
SummarizeMessageThreshold: 20,
|
||||
SummarizeTokenPercent: 75,
|
||||
SteeringMode: "one-at-a-time",
|
||||
},
|
||||
},
|
||||
Bindings: []AgentBinding{},
|
||||
|
||||
Reference in New Issue
Block a user