mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): steering (#1517)
* feat(agent): steering * fix loop * fix lint * fix lint
This commit is contained in:
+183
-102
@@ -48,6 +48,7 @@ type AgentLoop struct {
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
steering *steeringQueue
|
||||
mu sync.RWMutex
|
||||
// Track active requests for safe provider cleanup
|
||||
activeRequests sync.WaitGroup
|
||||
@@ -55,15 +56,16 @@ type AgentLoop struct {
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -105,6 +107,7 @@ func NewAgentLoop(
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
|
||||
return al
|
||||
@@ -257,6 +260,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Start a goroutine that drains the bus while processMessage is
|
||||
// running. Any inbound messages that arrive during processing are
|
||||
// redirected into the steering queue so the agent loop can pick
|
||||
// them up between tool calls.
|
||||
drainCtx, drainCancel := context.WithCancel(ctx)
|
||||
go al.drainBusToSteering(drainCtx)
|
||||
|
||||
// Process message
|
||||
func() {
|
||||
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
|
||||
@@ -272,6 +282,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// }
|
||||
// }()
|
||||
|
||||
defer drainCancel()
|
||||
|
||||
response, err := al.processMessage(ctx, msg)
|
||||
if err != nil {
|
||||
response = fmt.Sprintf("Error processing message: %v", err)
|
||||
@@ -318,6 +330,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// drainBusToSteering continuously consumes inbound messages and redirects
|
||||
// them into the steering queue. It runs in a goroutine while processMessage
|
||||
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
|
||||
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
|
||||
for {
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Transcribe audio if needed before steering, so the agent sees text.
|
||||
msg, _ = al.transcribeAudioInMessage(ctx, msg)
|
||||
|
||||
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"sender_id": msg.SenderID,
|
||||
"content_len": len(msg.Content),
|
||||
})
|
||||
|
||||
if err := al.Steer(providers.Message{
|
||||
Role: "user",
|
||||
Content: msg.Content,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Failed to steer message, will be lost",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"channel": msg.Channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
}
|
||||
@@ -999,6 +1044,16 @@ func (al *AgentLoop) runLLMIteration(
|
||||
) (string, int, error) {
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
var pendingMessages []providers.Message
|
||||
|
||||
// Poll for steering messages at loop start (in case the user typed while
|
||||
// the agent was setting up), unless the caller already provided initial
|
||||
// steering messages (e.g. Continue).
|
||||
if !opts.SkipInitialSteeringPoll {
|
||||
if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 {
|
||||
pendingMessages = msgs
|
||||
}
|
||||
}
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
@@ -1006,9 +1061,25 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// tool chain doesn't switch models mid-way through.
|
||||
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
|
||||
|
||||
for iteration < agent.MaxIterations {
|
||||
for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
|
||||
iteration++
|
||||
|
||||
// Inject pending steering messages into the conversation context
|
||||
// before the next LLM call.
|
||||
if len(pendingMessages) > 0 {
|
||||
for _, pm := range pendingMessages {
|
||||
messages = append(messages, pm)
|
||||
agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
|
||||
logger.InfoCF("agent", "Injected steering message into context",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_len": len(pm.Content),
|
||||
})
|
||||
}
|
||||
pendingMessages = nil
|
||||
}
|
||||
|
||||
logger.DebugCF("agent", "LLM iteration",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
@@ -1251,107 +1322,83 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Save assistant message with tool calls to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||
|
||||
// Execute tool calls in parallel
|
||||
type indexedAgentResult struct {
|
||||
result *tools.ToolResult
|
||||
tc providers.ToolCall
|
||||
}
|
||||
|
||||
agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
// Execute tool calls sequentially. After each tool completes, check
|
||||
// for steering messages. If any are found, skip remaining tools.
|
||||
var steeringAfterTools []providers.Message
|
||||
|
||||
for i, tc := range normalizedToolCalls {
|
||||
agentResults[i].tc = tc
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
go func(idx int, tc providers.ToolCall) {
|
||||
defer wg.Done()
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Create async callback for tools that implement AsyncExecutor.
|
||||
// When the background work completes, this publishes the result
|
||||
// as an inbound system message so processSystemMessage routes it
|
||||
// back to the user via the normal agent loop.
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
// Send ForUser content directly to the user (immediate feedback),
|
||||
// mirroring the synchronous tool execution path.
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer outCancel()
|
||||
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: result.ForUser,
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for the agent loop (ForLLM or error).
|
||||
content := result.ForLLM
|
||||
if content == "" && result.Err != nil {
|
||||
content = result.Err.Error()
|
||||
}
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(content),
|
||||
"channel": opts.Channel,
|
||||
})
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", tc.Name),
|
||||
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
|
||||
Content: content,
|
||||
// Create async callback for tools that implement AsyncExecutor.
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer outCancel()
|
||||
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: result.ForUser,
|
||||
})
|
||||
}
|
||||
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
agentResults[idx].result = toolResult
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
content := result.ForLLM
|
||||
if content == "" && result.Err != nil {
|
||||
content = result.Err.Error()
|
||||
}
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Process results in original order (send to user, save to session)
|
||||
for _, r := range agentResults {
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(content),
|
||||
"channel": opts.Channel,
|
||||
})
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", tc.Name),
|
||||
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
ctx,
|
||||
tc.Name,
|
||||
tc.Arguments,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
asyncCallback,
|
||||
)
|
||||
|
||||
// Process tool result
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: r.result.ForUser,
|
||||
Content: toolResult.ForUser,
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": r.tc.Name,
|
||||
"content_len": len(r.result.ForUser),
|
||||
"tool": tc.Name,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
// If tool returned media refs, publish them as outbound media
|
||||
if len(r.result.Media) > 0 {
|
||||
parts := make([]bus.MediaPart, 0, len(r.result.Media))
|
||||
for _, ref := range r.result.Media {
|
||||
if len(toolResult.Media) > 0 {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
@@ -1369,21 +1416,55 @@ func (al *AgentLoop) runLLMIteration(
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := r.result.ForLLM
|
||||
if contentForLLM == "" && r.result.Err != nil {
|
||||
contentForLLM = r.result.Err.Error()
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: r.tc.ID,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
|
||||
// Save tool result message to session
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
|
||||
// After EVERY tool (including the first and last), check for
|
||||
// steering messages. If found and there are remaining tools,
|
||||
// skip them all.
|
||||
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
|
||||
remaining := len(normalizedToolCalls) - i - 1
|
||||
if remaining > 0 {
|
||||
logger.InfoCF("agent", "Steering interrupt: skipping remaining tools",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"completed": i + 1,
|
||||
"skipped": remaining,
|
||||
"total_tools": len(normalizedToolCalls),
|
||||
"steering_count": len(steerMsgs),
|
||||
})
|
||||
|
||||
// Mark remaining tool calls as skipped
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: "Skipped due to queued user message.",
|
||||
ToolCallID: skippedTC.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||
}
|
||||
}
|
||||
steeringAfterTools = steerMsgs
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If steering messages were captured during tool execution, they
|
||||
// become pendingMessages for the next iteration of the inner loop.
|
||||
if len(steeringAfterTools) > 0 {
|
||||
pendingMessages = steeringAfterTools
|
||||
}
|
||||
|
||||
// Tick down TTL of discovered tools after processing tool results.
|
||||
|
||||
Reference in New Issue
Block a user