Merge pull request #2762 from afjcjsbx/feat/stop-command

feat(agent): stop command
This commit is contained in:
Mauro
2026-05-06 18:19:14 +02:00
committed by GitHub
12 changed files with 662 additions and 16 deletions
+23
View File
@@ -58,6 +58,7 @@ type AgentLoop struct {
hookRuntime hookRuntime
steering *steeringQueue
pendingSkills sync.Map
pendingStops sync.Map
mu sync.RWMutex
// workerSem limits concurrent turn processing workers.
@@ -177,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
phase: TurnPhaseSetup,
}
if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded {
if al.tryHandleStopCommand(ctx, msg, sessionKey) {
continue
}
// Another turn is already active (or reserved) for this session — enqueue
if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{
Role: "user",
@@ -240,6 +245,24 @@ func (al *AgentLoop) Run(ctx context.Context) error {
defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID)
}
if al.takePendingStop(sessionKey) {
al.activeTurnStates.Delete(sessionKey)
target := &continuationTarget{
SessionKey: sessionKey,
Channel: m.Channel,
ChatID: m.ChatID,
}
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
if continueErr != nil {
al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr)
return
}
if continued != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued)
}
return
}
al.runTurnWithSteering(ctx, m)
}(msg)
+6
View File
@@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime(
return nil
},
}
rt.StopActiveTurn = func() (commands.StopResult, error) {
if opts == nil {
return commands.StopResult{}, fmt.Errorf("process options not available")
}
return al.stopActiveTurnForSession(opts.Dispatch.SessionKey)
}
if agent != nil && agent.ContextBuilder != nil {
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
}
+31 -15
View File
@@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
return
}
// Drain steering queue using existing Continue mechanism
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
if continueErr != nil {
logger.WarnCF("agent", "Failed to continue queued steering",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"error": continueErr.Error(),
})
} else if continued != "" {
finalResponse = continued
}
// Publish final response
if finalResponse != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
}
}
func (al *AgentLoop) drainQueuedSteeringContinuations(
ctx context.Context,
target *continuationTarget,
) (string, error) {
if target == nil {
return "", nil
}
finalResponse := ""
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
// Check for context cancellation between iterations
if ctx.Err() != nil {
return
if err := ctx.Err(); err != nil {
return finalResponse, err
}
logger.InfoCF("agent", "Continuing queued steering after turn end",
@@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
if continueErr != nil {
logger.WarnCF("agent", "Failed to continue queued steering",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"error": continueErr.Error(),
})
break
return finalResponse, continueErr
}
if continued == "" {
break
@@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
finalResponse = continued
}
// Publish final response
if finalResponse != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
}
return finalResponse, nil
}
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
+122
View File
@@ -0,0 +1,122 @@
package agent
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/commands"
)
func (al *AgentLoop) tryHandleStopCommand(
ctx context.Context,
msg bus.InboundMessage,
sessionKey string,
) bool {
cmdName, ok := commands.CommandName(msg.Content)
if !ok || cmdName != "stop" {
return false
}
result, err := al.stopActiveTurnForSession(sessionKey)
// This function is only called when loaded=true (another turn already
// claimed this session). If stopActiveTurnForSession found a pending
// placeholder but didn't stop it, that placeholder belongs to the other
// message's worker which hasn't started yet — arm a pending stop so the
// worker will bail when it checks before running.
if err == nil && !result.Stopped {
if ts := al.getActiveTurnState(sessionKey); ts != nil {
snap := ts.snapshot()
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
al.markPendingStop(sessionKey)
result.Stopped = true
}
}
}
reply := commands.FormatStopReply(result)
if err != nil {
reply = "Failed to stop task: " + err.Error()
}
if al.channelManager != nil {
al.channelManager.InvokeTypingStop(msg.Channel, msg.ChatID)
}
al.resetMessageToolRound(sessionKey)
al.PublishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, sessionKey, reply)
return true
}
func (al *AgentLoop) stopActiveTurnForSession(sessionKey string) (commands.StopResult, error) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return commands.StopResult{}, fmt.Errorf("session key is required")
}
result := commands.StopResult{}
cleared := al.clearSteeringMessagesForScope(sessionKey)
al.clearPendingSkills(sessionKey)
ts := al.getActiveTurnState(sessionKey)
if ts == nil {
result.Stopped = cleared > 0
return result, nil
}
snap := ts.snapshot()
result.TaskName = snap.UserMessage
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
// A pending placeholder means this session is either idle (our own
// placeholder from the /stop command) or another message is queued but
// hasn't started yet. In both cases, we don't arm a pending stop here;
// the caller (tryHandleStopCommand) handles the "another message queued"
// case explicitly, since it knows loaded=true.
return result, nil
}
if err := al.HardAbort(sessionKey); err != nil {
if al.getActiveTurnState(sessionKey) == nil {
result.Stopped = cleared > 0
return result, nil
}
return commands.StopResult{}, err
}
result.Stopped = true
return result, nil
}
func (al *AgentLoop) markPendingStop(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return
}
al.pendingStops.Store(sessionKey, struct{}{})
}
func (al *AgentLoop) takePendingStop(sessionKey string) bool {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return false
}
_, ok := al.pendingStops.LoadAndDelete(sessionKey)
return ok
}
func (al *AgentLoop) resetMessageToolRound(sessionKey string) {
if strings.TrimSpace(sessionKey) == "" {
return
}
if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
if tool, ok := agent.Tools.Get("message"); ok {
if resetter, ok := tool.(interface{ ResetSentInRound(sessionKey string) }); ok {
resetter.ResetSentInRound(sessionKey)
}
}
}
}
}
+23
View File
@@ -156,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int {
return len(sq.queues[normalizeSteeringScope(scope)])
}
func (sq *steeringQueue) clearScope(scope string) int {
sq.mu.Lock()
defer sq.mu.Unlock()
scope = normalizeSteeringScope(scope)
count := len(sq.queues[scope])
if count > 0 {
delete(sq.queues, scope)
}
return count
}
// setMode updates the steering mode.
func (sq *steeringQueue) setMode(mode SteeringMode) {
sq.mu.Lock()
@@ -290,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
return al.steering.lenScope(scope)
}
func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int {
if al.steering == nil {
return 0
}
return al.steering.clearScope(scope)
}
func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
@@ -511,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
"initial_history_length": ts.initialHistoryLength,
})
// Cancel the active provider/tool turn contexts immediately so long-running
// execution stops as soon as possible on the root turn.
_ = ts.requestHardAbort()
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
// from adding more messages to the session. This prevents race conditions
// where rollback happens while children are still writing.
+328
View File
@@ -840,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
}
}
func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxParallelTurns: 1,
},
},
}
msgBus := bus.NewMessageBus()
provider := &lateSteeringProvider{
firstCallStarted: make(chan struct{}),
releaseFirstCall: make(chan struct{}),
}
al := NewAgentLoop(cfg, msgBus, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
defer func() {
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run() error = %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}()
blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker")
targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target")
blockerCtx := bus.InboundContext{
Channel: "test",
ChatID: "blocker-chat",
ChatType: "direct",
SenderID: "user1",
}
targetCtx := bus.InboundContext{
Channel: "test",
ChatID: "target-chat",
ChatType: "direct",
SenderID: "user1",
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: blockerCtx,
Content: "block worker pool",
SessionKey: blockerSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(blocker) error = %v", err)
}
select {
case <-provider.firstCallStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for blocker turn to start")
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "skip this turn",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(target start) error = %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for {
ts := al.getActiveTurnState(targetSessionKey)
if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) {
break
}
if time.Now().After(deadline) {
t.Fatal("timeout waiting for pending placeholder")
}
time.Sleep(10 * time.Millisecond)
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "/stop",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(/stop) error = %v", err)
}
deadline = time.Now().Add(2 * time.Second)
stopSeen := false
for !stopSeen {
select {
case outbound := <-msgBus.OutboundChan():
if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." {
stopSeen = true
}
case <-time.After(10 * time.Millisecond):
if time.Now().After(deadline) {
t.Fatal("timeout waiting for /stop reply")
}
}
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "run this instead",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(follow-up) error = %v", err)
}
deadline = time.Now().Add(2 * time.Second)
for al.pendingSteeringCountForScope(targetSessionKey) == 0 {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for follow-up to enter scoped steering queue")
}
time.Sleep(10 * time.Millisecond)
}
close(provider.releaseFirstCall)
deadline = time.Now().Add(5 * time.Second)
followUpSeen := false
for !followUpSeen {
select {
case outbound := <-msgBus.OutboundChan():
if outbound.ChatID == "target-chat" && outbound.Content == "continued response" {
followUpSeen = true
}
case <-time.After(10 * time.Millisecond):
if time.Now().After(deadline) {
t.Fatal("timeout waiting for queued follow-up continuation")
}
}
}
deadline = time.Now().Add(2 * time.Second)
for {
if al.GetActiveTurnBySession(targetSessionKey) == nil &&
al.pendingSteeringCountForScope(targetSessionKey) == 0 {
break
}
if time.Now().After(deadline) {
t.Fatal("timeout waiting for target session to go idle")
}
time.Sleep(10 * time.Millisecond)
}
provider.mu.Lock()
calls := provider.calls
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls)
}
foundFollowUp := false
for _, msg := range secondMessages {
if msg.Role == "user" && msg.Content == "run this instead" {
foundFollowUp = true
}
if msg.Role == "user" && msg.Content == "skip this turn" {
t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content)
}
}
if !foundFollowUp {
t.Fatal("expected queued follow-up to be processed after pending stop")
}
}
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -1392,6 +1577,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
}
}
func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "cancel_tool",
Function: &providers.FunctionCall{
Name: "cancel_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "should not continue",
}
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
defer func() {
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run() error = %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}()
baseMsg := testInboundMessage(bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
SessionKey: sessionKey,
})
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "do work",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(start) error = %v", err)
}
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for interruptible tool to start")
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "follow up after cancel",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(follow-up) error = %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for al.pendingSteeringCountForScope(sessionKey) == 0 {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for follow-up message to enter steering queue")
}
time.Sleep(10 * time.Millisecond)
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "/stop",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(/stop) error = %v", err)
}
select {
case outbound := <-msgBus.OutboundChan():
want := "Task stopped. \"do work\" was canceled."
if outbound.Content != want {
t.Fatalf("stop reply = %q, want %q", outbound.Content, want)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /stop reply")
}
deadline = time.Now().Add(5 * time.Second)
for al.GetActiveTurnBySession(sessionKey) != nil {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for active turn to stop")
}
time.Sleep(10 * time.Millisecond)
}
if got := al.pendingSteeringCountForScope(sessionKey); got != 0 {
t.Fatalf("expected cleared steering queue, got %d pending message(s)", got)
}
select {
case outbound := <-msgBus.OutboundChan():
t.Fatalf("unexpected outbound after stop: %q", outbound.Content)
case <-time.After(300 * time.Millisecond):
}
provider.mu.Lock()
calls := provider.calls
provider.mu.Unlock()
if calls != 1 {
t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls)
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
+9
View File
@@ -26,6 +26,10 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
al.registerActiveTurn(ts)
defer al.clearActiveTurn(ts)
if al.takePendingStop(ts.sessionKey) {
_ = ts.requestHardAbort()
}
turnStatus := TurnEndStatusCompleted
defer func() {
al.emitEvent(
@@ -40,6 +44,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
)
}()
if ts.hardAbortRequested() {
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
}
al.emitEvent(
runtimeevents.KindAgentTurnStart,
ts.eventMeta("runTurn", "turn.start"),
+4 -1
View File
@@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
ts.initialHistoryLength = len(history)
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
}
return ts