fix(agent) scope steering

This commit is contained in:
afjcjsbx
2026-03-20 19:44:00 +01:00
parent 73a683fd16
commit 1c6586681d
4 changed files with 645 additions and 86 deletions
+34 -1
View File
@@ -21,6 +21,18 @@ Agent Loop ▼
└─ new LLM turn with steering message
```
## Scoped queues
Steering is now isolated per resolved session scope, not stored in a single
global queue.
- The active turn writes and reads from its own scope key (usually the routed session key such as `agent:<agent_id>:...`)
- `Steer()` still works outside an active turn through a legacy fallback queue
- `Continue()` first dequeues messages for the requested session scope, then falls back to the legacy queue for backwards compatibility
This prevents a message arriving from another chat, DM peer, or routed agent
session from being injected into the wrong conversation.
## Configuration
In `config.json`, under `agents.defaults`:
@@ -86,12 +98,18 @@ if response == "" {
`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input).
`Continue` also resolves the target agent from the provided session key, so
agent-scoped sessions continue on the correct agent instead of always using
the default one.
## Polling points in the loop
Steering is checked at **two points** in the agent cycle:
Steering is checked at the following points in the agent cycle:
1. **At loop start** — before the first LLM call, to catch messages enqueued during setup
2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately
3. **After a direct LLM response** — if a new steering message arrived while the model was generating a non-tool response, the loop continues instead of returning a stale answer
4. **Right before the turn is finalized** — if steering arrived at the very end of the turn, the agent immediately starts a continuation turn instead of leaving the message orphaned in the queue
## Why remaining tools are skipped
@@ -156,11 +174,26 @@ When the agent loop (`Run()`) starts processing a message, it spawns a backgroun
- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy
- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is
- Only messages that resolve to the **same steering scope** as the active turn are redirected. Messages for other chats/sessions are requeued onto the inbound bus so they can be processed normally
- `system` inbound messages are not treated as steering input
- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes
## Steering with media
Steering messages can include `Media` refs, just like normal inbound user
messages.
- The original `media://` refs are preserved in session history via `AddFullMessage`
- Before the next provider call, steering messages go through the normal media resolution pipeline
- Image refs are converted to data URLs for multimodal providers; non-image refs are resolved the same way as standard inbound media
This applies both to in-turn steering and to idle-session continuation through
`Continue()`.
## Notes
- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue.
- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually.
- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once.
- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped.
- Manual `Steer()` calls made outside an active turn still go to the legacy fallback queue, so older integrations keep working.
+104 -29
View File
@@ -64,11 +64,12 @@ type processOptions struct {
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
InitialSteeringMessages []providers.Message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
}
type continuationTarget struct {
@@ -271,11 +272,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
// Start a goroutine that drains the bus while processMessage is
// running. Any inbound messages that arrive during processing are
// redirected into the steering queue so the agent loop can pick
// them up between tool calls.
drainCtx, drainCancel := context.WithCancel(ctx)
go al.drainBusToSteering(drainCtx)
// running. Only messages that resolve to the active turn scope are
// redirected into steering; other inbound messages are requeued.
drainCancel := func() {}
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
drainCtx, cancel := context.WithCancel(ctx)
drainCancel = cancel
go al.drainBusToSteering(drainCtx, activeScope, activeAgentID)
}
// Process message
func() {
@@ -316,13 +320,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return
}
for al.pendingSteeringCount() > 0 {
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
logger.InfoCF("agent", "Continuing queued steering after turn end",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"session_key": target.SessionKey,
"queue_depth": al.pendingSteeringCount(),
"queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
})
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
@@ -349,15 +353,27 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
// drainBusToSteering continuously consumes inbound messages and redirects
// them into the steering queue. It runs in a goroutine while processMessage
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
// messages from the active scope into the steering queue. Messages from other
// scopes are requeued so they can be processed normally after the active turn.
func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) {
for {
msg, ok := al.bus.ConsumeInbound(ctx)
if !ok {
return
}
msgScope, _, scopeOK := al.resolveSteeringTarget(msg)
if !scopeOK || msgScope != activeScope {
if err := al.requeueInboundMessage(msg); err != nil {
logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{
"error": err.Error(),
"channel": msg.Channel,
"sender_id": msg.SenderID,
})
}
return
}
// Transcribe audio if needed before steering, so the agent sees text.
msg, _ = al.transcribeAudioInMessage(ctx, msg)
@@ -366,11 +382,13 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
"channel": msg.Channel,
"sender_id": msg.SenderID,
"content_len": len(msg.Content),
"scope": activeScope,
})
if err := al.Steer(providers.Message{
if err := al.enqueueSteeringMessage(activeScope, activeAgentID, providers.Message{
Role: "user",
Content: msg.Content,
Media: append([]string(nil), msg.Media...),
}); err != nil {
logger.WarnCF("agent", "Failed to steer message, will be lost",
map[string]any{
@@ -1085,6 +1103,25 @@ func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
return route.SessionKey
}
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
if msg.Channel == "system" {
return "", "", false
}
route, agent, err := al.resolveMessageRoute(msg)
if err != nil || agent == nil {
return "", "", false
}
return resolveScopeKey(route, msg.SessionKey), agent.ID, true
}
func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error {
pubCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return al.bus.PublishInbound(pubCtx, msg)
}
func (al *AgentLoop) processSystemMessage(
ctx context.Context,
msg bus.InboundMessage,
@@ -1346,16 +1383,25 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
}
}
if !ts.opts.NoHistory {
rootMsg := providers.Message{Role: "user", Content: ts.userMessage}
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
if !ts.opts.NoHistory && (strings.TrimSpace(ts.userMessage) != "" || len(ts.media) > 0) {
rootMsg := providers.Message{
Role: "user",
Content: ts.userMessage,
Media: append([]string(nil), ts.media...),
}
if len(rootMsg.Media) > 0 {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, rootMsg)
} else {
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
}
ts.recordPersistedMessage(rootMsg)
}
activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages)
var pendingMessages []providers.Message
pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...)
var finalContent string
turnLoop:
for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool {
graceful, _ := ts.gracefulInterruptRequested()
return graceful
@@ -1369,19 +1415,24 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.setIteration(iteration)
ts.setPhase(TurnPhaseRunning)
if iteration > 1 || !ts.opts.SkipInitialSteeringPoll {
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
if iteration > 1 {
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
pendingMessages = append(pendingMessages, steerMsgs...)
}
} else if !ts.opts.SkipInitialSteeringPoll {
if steerMsgs := al.dequeueSteeringMessagesForScopeWithFallback(ts.sessionKey); len(steerMsgs) > 0 {
pendingMessages = append(pendingMessages, steerMsgs...)
}
}
if len(pendingMessages) > 0 {
resolvedPending := resolveMediaRefs(pendingMessages, al.mediaStore, maxMediaSize)
totalContentLen := 0
for _, pm := range pendingMessages {
messages = append(messages, pm)
for i, pm := range pendingMessages {
messages = append(messages, resolvedPending[i])
totalContentLen += len(pm.Content)
if !ts.opts.NoHistory {
ts.agent.Sessions.AddMessage(ts.sessionKey, pm.Role, pm.Content)
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
ts.recordPersistedMessage(pm)
}
logger.InfoCF("agent", "Injected steering message into context",
@@ -1389,6 +1440,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
"agent_id": ts.agent.ID,
"iteration": iteration,
"content_len": len(pm.Content),
"media_count": len(pm.Media),
})
}
al.emitEvent(
@@ -1660,10 +1712,21 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
})
if len(response.ToolCalls) == 0 || gracefulTerminal {
finalContent = response.Content
if finalContent == "" && response.ReasoningContent != "" {
finalContent = response.ReasoningContent
responseContent := response.Content
if responseContent == "" && response.ReasoningContent != "" {
responseContent = response.ReasoningContent
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
logger.InfoCF("agent", "Steering arrived after direct LLM response; continuing turn",
map[string]any{
"agent_id": ts.agent.ID,
"iteration": iteration,
"steering_count": len(steerMsgs),
})
pendingMessages = append(pendingMessages, steerMsgs...)
continue
}
finalContent = responseContent
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
map[string]any{
"agent_id": ts.agent.ID,
@@ -1870,7 +1933,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.recordPersistedMessage(toolResultMsg)
}
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
pendingMessages = append(pendingMessages, steerMsgs...)
}
@@ -1926,6 +1989,18 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
})
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
logger.InfoCF("agent", "Steering arrived after turn completion; continuing turn before finalizing",
map[string]any{
"agent_id": ts.agent.ID,
"steering_count": len(steerMsgs),
"session_key": ts.sessionKey,
})
pendingMessages = append(pendingMessages, steerMsgs...)
finalContent = ""
goto turnLoop
}
if ts.hardAbortRequested() {
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
+176 -47
View File
@@ -8,6 +8,7 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
)
// SteeringMode controls how queued steering messages are dequeued.
@@ -20,6 +21,9 @@ const (
SteeringAll SteeringMode = "all"
// MaxQueueSize number of possible messages in the Steering Queue
MaxQueueSize = 10
// manualSteeringScope is the legacy fallback queue used when no active
// turn/session scope is available.
manualSteeringScope = "__manual__"
)
// parseSteeringMode normalizes a config string into a SteeringMode.
@@ -35,56 +39,117 @@ func parseSteeringMode(s string) SteeringMode {
// steeringQueue is a thread-safe queue of user messages that can be injected
// into a running agent loop to interrupt it between tool calls.
type steeringQueue struct {
mu sync.Mutex
queue []providers.Message
mode SteeringMode
mu sync.Mutex
queues map[string][]providers.Message
mode SteeringMode
}
func newSteeringQueue(mode SteeringMode) *steeringQueue {
return &steeringQueue{
mode: mode,
queues: make(map[string][]providers.Message),
mode: mode,
}
}
// push enqueues a steering message.
func normalizeSteeringScope(scope string) string {
scope = strings.TrimSpace(scope)
if scope == "" {
return manualSteeringScope
}
return scope
}
// push enqueues a steering message in the legacy fallback scope.
func (sq *steeringQueue) push(msg providers.Message) error {
return sq.pushScope(manualSteeringScope, msg)
}
// pushScope enqueues a steering message for the provided scope.
func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) >= MaxQueueSize {
scope = normalizeSteeringScope(scope)
queue := sq.queues[scope]
if len(queue) >= MaxQueueSize {
return fmt.Errorf("steering queue is full")
}
sq.queue = append(sq.queue, msg)
sq.queues[scope] = append(queue, msg)
return nil
}
// dequeue removes and returns pending steering messages according to the
// configured mode. Returns nil when the queue is empty.
// dequeue removes and returns pending steering messages from the legacy
// fallback scope according to the configured mode.
func (sq *steeringQueue) dequeue() []providers.Message {
return sq.dequeueScope(manualSteeringScope)
}
// dequeueScope removes and returns pending steering messages for the provided
// scope according to the configured mode.
func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) == 0 {
return sq.dequeueLocked(normalizeSteeringScope(scope))
}
// dequeueScopeWithFallback drains the scoped queue first and falls back to the
// legacy manual scope for backwards compatibility.
func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
scope = strings.TrimSpace(scope)
if scope != "" {
if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
return msgs
}
}
return sq.dequeueLocked(manualSteeringScope)
}
func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
queue := sq.queues[scope]
if len(queue) == 0 {
return nil
}
switch sq.mode {
case SteeringAll:
msgs := sq.queue
sq.queue = nil
msgs := append([]providers.Message(nil), queue...)
delete(sq.queues, scope)
return msgs
default: // one-at-a-time
msg := sq.queue[0]
sq.queue[0] = providers.Message{} // Clear reference for GC
sq.queue = sq.queue[1:]
default:
msg := queue[0]
queue[0] = providers.Message{} // Clear reference for GC
queue = queue[1:]
if len(queue) == 0 {
delete(sq.queues, scope)
} else {
sq.queues[scope] = queue
}
return []providers.Message{msg}
}
}
// len returns the number of queued messages.
// len returns the number of queued messages across all scopes.
func (sq *steeringQueue) len() int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queue)
total := 0
for _, queue := range sq.queues {
total += len(queue)
}
return total
}
// lenScope returns the number of queued messages for a specific scope.
func (sq *steeringQueue) lenScope(scope string) int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queues[normalizeSteeringScope(scope)])
}
// setMode updates the steering mode.
@@ -101,26 +166,40 @@ func (sq *steeringQueue) getMode() SteeringMode {
return sq.mode
}
// --- AgentLoop steering API ---
// Steer enqueues a user message to be injected into the currently running
// agent loop. The message will be picked up after the current tool finishes
// executing, causing any remaining tool calls in the batch to be skipped.
func (al *AgentLoop) Steer(msg providers.Message) error {
scope := ""
agentID := ""
if ts := al.getActiveTurnState(); ts != nil {
scope = ts.sessionKey
agentID = ts.agentID
}
return al.enqueueSteeringMessage(scope, agentID, msg)
}
func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
if al.steering == nil {
return fmt.Errorf("steering queue is not initialized")
}
if err := al.steering.push(msg); err != nil {
if err := al.steering.pushScope(scope, msg); err != nil {
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
"error": err.Error(),
"role": msg.Role,
"scope": normalizeSteeringScope(scope),
})
return err
}
queueDepth := al.steering.lenScope(scope)
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
"role": msg.Role,
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
"media_count": len(msg.Media),
"queue_len": queueDepth,
"scope": normalizeSteeringScope(scope),
})
meta := EventMeta{
@@ -129,11 +208,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
}
if ts := al.getActiveTurnState(); ts != nil {
meta = ts.eventMeta("Steer", "turn.interrupt.received")
} else if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
meta.AgentID = agent.ID
} else {
if strings.TrimSpace(agentID) != "" {
meta.AgentID = agentID
}
normalizedScope := normalizeSteeringScope(scope)
if normalizedScope != manualSteeringScope {
meta.SessionKey = normalizedScope
}
if meta.AgentID == "" {
if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
meta.AgentID = agent.ID
}
}
}
}
al.emitEvent(
EventKindInterruptReceived,
meta,
@@ -141,7 +232,7 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
Kind: InterruptKindSteering,
Role: msg.Role,
ContentLen: len(msg.Content),
QueueDepth: al.steering.len(),
QueueDepth: queueDepth,
},
)
@@ -165,7 +256,7 @@ func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
}
// dequeueSteeringMessages is the internal method called by the agent loop
// to poll for steering messages. Returns nil when no messages are pending.
// to poll for steering messages in the legacy fallback scope.
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
if al.steering == nil {
return nil
@@ -173,6 +264,60 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
return al.steering.dequeue()
}
func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeueScope(scope)
}
func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeueScopeWithFallback(scope)
}
func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
if al.steering == nil {
return 0
}
return al.steering.lenScope(scope)
}
func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
sessionKey, channel, chatID string,
steeringMsgs []providers.Message,
) (string, error) {
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
InitialSteeringMessages: steeringMsgs,
SkipInitialSteeringPoll: true,
})
}
func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
registry := al.GetRegistry()
if registry == nil {
return nil
}
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
return agent
}
}
return registry.GetDefaultAgent()
}
// Continue resumes an idle agent by dequeuing any pending steering messages
// and running them through the agent loop. This is used when the agent's last
// message was from the assistant (i.e., it has stopped processing) and the
@@ -184,14 +329,14 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
steeringMsgs := al.dequeueSteeringMessages()
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
if len(steeringMsgs) == 0 {
return "", nil
}
agent := al.GetRegistry().GetDefaultAgent()
agent := al.agentForSession(sessionKey)
if agent == nil {
return "", fmt.Errorf("no default agent available")
return "", fmt.Errorf("no agent available for session %q", sessionKey)
}
if tool, ok := agent.Tools.Get("message"); ok {
@@ -200,23 +345,7 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
}
}
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
contents = append(contents, msg.Content)
}
combinedContent := strings.Join(contents, "\n")
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
UserMessage: combinedContent,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SkipInitialSteeringPoll: true,
})
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
+331 -9
View File
@@ -5,13 +5,16 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -337,6 +340,96 @@ func TestAgentLoop_Continue_WithMessages(t *testing.T) {
}
}
func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Session: config.SessionConfig{
DMScope: "per-peer",
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
activeMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "active turn",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
if !ok {
t.Fatal("expected active message to resolve to a steering scope")
}
otherMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user2",
ChatID: "chat2",
Content: "other session",
Peer: bus.Peer{
Kind: "direct",
ID: "user2",
},
}
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
if !ok {
t.Fatal("expected other message to resolve to a steering scope")
}
if otherScope == activeScope {
t.Fatalf("expected different steering scopes, got same scope %q", activeScope)
}
if err := msgBus.PublishInbound(context.Background(), otherMsg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
done := make(chan struct{})
go func() {
al.drainBusToSteering(ctx, activeScope, activeAgentID)
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for drainBusToSteering to stop")
}
if msgs := al.dequeueSteeringMessagesForScope(activeScope); len(msgs) != 0 {
t.Fatalf("expected no steering messages for active scope, got %v", msgs)
}
requeued, ok := msgBus.ConsumeInbound(context.Background())
if !ok {
t.Fatal("expected message to be requeued on the inbound bus")
}
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
requeued.SenderID != otherMsg.SenderID || requeued.Content != otherMsg.Content {
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
}
}
// slowTool simulates a tool that takes some time to execute.
type slowTool struct {
name string
@@ -472,6 +565,52 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
return "late-steering-mock"
}
type blockingDirectProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
firstResp string
finalResp string
}
func (p *blockingDirectProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
firstResp := p.firstResp
finalResp := p.finalResp
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
p.firstStarted = nil
}
p.mu.Unlock()
if call == 1 {
select {
case <-releaseFirst:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: firstResp}, nil
}
_ = firstStarted
return &providers.LLMResponse{Content: finalResp}, nil
}
func (p *blockingDirectProvider) GetDefaultModel() string {
return "blocking-direct-mock"
}
type interruptibleTool struct {
name string
started chan struct{}
@@ -744,18 +883,16 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
out1, ok := msgBus.SubscribeOutbound(subCtx)
if !ok {
t.Fatal("expected first outbound response")
t.Fatal("expected outbound response")
}
if out1.Content != "first response" {
t.Fatalf("expected first response, got %q", out1.Content)
if out1.Content != "continued response" {
t.Fatalf("expected continued response, got %q", out1.Content)
}
out2, ok := msgBus.SubscribeOutbound(subCtx)
if !ok {
t.Fatal("expected continued outbound response")
}
if out2.Content != "continued response" {
t.Fatalf("expected continued response, got %q", out2.Content)
noExtraCtx, cancelNoExtra := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancelNoExtra()
if out2, ok := msgBus.SubscribeOutbound(noExtraCtx); ok {
t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content)
}
cancelRun()
@@ -789,6 +926,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
}
}
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
firstResp: "stale direct response",
finalResp: "fresh response after steering",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
resultCh := make(chan struct {
resp string
err error
}, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"initial request",
sessionKey,
"test",
"chat1",
)
resultCh <- struct {
resp string
err error
}{resp: resp, err: err}
}()
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := al.Steer(providers.Message{Role: "user", Content: "follow-up instruction"}); err != nil {
t.Fatalf("Steer failed: %v", err)
}
close(provider.releaseFirst)
select {
case result := <-resultCh:
if result.err != nil {
t.Fatalf("unexpected error: %v", result.err)
}
if result.resp != "fresh response after steering" {
t.Fatalf("expected refreshed response, got %q", result.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for ProcessDirectWithChannel")
}
provider.mu.Lock()
calls := provider.calls
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
t.Fatalf("expected steering queue to be empty after continuation, got %v", msgs)
}
}
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
store := media.NewFileMediaStore()
pngPath := filepath.Join(tmpDir, "steer.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D,
0x49, 0x48, 0x44, 0x52,
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
0x00, 0x00, 0x00,
0x90, 0x77, 0x53, 0xDE,
}
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
if err != nil {
t.Fatalf("Store failed: %v", err)
}
var capturedMessages []providers.Message
var capMu sync.Mutex
provider := &capturingMockProvider{
response: "ack",
captureFn: func(msgs []providers.Message) {
capMu.Lock()
defer capMu.Unlock()
capturedMessages = append([]providers.Message(nil), msgs...)
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.SetMediaStore(store)
if err := al.Steer(providers.Message{
Role: "user",
Content: "describe this image",
Media: []string{ref},
}); err != nil {
t.Fatalf("Steer failed: %v", err)
}
resp, err := al.Continue(context.Background(), sessionKey, "test", "chat1")
if err != nil {
t.Fatalf("Continue failed: %v", err)
}
if resp != "ack" {
t.Fatalf("expected ack, got %q", resp)
}
capMu.Lock()
msgs := append([]providers.Message(nil), capturedMessages...)
capMu.Unlock()
foundResolvedMedia := false
for _, msg := range msgs {
if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 {
continue
}
if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
foundResolvedMedia = true
break
}
}
if !foundResolvedMedia {
t.Fatal("expected continue path to inject steering media into the provider request")
}
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
history := defaultAgent.Sessions.GetHistory(sessionKey)
foundOriginalRef := false
for _, msg := range history {
if msg.Role == "user" && len(msg.Media) == 1 && msg.Media[0] == ref {
foundOriginalRef = true
break
}
}
if !foundOriginalRef {
t.Fatal("expected original steering media ref to be preserved in session history")
}
}
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {