feat(agent): steering (#1517)

* feat(agent): steering

* fix loop

* fix lint

* fix lint
This commit is contained in:
Mauro
2026-03-15 17:08:16 +01:00
committed by GitHub
parent 0f700a6bf0
commit 021aa7d6d5
7 changed files with 1589 additions and 102 deletions
+183 -102
View File
@@ -48,6 +48,7 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
steering *steeringQueue
mu sync.RWMutex
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
@@ -55,15 +56,16 @@ type AgentLoop struct {
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
}
const (
@@ -105,6 +107,7 @@ func NewAgentLoop(
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
return al
@@ -257,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
continue
}
// Start a goroutine that drains the bus while processMessage is
// running. Any inbound messages that arrive during processing are
// redirected into the steering queue so the agent loop can pick
// them up between tool calls.
drainCtx, drainCancel := context.WithCancel(ctx)
go al.drainBusToSteering(drainCtx)
// Process message
func() {
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
@@ -272,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// }
// }()
defer drainCancel()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
@@ -318,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return nil
}
// drainBusToSteering continuously consumes inbound messages and redirects
// them into the steering queue. It runs in a goroutine while processMessage
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
for {
msg, ok := al.bus.ConsumeInbound(ctx)
if !ok {
return
}
// Transcribe audio if needed before steering, so the agent sees text.
msg, _ = al.transcribeAudioInMessage(ctx, msg)
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
map[string]any{
"channel": msg.Channel,
"sender_id": msg.SenderID,
"content_len": len(msg.Content),
})
if err := al.Steer(providers.Message{
Role: "user",
Content: msg.Content,
}); err != nil {
logger.WarnCF("agent", "Failed to steer message, will be lost",
map[string]any{
"error": err.Error(),
"channel": msg.Channel,
})
}
}
}
func (al *AgentLoop) Stop() {
al.running.Store(false)
}
@@ -999,6 +1044,16 @@ func (al *AgentLoop) runLLMIteration(
) (string, int, error) {
iteration := 0
var finalContent string
var pendingMessages []providers.Message
// Poll for steering messages at loop start (in case the user typed while
// the agent was setting up), unless the caller already provided initial
// steering messages (e.g. Continue).
if !opts.SkipInitialSteeringPoll {
if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 {
pendingMessages = msgs
}
}
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
@@ -1006,9 +1061,25 @@ func (al *AgentLoop) runLLMIteration(
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
for iteration < agent.MaxIterations {
for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
iteration++
// Inject pending steering messages into the conversation context
// before the next LLM call.
if len(pendingMessages) > 0 {
for _, pm := range pendingMessages {
messages = append(messages, pm)
agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"content_len": len(pm.Content),
})
}
pendingMessages = nil
}
logger.DebugCF("agent", "LLM iteration",
map[string]any{
"agent_id": agent.ID,
@@ -1251,107 +1322,83 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls in parallel
type indexedAgentResult struct {
result *tools.ToolResult
tc providers.ToolCall
}
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
var wg sync.WaitGroup
// Execute tool calls sequentially. After each tool completes, check
// for steering messages. If any are found, skip remaining tools.
var steeringAfterTools []providers.Message
for i, tc := range normalizedToolCalls {
agentResults[i].tc = tc
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
wg.Add(1)
go func(idx int, tc providers.ToolCall) {
defer wg.Done()
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
// Create async callback for tools that implement AsyncExecutor.
// When the background work completes, this publishes the result
// as an inbound system message so processSystemMessage routes it
// back to the user via the normal agent loop.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
// Send ForUser content directly to the user (immediate feedback),
// mirroring the synchronous tool execution path.
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: result.ForUser,
})
}
// Determine content for the agent loop (ForLLM or error).
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
if content == "" {
return
}
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": tc.Name,
"content_len": len(content),
"channel": opts.Channel,
})
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", tc.Name),
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
Content: content,
// Create async callback for tools that implement AsyncExecutor.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: result.ForUser,
})
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
agentResults[idx].result = toolResult
}(i, tc)
}
wg.Wait()
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
if content == "" {
return
}
// Process results in original order (send to user, save to session)
for _, r := range agentResults {
// Send ForUser content to user immediately if not Silent
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": tc.Name,
"content_len": len(content),
"channel": opts.Channel,
})
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", tc.Name),
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
Content: content,
})
}
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
tc.Arguments,
opts.Channel,
opts.ChatID,
asyncCallback,
)
// Process tool result
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: r.result.ForUser,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": r.tc.Name,
"content_len": len(r.result.ForUser),
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// If tool returned media refs, publish them as outbound media
if len(r.result.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(r.result.Media))
for _, ref := range r.result.Media {
if len(toolResult.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
@@ -1369,21 +1416,55 @@ func (al *AgentLoop) runLLMIteration(
})
}
// Determine content for LLM based on tool result
contentForLLM := r.result.ForLLM
if contentForLLM == "" && r.result.Err != nil {
contentForLLM = r.result.Err.Error()
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: r.tc.ID,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
// Save tool result message to session
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
// After EVERY tool (including the first and last), check for
// steering messages. If found and there are remaining tools,
// skip them all.
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
remaining := len(normalizedToolCalls) - i - 1
if remaining > 0 {
logger.InfoCF("agent", "Steering interrupt: skipping remaining tools",
map[string]any{
"agent_id": agent.ID,
"completed": i + 1,
"skipped": remaining,
"total_tools": len(normalizedToolCalls),
"steering_count": len(steerMsgs),
})
// Mark remaining tool calls as skipped
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
toolResultMsg := providers.Message{
Role: "tool",
Content: "Skipped due to queued user message.",
ToolCallID: skippedTC.ID,
}
messages = append(messages, toolResultMsg)
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
}
}
steeringAfterTools = steerMsgs
break
}
}
// If steering messages were captured during tool execution, they
// become pendingMessages for the next iteration of the inner loop.
if len(steeringAfterTools) > 0 {
pendingMessages = steeringAfterTools
}
// Tick down TTL of discovered tools after processing tool results.