mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #1844 from afjcjsbx/fix/scope-steering
fix(agent) scope steering
This commit is contained in:
+34
-1
@@ -21,6 +21,18 @@ Agent Loop ▼
|
||||
└─ new LLM turn with steering message
|
||||
```
|
||||
|
||||
## Scoped queues
|
||||
|
||||
Steering is now isolated per resolved session scope, not stored in a single
|
||||
global queue.
|
||||
|
||||
- The active turn writes and reads from its own scope key (usually the routed session key such as `agent:<agent_id>:...`)
|
||||
- `Steer()` still works outside an active turn through a legacy fallback queue
|
||||
- `Continue()` first dequeues messages for the requested session scope, then falls back to the legacy queue for backwards compatibility
|
||||
|
||||
This prevents a message arriving from another chat, DM peer, or routed agent
|
||||
session from being injected into the wrong conversation.
|
||||
|
||||
## Configuration
|
||||
|
||||
In `config.json`, under `agents.defaults`:
|
||||
@@ -86,12 +98,18 @@ if response == "" {
|
||||
|
||||
`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input).
|
||||
|
||||
`Continue` also resolves the target agent from the provided session key, so
|
||||
agent-scoped sessions continue on the correct agent instead of always using
|
||||
the default one.
|
||||
|
||||
## Polling points in the loop
|
||||
|
||||
Steering is checked at **two points** in the agent cycle:
|
||||
Steering is checked at the following points in the agent cycle:
|
||||
|
||||
1. **At loop start** — before the first LLM call, to catch messages enqueued during setup
|
||||
2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately
|
||||
3. **After a direct LLM response** — if a new steering message arrived while the model was generating a non-tool response, the loop continues instead of returning a stale answer
|
||||
4. **Right before the turn is finalized** — if steering arrived at the very end of the turn, the agent immediately starts a continuation turn instead of leaving the message orphaned in the queue
|
||||
|
||||
## Why remaining tools are skipped
|
||||
|
||||
@@ -156,11 +174,26 @@ When the agent loop (`Run()`) starts processing a message, it spawns a backgroun
|
||||
|
||||
- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy
|
||||
- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is
|
||||
- Only messages that resolve to the **same steering scope** as the active turn are redirected. Messages for other chats/sessions are requeued onto the inbound bus so they can be processed normally
|
||||
- `system` inbound messages are not treated as steering input
|
||||
- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes
|
||||
|
||||
## Steering with media
|
||||
|
||||
Steering messages can include `Media` refs, just like normal inbound user
|
||||
messages.
|
||||
|
||||
- The original `media://` refs are preserved in session history via `AddFullMessage`
|
||||
- Before the next provider call, steering messages go through the normal media resolution pipeline
|
||||
- Image refs are converted to data URLs for multimodal providers; non-image refs are resolved the same way as standard inbound media
|
||||
|
||||
This applies both to in-turn steering and to idle-session continuation through
|
||||
`Continue()`.
|
||||
|
||||
## Notes
|
||||
|
||||
- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue.
|
||||
- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually.
|
||||
- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once.
|
||||
- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped.
|
||||
- Manual `Steer()` calls made outside an active turn still go to the legacy fallback queue, so older integrations keep working.
|
||||
|
||||
+151
-42
@@ -64,11 +64,12 @@ type processOptions struct {
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
|
||||
InitialSteeringMessages []providers.Message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
|
||||
}
|
||||
|
||||
type continuationTarget struct {
|
||||
@@ -271,11 +272,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Start a goroutine that drains the bus while processMessage is
|
||||
// running. Any inbound messages that arrive during processing are
|
||||
// redirected into the steering queue so the agent loop can pick
|
||||
// them up between tool calls.
|
||||
drainCtx, drainCancel := context.WithCancel(ctx)
|
||||
go al.drainBusToSteering(drainCtx)
|
||||
// running. Only messages that resolve to the active turn scope are
|
||||
// redirected into steering; other inbound messages are requeued.
|
||||
drainCancel := func() {}
|
||||
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
|
||||
drainCtx, cancel := context.WithCancel(ctx)
|
||||
drainCancel = cancel
|
||||
go al.drainBusToSteering(drainCtx, activeScope, activeAgentID)
|
||||
}
|
||||
|
||||
// Process message
|
||||
func() {
|
||||
@@ -292,16 +296,21 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// }
|
||||
// }()
|
||||
|
||||
defer drainCancel()
|
||||
drainCanceled := false
|
||||
cancelDrain := func() {
|
||||
if drainCanceled {
|
||||
return
|
||||
}
|
||||
drainCancel()
|
||||
drainCanceled = true
|
||||
}
|
||||
defer cancelDrain()
|
||||
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, response)
|
||||
}
|
||||
finalResponse := response
|
||||
|
||||
target, targetErr := al.buildContinuationTarget(msg)
|
||||
if targetErr != nil {
|
||||
@@ -313,16 +322,20 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
return
|
||||
}
|
||||
if target == nil {
|
||||
cancelDrain()
|
||||
if finalResponse != "" {
|
||||
al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, finalResponse)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for al.pendingSteeringCount() > 0 {
|
||||
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
|
||||
logger.InfoCF("agent", "Continuing queued steering after turn end",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"session_key": target.SessionKey,
|
||||
"queue_depth": al.pendingSteeringCount(),
|
||||
"queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
|
||||
})
|
||||
|
||||
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
|
||||
@@ -339,7 +352,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
return
|
||||
}
|
||||
|
||||
al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, continued)
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
cancelDrain()
|
||||
|
||||
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
|
||||
logger.InfoCF("agent", "Draining steering queued during turn shutdown",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"session_key": target.SessionKey,
|
||||
"queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
|
||||
})
|
||||
|
||||
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering after shutdown drain",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if continued == "" {
|
||||
break
|
||||
}
|
||||
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
if finalResponse != "" {
|
||||
al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, finalResponse)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -349,15 +394,27 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// drainBusToSteering continuously consumes inbound messages and redirects
|
||||
// them into the steering queue. It runs in a goroutine while processMessage
|
||||
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
|
||||
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
|
||||
// messages from the active scope into the steering queue. Messages from other
|
||||
// scopes are requeued so they can be processed normally after the active turn.
|
||||
func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) {
|
||||
for {
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
msgScope, _, scopeOK := al.resolveSteeringTarget(msg)
|
||||
if !scopeOK || msgScope != activeScope {
|
||||
if err := al.requeueInboundMessage(msg); err != nil {
|
||||
logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{
|
||||
"error": err.Error(),
|
||||
"channel": msg.Channel,
|
||||
"sender_id": msg.SenderID,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Transcribe audio if needed before steering, so the agent sees text.
|
||||
msg, _ = al.transcribeAudioInMessage(ctx, msg)
|
||||
|
||||
@@ -366,11 +423,13 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
|
||||
"channel": msg.Channel,
|
||||
"sender_id": msg.SenderID,
|
||||
"content_len": len(msg.Content),
|
||||
"scope": activeScope,
|
||||
})
|
||||
|
||||
if err := al.Steer(providers.Message{
|
||||
if err := al.enqueueSteeringMessage(activeScope, activeAgentID, providers.Message{
|
||||
Role: "user",
|
||||
Content: msg.Content,
|
||||
Media: append([]string(nil), msg.Media...),
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Failed to steer message, will be lost",
|
||||
map[string]any{
|
||||
@@ -422,13 +481,6 @@ func (al *AgentLoop) publishResponseIfNeeded(ctx context.Context, channel, chatI
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) pendingSteeringCount() int {
|
||||
if al.steering == nil {
|
||||
return 0
|
||||
}
|
||||
return al.steering.len()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuationTarget, error) {
|
||||
if msg.Channel == "system" {
|
||||
return nil, nil
|
||||
@@ -1085,6 +1137,25 @@ func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
|
||||
return route.SessionKey
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
|
||||
if msg.Channel == "system" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
route, agent, err := al.resolveMessageRoute(msg)
|
||||
if err != nil || agent == nil {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
return resolveScopeKey(route, msg.SessionKey), agent.ID, true
|
||||
}
|
||||
|
||||
func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error {
|
||||
pubCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
return al.bus.PublishInbound(pubCtx, msg)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
@@ -1346,16 +1417,25 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
}
|
||||
}
|
||||
|
||||
if !ts.opts.NoHistory {
|
||||
rootMsg := providers.Message{Role: "user", Content: ts.userMessage}
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
|
||||
if !ts.opts.NoHistory && (strings.TrimSpace(ts.userMessage) != "" || len(ts.media) > 0) {
|
||||
rootMsg := providers.Message{
|
||||
Role: "user",
|
||||
Content: ts.userMessage,
|
||||
Media: append([]string(nil), ts.media...),
|
||||
}
|
||||
if len(rootMsg.Media) > 0 {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, rootMsg)
|
||||
} else {
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
|
||||
}
|
||||
ts.recordPersistedMessage(rootMsg)
|
||||
}
|
||||
|
||||
activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
var pendingMessages []providers.Message
|
||||
pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...)
|
||||
var finalContent string
|
||||
|
||||
turnLoop:
|
||||
for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool {
|
||||
graceful, _ := ts.gracefulInterruptRequested()
|
||||
return graceful
|
||||
@@ -1369,19 +1449,24 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.setIteration(iteration)
|
||||
ts.setPhase(TurnPhaseRunning)
|
||||
|
||||
if iteration > 1 || !ts.opts.SkipInitialSteeringPoll {
|
||||
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
|
||||
if iteration > 1 {
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
}
|
||||
} else if !ts.opts.SkipInitialSteeringPoll {
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScopeWithFallback(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(pendingMessages) > 0 {
|
||||
resolvedPending := resolveMediaRefs(pendingMessages, al.mediaStore, maxMediaSize)
|
||||
totalContentLen := 0
|
||||
for _, pm := range pendingMessages {
|
||||
messages = append(messages, pm)
|
||||
for i, pm := range pendingMessages {
|
||||
messages = append(messages, resolvedPending[i])
|
||||
totalContentLen += len(pm.Content)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, pm.Role, pm.Content)
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
|
||||
ts.recordPersistedMessage(pm)
|
||||
}
|
||||
logger.InfoCF("agent", "Injected steering message into context",
|
||||
@@ -1389,6 +1474,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_len": len(pm.Content),
|
||||
"media_count": len(pm.Media),
|
||||
})
|
||||
}
|
||||
al.emitEvent(
|
||||
@@ -1660,10 +1746,21 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
})
|
||||
|
||||
if len(response.ToolCalls) == 0 || gracefulTerminal {
|
||||
finalContent = response.Content
|
||||
if finalContent == "" && response.ReasoningContent != "" {
|
||||
finalContent = response.ReasoningContent
|
||||
responseContent := response.Content
|
||||
if responseContent == "" && response.ReasoningContent != "" {
|
||||
responseContent = response.ReasoningContent
|
||||
}
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
logger.InfoCF("agent", "Steering arrived after direct LLM response; continuing turn",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"steering_count": len(steerMsgs),
|
||||
})
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
continue
|
||||
}
|
||||
finalContent = responseContent
|
||||
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
@@ -1870,7 +1967,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.recordPersistedMessage(toolResultMsg)
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
}
|
||||
|
||||
@@ -1926,6 +2023,18 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
})
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
logger.InfoCF("agent", "Steering arrived after turn completion; continuing turn before finalizing",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"steering_count": len(steerMsgs),
|
||||
"session_key": ts.sessionKey,
|
||||
})
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
finalContent = ""
|
||||
goto turnLoop
|
||||
}
|
||||
|
||||
if ts.hardAbortRequested() {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
|
||||
+176
-47
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
)
|
||||
|
||||
// SteeringMode controls how queued steering messages are dequeued.
|
||||
@@ -20,6 +21,9 @@ const (
|
||||
SteeringAll SteeringMode = "all"
|
||||
// MaxQueueSize number of possible messages in the Steering Queue
|
||||
MaxQueueSize = 10
|
||||
// manualSteeringScope is the legacy fallback queue used when no active
|
||||
// turn/session scope is available.
|
||||
manualSteeringScope = "__manual__"
|
||||
)
|
||||
|
||||
// parseSteeringMode normalizes a config string into a SteeringMode.
|
||||
@@ -35,56 +39,117 @@ func parseSteeringMode(s string) SteeringMode {
|
||||
// steeringQueue is a thread-safe queue of user messages that can be injected
|
||||
// into a running agent loop to interrupt it between tool calls.
|
||||
type steeringQueue struct {
|
||||
mu sync.Mutex
|
||||
queue []providers.Message
|
||||
mode SteeringMode
|
||||
mu sync.Mutex
|
||||
queues map[string][]providers.Message
|
||||
mode SteeringMode
|
||||
}
|
||||
|
||||
func newSteeringQueue(mode SteeringMode) *steeringQueue {
|
||||
return &steeringQueue{
|
||||
mode: mode,
|
||||
queues: make(map[string][]providers.Message),
|
||||
mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// push enqueues a steering message.
|
||||
func normalizeSteeringScope(scope string) string {
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope == "" {
|
||||
return manualSteeringScope
|
||||
}
|
||||
return scope
|
||||
}
|
||||
|
||||
// push enqueues a steering message in the legacy fallback scope.
|
||||
func (sq *steeringQueue) push(msg providers.Message) error {
|
||||
return sq.pushScope(manualSteeringScope, msg)
|
||||
}
|
||||
|
||||
// pushScope enqueues a steering message for the provided scope.
|
||||
func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
if len(sq.queue) >= MaxQueueSize {
|
||||
|
||||
scope = normalizeSteeringScope(scope)
|
||||
queue := sq.queues[scope]
|
||||
if len(queue) >= MaxQueueSize {
|
||||
return fmt.Errorf("steering queue is full")
|
||||
}
|
||||
sq.queue = append(sq.queue, msg)
|
||||
sq.queues[scope] = append(queue, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// dequeue removes and returns pending steering messages according to the
|
||||
// configured mode. Returns nil when the queue is empty.
|
||||
// dequeue removes and returns pending steering messages from the legacy
|
||||
// fallback scope according to the configured mode.
|
||||
func (sq *steeringQueue) dequeue() []providers.Message {
|
||||
return sq.dequeueScope(manualSteeringScope)
|
||||
}
|
||||
|
||||
// dequeueScope removes and returns pending steering messages for the provided
|
||||
// scope according to the configured mode.
|
||||
func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
if len(sq.queue) == 0 {
|
||||
return sq.dequeueLocked(normalizeSteeringScope(scope))
|
||||
}
|
||||
|
||||
// dequeueScopeWithFallback drains the scoped queue first and falls back to the
|
||||
// legacy manual scope for backwards compatibility.
|
||||
func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
scope = strings.TrimSpace(scope)
|
||||
if scope != "" {
|
||||
if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
|
||||
return msgs
|
||||
}
|
||||
}
|
||||
|
||||
return sq.dequeueLocked(manualSteeringScope)
|
||||
}
|
||||
|
||||
func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
|
||||
queue := sq.queues[scope]
|
||||
if len(queue) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch sq.mode {
|
||||
case SteeringAll:
|
||||
msgs := sq.queue
|
||||
sq.queue = nil
|
||||
msgs := append([]providers.Message(nil), queue...)
|
||||
delete(sq.queues, scope)
|
||||
return msgs
|
||||
default: // one-at-a-time
|
||||
msg := sq.queue[0]
|
||||
sq.queue[0] = providers.Message{} // Clear reference for GC
|
||||
sq.queue = sq.queue[1:]
|
||||
default:
|
||||
msg := queue[0]
|
||||
queue[0] = providers.Message{} // Clear reference for GC
|
||||
queue = queue[1:]
|
||||
if len(queue) == 0 {
|
||||
delete(sq.queues, scope)
|
||||
} else {
|
||||
sq.queues[scope] = queue
|
||||
}
|
||||
return []providers.Message{msg}
|
||||
}
|
||||
}
|
||||
|
||||
// len returns the number of queued messages.
|
||||
// len returns the number of queued messages across all scopes.
|
||||
func (sq *steeringQueue) len() int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return len(sq.queue)
|
||||
|
||||
total := 0
|
||||
for _, queue := range sq.queues {
|
||||
total += len(queue)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// lenScope returns the number of queued messages for a specific scope.
|
||||
func (sq *steeringQueue) lenScope(scope string) int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
return len(sq.queues[normalizeSteeringScope(scope)])
|
||||
}
|
||||
|
||||
// setMode updates the steering mode.
|
||||
@@ -101,26 +166,40 @@ func (sq *steeringQueue) getMode() SteeringMode {
|
||||
return sq.mode
|
||||
}
|
||||
|
||||
// --- AgentLoop steering API ---
|
||||
|
||||
// Steer enqueues a user message to be injected into the currently running
|
||||
// agent loop. The message will be picked up after the current tool finishes
|
||||
// executing, causing any remaining tool calls in the batch to be skipped.
|
||||
func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
scope := ""
|
||||
agentID := ""
|
||||
if ts := al.getActiveTurnState(); ts != nil {
|
||||
scope = ts.sessionKey
|
||||
agentID = ts.agentID
|
||||
}
|
||||
return al.enqueueSteeringMessage(scope, agentID, msg)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
|
||||
if al.steering == nil {
|
||||
return fmt.Errorf("steering queue is not initialized")
|
||||
}
|
||||
if err := al.steering.push(msg); err != nil {
|
||||
|
||||
if err := al.steering.pushScope(scope, msg); err != nil {
|
||||
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
|
||||
"error": err.Error(),
|
||||
"role": msg.Role,
|
||||
"scope": normalizeSteeringScope(scope),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
queueDepth := al.steering.lenScope(scope)
|
||||
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
|
||||
"role": msg.Role,
|
||||
"content_len": len(msg.Content),
|
||||
"queue_len": al.steering.len(),
|
||||
"media_count": len(msg.Media),
|
||||
"queue_len": queueDepth,
|
||||
"scope": normalizeSteeringScope(scope),
|
||||
})
|
||||
|
||||
meta := EventMeta{
|
||||
@@ -129,11 +208,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
}
|
||||
if ts := al.getActiveTurnState(); ts != nil {
|
||||
meta = ts.eventMeta("Steer", "turn.interrupt.received")
|
||||
} else if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
meta.AgentID = agent.ID
|
||||
} else {
|
||||
if strings.TrimSpace(agentID) != "" {
|
||||
meta.AgentID = agentID
|
||||
}
|
||||
normalizedScope := normalizeSteeringScope(scope)
|
||||
if normalizedScope != manualSteeringScope {
|
||||
meta.SessionKey = normalizedScope
|
||||
}
|
||||
if meta.AgentID == "" {
|
||||
if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
meta.AgentID = agent.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
meta,
|
||||
@@ -141,7 +232,7 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
Kind: InterruptKindSteering,
|
||||
Role: msg.Role,
|
||||
ContentLen: len(msg.Content),
|
||||
QueueDepth: al.steering.len(),
|
||||
QueueDepth: queueDepth,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -165,7 +256,7 @@ func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
|
||||
}
|
||||
|
||||
// dequeueSteeringMessages is the internal method called by the agent loop
|
||||
// to poll for steering messages. Returns nil when no messages are pending.
|
||||
// to poll for steering messages in the legacy fallback scope.
|
||||
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
@@ -173,6 +264,60 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
return al.steering.dequeue()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeueScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
|
||||
if al.steering == nil {
|
||||
return nil
|
||||
}
|
||||
return al.steering.dequeueScopeWithFallback(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
|
||||
if al.steering == nil {
|
||||
return 0
|
||||
}
|
||||
return al.steering.lenScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
sessionKey, channel, chatID string,
|
||||
steeringMsgs []providers.Message,
|
||||
) (string, error) {
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
InitialSteeringMessages: steeringMsgs,
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
registry := al.GetRegistry()
|
||||
if registry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
|
||||
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
||||
return agent
|
||||
}
|
||||
}
|
||||
|
||||
return registry.GetDefaultAgent()
|
||||
}
|
||||
|
||||
// Continue resumes an idle agent by dequeuing any pending steering messages
|
||||
// and running them through the agent loop. This is used when the agent's last
|
||||
// message was from the assistant (i.e., it has stopped processing) and the
|
||||
@@ -184,14 +329,14 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
||||
}
|
||||
|
||||
steeringMsgs := al.dequeueSteeringMessages()
|
||||
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
|
||||
if len(steeringMsgs) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
agent := al.agentForSession(sessionKey)
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent available")
|
||||
return "", fmt.Errorf("no agent available for session %q", sessionKey)
|
||||
}
|
||||
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
@@ -200,23 +345,7 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
}
|
||||
}
|
||||
|
||||
// Build a combined user message from the steering messages.
|
||||
var contents []string
|
||||
for _, msg := range steeringMsgs {
|
||||
contents = append(contents, msg.Content)
|
||||
}
|
||||
combinedContent := strings.Join(contents, "\n")
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
UserMessage: combinedContent,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
|
||||
+331
-9
@@ -5,13 +5,16 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -337,6 +340,96 @@ func TestAgentLoop_Continue_WithMessages(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-peer",
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
|
||||
|
||||
activeMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "active turn",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
|
||||
if !ok {
|
||||
t.Fatal("expected active message to resolve to a steering scope")
|
||||
}
|
||||
|
||||
otherMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user2",
|
||||
ChatID: "chat2",
|
||||
Content: "other session",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user2",
|
||||
},
|
||||
}
|
||||
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
|
||||
if !ok {
|
||||
t.Fatal("expected other message to resolve to a steering scope")
|
||||
}
|
||||
if otherScope == activeScope {
|
||||
t.Fatalf("expected different steering scopes, got same scope %q", activeScope)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), otherMsg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
al.drainBusToSteering(ctx, activeScope, activeAgentID)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for drainBusToSteering to stop")
|
||||
}
|
||||
|
||||
if msgs := al.dequeueSteeringMessagesForScope(activeScope); len(msgs) != 0 {
|
||||
t.Fatalf("expected no steering messages for active scope, got %v", msgs)
|
||||
}
|
||||
|
||||
requeued, ok := msgBus.ConsumeInbound(context.Background())
|
||||
if !ok {
|
||||
t.Fatal("expected message to be requeued on the inbound bus")
|
||||
}
|
||||
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
|
||||
requeued.SenderID != otherMsg.SenderID || requeued.Content != otherMsg.Content {
|
||||
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// slowTool simulates a tool that takes some time to execute.
|
||||
type slowTool struct {
|
||||
name string
|
||||
@@ -472,6 +565,52 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
|
||||
return "late-steering-mock"
|
||||
}
|
||||
|
||||
type blockingDirectProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
firstResp string
|
||||
finalResp string
|
||||
}
|
||||
|
||||
func (p *blockingDirectProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
call := p.calls
|
||||
firstStarted := p.firstStarted
|
||||
releaseFirst := p.releaseFirst
|
||||
firstResp := p.firstResp
|
||||
finalResp := p.finalResp
|
||||
if call == 1 && p.firstStarted != nil {
|
||||
close(p.firstStarted)
|
||||
p.firstStarted = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
if call == 1 {
|
||||
select {
|
||||
case <-releaseFirst:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return &providers.LLMResponse{Content: firstResp}, nil
|
||||
}
|
||||
|
||||
_ = firstStarted
|
||||
return &providers.LLMResponse{Content: finalResp}, nil
|
||||
}
|
||||
|
||||
func (p *blockingDirectProvider) GetDefaultModel() string {
|
||||
return "blocking-direct-mock"
|
||||
}
|
||||
|
||||
type interruptibleTool struct {
|
||||
name string
|
||||
started chan struct{}
|
||||
@@ -744,18 +883,16 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
|
||||
out1, ok := msgBus.SubscribeOutbound(subCtx)
|
||||
if !ok {
|
||||
t.Fatal("expected first outbound response")
|
||||
t.Fatal("expected outbound response")
|
||||
}
|
||||
if out1.Content != "first response" {
|
||||
t.Fatalf("expected first response, got %q", out1.Content)
|
||||
if out1.Content != "continued response" {
|
||||
t.Fatalf("expected continued response, got %q", out1.Content)
|
||||
}
|
||||
|
||||
out2, ok := msgBus.SubscribeOutbound(subCtx)
|
||||
if !ok {
|
||||
t.Fatal("expected continued outbound response")
|
||||
}
|
||||
if out2.Content != "continued response" {
|
||||
t.Fatalf("expected continued response, got %q", out2.Content)
|
||||
noExtraCtx, cancelNoExtra := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancelNoExtra()
|
||||
if out2, ok := msgBus.SubscribeOutbound(noExtraCtx); ok {
|
||||
t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content)
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
@@ -789,6 +926,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
firstResp: "stale direct response",
|
||||
finalResp: "fresh response after steering",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
resultCh := make(chan struct {
|
||||
resp string
|
||||
err error
|
||||
}, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"initial request",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- struct {
|
||||
resp string
|
||||
err error
|
||||
}{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
if err := al.Steer(providers.Message{Role: "user", Content: "follow-up instruction"}); err != nil {
|
||||
t.Fatalf("Steer failed: %v", err)
|
||||
}
|
||||
close(provider.releaseFirst)
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
t.Fatalf("unexpected error: %v", result.err)
|
||||
}
|
||||
if result.resp != "fresh response after steering" {
|
||||
t.Fatalf("expected refreshed response, got %q", result.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for ProcessDirectWithChannel")
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
|
||||
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
|
||||
t.Fatalf("expected steering queue to be empty after continuation, got %v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
pngPath := filepath.Join(tmpDir, "steer.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D,
|
||||
0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
||||
0x00, 0x00, 0x00,
|
||||
0x90, 0x77, 0x53, 0xDE,
|
||||
}
|
||||
if err = os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile failed: %v", err)
|
||||
}
|
||||
ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Store failed: %v", err)
|
||||
}
|
||||
|
||||
var capturedMessages []providers.Message
|
||||
var capMu sync.Mutex
|
||||
provider := &capturingMockProvider{
|
||||
response: "ack",
|
||||
captureFn: func(msgs []providers.Message) {
|
||||
capMu.Lock()
|
||||
defer capMu.Unlock()
|
||||
capturedMessages = append([]providers.Message(nil), msgs...)
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.SetMediaStore(store)
|
||||
|
||||
if err = al.Steer(providers.Message{
|
||||
Role: "user",
|
||||
Content: "describe this image",
|
||||
Media: []string{ref},
|
||||
}); err != nil {
|
||||
t.Fatalf("Steer failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.Continue(context.Background(), sessionKey, "test", "chat1")
|
||||
if err != nil {
|
||||
t.Fatalf("Continue failed: %v", err)
|
||||
}
|
||||
if resp != "ack" {
|
||||
t.Fatalf("expected ack, got %q", resp)
|
||||
}
|
||||
|
||||
capMu.Lock()
|
||||
msgs := append([]providers.Message(nil), capturedMessages...)
|
||||
capMu.Unlock()
|
||||
|
||||
foundResolvedMedia := false
|
||||
for _, msg := range msgs {
|
||||
if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
|
||||
foundResolvedMedia = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundResolvedMedia {
|
||||
t.Fatal("expected continue path to inject steering media into the provider request")
|
||||
}
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
foundOriginalRef := false
|
||||
for _, msg := range history {
|
||||
if msg.Role == "user" && len(msg.Media) == 1 && msg.Media[0] == ref {
|
||||
foundOriginalRef = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundOriginalRef {
|
||||
t.Fatal("expected original steering media ref to be preserved in session history")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user