mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
391 lines
10 KiB
Go
391 lines
10 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/routing"
|
|
)
|
|
|
|
// SteeringMode controls how queued steering messages are dequeued.
|
|
type SteeringMode string
|
|
|
|
const (
|
|
// SteeringOneAtATime dequeues only the first queued message per poll.
|
|
SteeringOneAtATime SteeringMode = "one-at-a-time"
|
|
// SteeringAll drains the entire queue in a single poll.
|
|
SteeringAll SteeringMode = "all"
|
|
// MaxQueueSize number of possible messages in the Steering Queue
|
|
MaxQueueSize = 10
|
|
// manualSteeringScope is the legacy fallback queue used when no active
|
|
// turn/session scope is available.
|
|
manualSteeringScope = "__manual__"
|
|
)
|
|
|
|
// parseSteeringMode normalizes a config string into a SteeringMode.
|
|
func parseSteeringMode(s string) SteeringMode {
|
|
switch s {
|
|
case "all":
|
|
return SteeringAll
|
|
default:
|
|
return SteeringOneAtATime
|
|
}
|
|
}
|
|
|
|
// steeringQueue is a thread-safe queue of user messages that can be injected
|
|
// into a running agent loop to interrupt it between tool calls.
|
|
type steeringQueue struct {
|
|
mu sync.Mutex
|
|
queues map[string][]providers.Message
|
|
mode SteeringMode
|
|
}
|
|
|
|
func newSteeringQueue(mode SteeringMode) *steeringQueue {
|
|
return &steeringQueue{
|
|
queues: make(map[string][]providers.Message),
|
|
mode: mode,
|
|
}
|
|
}
|
|
|
|
func normalizeSteeringScope(scope string) string {
|
|
scope = strings.TrimSpace(scope)
|
|
if scope == "" {
|
|
return manualSteeringScope
|
|
}
|
|
return scope
|
|
}
|
|
|
|
// push enqueues a steering message in the legacy fallback scope.
|
|
func (sq *steeringQueue) push(msg providers.Message) error {
|
|
return sq.pushScope(manualSteeringScope, msg)
|
|
}
|
|
|
|
// pushScope enqueues a steering message for the provided scope.
|
|
func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
|
|
sq.mu.Lock()
|
|
defer sq.mu.Unlock()
|
|
|
|
scope = normalizeSteeringScope(scope)
|
|
queue := sq.queues[scope]
|
|
if len(queue) >= MaxQueueSize {
|
|
return fmt.Errorf("steering queue is full")
|
|
}
|
|
sq.queues[scope] = append(queue, msg)
|
|
return nil
|
|
}
|
|
|
|
// dequeue removes and returns pending steering messages from the legacy
|
|
// fallback scope according to the configured mode.
|
|
func (sq *steeringQueue) dequeue() []providers.Message {
|
|
return sq.dequeueScope(manualSteeringScope)
|
|
}
|
|
|
|
// dequeueScope removes and returns pending steering messages for the provided
|
|
// scope according to the configured mode.
|
|
func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
|
|
sq.mu.Lock()
|
|
defer sq.mu.Unlock()
|
|
|
|
return sq.dequeueLocked(normalizeSteeringScope(scope))
|
|
}
|
|
|
|
// dequeueScopeWithFallback drains the scoped queue first and falls back to the
|
|
// legacy manual scope for backwards compatibility.
|
|
func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
|
|
sq.mu.Lock()
|
|
defer sq.mu.Unlock()
|
|
|
|
scope = strings.TrimSpace(scope)
|
|
if scope != "" {
|
|
if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
|
|
return msgs
|
|
}
|
|
}
|
|
|
|
return sq.dequeueLocked(manualSteeringScope)
|
|
}
|
|
|
|
func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
|
|
queue := sq.queues[scope]
|
|
if len(queue) == 0 {
|
|
return nil
|
|
}
|
|
|
|
switch sq.mode {
|
|
case SteeringAll:
|
|
msgs := append([]providers.Message(nil), queue...)
|
|
delete(sq.queues, scope)
|
|
return msgs
|
|
default:
|
|
msg := queue[0]
|
|
queue[0] = providers.Message{} // Clear reference for GC
|
|
queue = queue[1:]
|
|
if len(queue) == 0 {
|
|
delete(sq.queues, scope)
|
|
} else {
|
|
sq.queues[scope] = queue
|
|
}
|
|
return []providers.Message{msg}
|
|
}
|
|
}
|
|
|
|
// len returns the number of queued messages across all scopes.
|
|
func (sq *steeringQueue) len() int {
|
|
sq.mu.Lock()
|
|
defer sq.mu.Unlock()
|
|
|
|
total := 0
|
|
for _, queue := range sq.queues {
|
|
total += len(queue)
|
|
}
|
|
return total
|
|
}
|
|
|
|
// lenScope returns the number of queued messages for a specific scope.
|
|
func (sq *steeringQueue) lenScope(scope string) int {
|
|
sq.mu.Lock()
|
|
defer sq.mu.Unlock()
|
|
return len(sq.queues[normalizeSteeringScope(scope)])
|
|
}
|
|
|
|
// setMode updates the steering mode.
|
|
func (sq *steeringQueue) setMode(mode SteeringMode) {
|
|
sq.mu.Lock()
|
|
defer sq.mu.Unlock()
|
|
sq.mode = mode
|
|
}
|
|
|
|
// getMode returns the current steering mode.
|
|
func (sq *steeringQueue) getMode() SteeringMode {
|
|
sq.mu.Lock()
|
|
defer sq.mu.Unlock()
|
|
return sq.mode
|
|
}
|
|
|
|
// Steer enqueues a user message to be injected into the currently running
|
|
// agent loop. The message will be picked up after the current tool finishes
|
|
// executing, causing any remaining tool calls in the batch to be skipped.
|
|
func (al *AgentLoop) Steer(msg providers.Message) error {
|
|
scope := ""
|
|
agentID := ""
|
|
if ts := al.getActiveTurnState(); ts != nil {
|
|
scope = ts.sessionKey
|
|
agentID = ts.agentID
|
|
}
|
|
return al.enqueueSteeringMessage(scope, agentID, msg)
|
|
}
|
|
|
|
func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
|
|
if al.steering == nil {
|
|
return fmt.Errorf("steering queue is not initialized")
|
|
}
|
|
|
|
if err := al.steering.pushScope(scope, msg); err != nil {
|
|
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
|
|
"error": err.Error(),
|
|
"role": msg.Role,
|
|
"scope": normalizeSteeringScope(scope),
|
|
})
|
|
return err
|
|
}
|
|
|
|
queueDepth := al.steering.lenScope(scope)
|
|
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
|
|
"role": msg.Role,
|
|
"content_len": len(msg.Content),
|
|
"media_count": len(msg.Media),
|
|
"queue_len": queueDepth,
|
|
"scope": normalizeSteeringScope(scope),
|
|
})
|
|
|
|
meta := EventMeta{
|
|
Source: "Steer",
|
|
TracePath: "turn.interrupt.received",
|
|
}
|
|
if ts := al.getActiveTurnState(); ts != nil {
|
|
meta = ts.eventMeta("Steer", "turn.interrupt.received")
|
|
} else {
|
|
if strings.TrimSpace(agentID) != "" {
|
|
meta.AgentID = agentID
|
|
}
|
|
normalizedScope := normalizeSteeringScope(scope)
|
|
if normalizedScope != manualSteeringScope {
|
|
meta.SessionKey = normalizedScope
|
|
}
|
|
if meta.AgentID == "" {
|
|
if registry := al.GetRegistry(); registry != nil {
|
|
if agent := registry.GetDefaultAgent(); agent != nil {
|
|
meta.AgentID = agent.ID
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
al.emitEvent(
|
|
EventKindInterruptReceived,
|
|
meta,
|
|
InterruptReceivedPayload{
|
|
Kind: InterruptKindSteering,
|
|
Role: msg.Role,
|
|
ContentLen: len(msg.Content),
|
|
QueueDepth: queueDepth,
|
|
},
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// SteeringMode returns the current steering mode.
|
|
func (al *AgentLoop) SteeringMode() SteeringMode {
|
|
if al.steering == nil {
|
|
return SteeringOneAtATime
|
|
}
|
|
return al.steering.getMode()
|
|
}
|
|
|
|
// SetSteeringMode updates the steering mode.
|
|
func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
|
|
if al.steering == nil {
|
|
return
|
|
}
|
|
al.steering.setMode(mode)
|
|
}
|
|
|
|
// dequeueSteeringMessages is the internal method called by the agent loop
|
|
// to poll for steering messages in the legacy fallback scope.
|
|
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
|
if al.steering == nil {
|
|
return nil
|
|
}
|
|
return al.steering.dequeue()
|
|
}
|
|
|
|
func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
|
|
if al.steering == nil {
|
|
return nil
|
|
}
|
|
return al.steering.dequeueScope(scope)
|
|
}
|
|
|
|
func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
|
|
if al.steering == nil {
|
|
return nil
|
|
}
|
|
return al.steering.dequeueScopeWithFallback(scope)
|
|
}
|
|
|
|
func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
|
|
if al.steering == nil {
|
|
return 0
|
|
}
|
|
return al.steering.lenScope(scope)
|
|
}
|
|
|
|
func (al *AgentLoop) continueWithSteeringMessages(
|
|
ctx context.Context,
|
|
agent *AgentInstance,
|
|
sessionKey, channel, chatID string,
|
|
steeringMsgs []providers.Message,
|
|
) (string, error) {
|
|
return al.runAgentLoop(ctx, agent, processOptions{
|
|
SessionKey: sessionKey,
|
|
Channel: channel,
|
|
ChatID: chatID,
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: true,
|
|
SendResponse: false,
|
|
InitialSteeringMessages: steeringMsgs,
|
|
SkipInitialSteeringPoll: true,
|
|
})
|
|
}
|
|
|
|
func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
|
registry := al.GetRegistry()
|
|
if registry == nil {
|
|
return nil
|
|
}
|
|
|
|
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
|
|
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
|
return agent
|
|
}
|
|
}
|
|
|
|
return registry.GetDefaultAgent()
|
|
}
|
|
|
|
// Continue resumes an idle agent by dequeuing any pending steering messages
|
|
// and running them through the agent loop. This is used when the agent's last
|
|
// message was from the assistant (i.e., it has stopped processing) and the
|
|
// user has since enqueued steering messages.
|
|
//
|
|
// If no steering messages are pending, it returns an empty string.
|
|
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
|
|
if active := al.GetActiveTurn(); active != nil {
|
|
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
|
}
|
|
|
|
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
|
|
if len(steeringMsgs) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
agent := al.agentForSession(sessionKey)
|
|
if agent == nil {
|
|
return "", fmt.Errorf("no agent available for session %q", sessionKey)
|
|
}
|
|
|
|
if tool, ok := agent.Tools.Get("message"); ok {
|
|
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
|
|
resetter.ResetSentInRound()
|
|
}
|
|
}
|
|
|
|
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
|
|
}
|
|
|
|
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
|
ts := al.getActiveTurnState()
|
|
if ts == nil {
|
|
return fmt.Errorf("no active turn")
|
|
}
|
|
if !ts.requestGracefulInterrupt(hint) {
|
|
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
|
|
}
|
|
|
|
al.emitEvent(
|
|
EventKindInterruptReceived,
|
|
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
|
|
InterruptReceivedPayload{
|
|
Kind: InterruptKindGraceful,
|
|
HintLen: len(hint),
|
|
},
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (al *AgentLoop) InterruptHard() error {
|
|
ts := al.getActiveTurnState()
|
|
if ts == nil {
|
|
return fmt.Errorf("no active turn")
|
|
}
|
|
if !ts.requestHardAbort() {
|
|
return fmt.Errorf("turn %s is already aborting", ts.turnID)
|
|
}
|
|
|
|
al.emitEvent(
|
|
EventKindInterruptReceived,
|
|
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
|
|
InterruptReceivedPayload{
|
|
Kind: InterruptKindHard,
|
|
},
|
|
)
|
|
|
|
return nil
|
|
}
|