feat(agent): stop command

This commit is contained in:
afjcjsbx
2026-05-04 08:41:17 +02:00
parent be67aed4dc
commit f3ef7090c5
10 changed files with 260 additions and 2 deletions
+10
View File
@@ -58,6 +58,7 @@ type AgentLoop struct {
hookRuntime hookRuntime
steering *steeringQueue
pendingSkills sync.Map
pendingStops sync.Map
mu sync.RWMutex
// workerSem limits concurrent turn processing workers.
@@ -177,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
phase: TurnPhaseSetup,
}
if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded {
if al.tryHandleStopCommand(ctx, msg, sessionKey) {
continue
}
// Another turn is already active (or reserved) for this session — enqueue
if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{
Role: "user",
@@ -240,6 +245,11 @@ func (al *AgentLoop) Run(ctx context.Context) error {
defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID)
}
if al.takePendingStop(sessionKey) {
al.activeTurnStates.Delete(sessionKey)
return
}
al.runTurnWithSteering(ctx, m)
}(msg)
+6
View File
@@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime(
return nil
},
}
rt.StopActiveTurn = func() (commands.StopResult, error) {
if opts == nil {
return commands.StopResult{}, fmt.Errorf("process options not available")
}
return al.stopActiveTurnForSession(opts.Dispatch.SessionKey)
}
if agent != nil && agent.ContextBuilder != nil {
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
}
+1 -1
View File
@@ -292,7 +292,7 @@ func (p *Pipeline) CallLLM(
if isNetworkError && retry < maxRetries {
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
al.emitEvent(
EventKindLLMRetry,
runtimeevents.KindAgentLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
+23
View File
@@ -156,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int {
return len(sq.queues[normalizeSteeringScope(scope)])
}
func (sq *steeringQueue) clearScope(scope string) int {
sq.mu.Lock()
defer sq.mu.Unlock()
scope = normalizeSteeringScope(scope)
count := len(sq.queues[scope])
if count > 0 {
delete(sq.queues, scope)
}
return count
}
// setMode updates the steering mode.
func (sq *steeringQueue) setMode(mode SteeringMode) {
sq.mu.Lock()
@@ -290,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
return al.steering.lenScope(scope)
}
func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int {
if al.steering == nil {
return 0
}
return al.steering.clearScope(scope)
}
func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
@@ -511,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
"initial_history_length": ts.initialHistoryLength,
})
// Cancel the active provider/tool turn contexts immediately so long-running
// execution stops as soon as possible on the root turn.
_ = ts.requestHardAbort()
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
// from adding more messages to the session. This prevents race conditions
// where rollback happens while children are still writing.
+143
View File
@@ -1392,6 +1392,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
}
}
func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "cancel_tool",
Function: &providers.FunctionCall{
Name: "cancel_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "should not continue",
}
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
defer func() {
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run() error = %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}()
baseMsg := testInboundMessage(bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
SessionKey: sessionKey,
})
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "do work",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(start) error = %v", err)
}
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for interruptible tool to start")
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "follow up after cancel",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(follow-up) error = %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for al.pendingSteeringCountForScope(sessionKey) == 0 {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for follow-up message to enter steering queue")
}
time.Sleep(10 * time.Millisecond)
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "/stop",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(/stop) error = %v", err)
}
select {
case outbound := <-msgBus.OutboundChan():
want := "⏹️ Task stopped. \"do work\" was canceled."
if outbound.Content != want {
t.Fatalf("stop reply = %q, want %q", outbound.Content, want)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /stop reply")
}
deadline = time.Now().Add(5 * time.Second)
for al.GetActiveTurnBySession(sessionKey) != nil {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for active turn to stop")
}
time.Sleep(10 * time.Millisecond)
}
if got := al.pendingSteeringCountForScope(sessionKey); got != 0 {
t.Fatalf("expected cleared steering queue, got %d pending message(s)", got)
}
select {
case outbound := <-msgBus.OutboundChan():
t.Fatalf("unexpected outbound after stop: %q", outbound.Content)
case <-time.After(300 * time.Millisecond):
}
provider.mu.Lock()
calls := provider.calls
provider.mu.Unlock()
if calls != 1 {
t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls)
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
+9
View File
@@ -26,6 +26,10 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
al.registerActiveTurn(ts)
defer al.clearActiveTurn(ts)
if al.takePendingStop(ts.sessionKey) {
_ = ts.requestHardAbort()
}
turnStatus := TurnEndStatusCompleted
defer func() {
al.emitEvent(
@@ -40,6 +44,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
)
}()
if ts.hardAbortRequested() {
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
}
al.emitEvent(
runtimeevents.KindAgentTurnStart,
ts.eventMeta("runTurn", "turn.start"),
+4 -1
View File
@@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
ts.initialHistoryLength = len(history)
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
}
return ts
+1
View File
@@ -8,6 +8,7 @@ func BuiltinDefinitions() []Definition {
return []Definition{
startCommand(),
helpCommand(),
stopCommand(),
showCommand(),
listCommand(),
useCommand(),
+56
View File
@@ -42,6 +42,9 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
if !strings.Contains(reply, "/list [models|channels|agents|skills|mcp]") {
t.Fatalf("/help reply missing /list usage, got %q", reply)
}
if !strings.Contains(reply, "/stop") {
t.Fatalf("/help reply missing /stop usage, got %q", reply)
}
if !strings.Contains(reply, "/use <skill> <message>") {
if !strings.Contains(reply, "/use <skill> [message]") {
t.Fatalf("/help reply missing /use usage, got %q", reply)
@@ -49,6 +52,59 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
}
}
func TestBuiltinStop_UsesRuntimeStopper(t *testing.T) {
rt := &Runtime{
StopActiveTurn: func() (StopResult, error) {
return StopResult{
Stopped: true,
TaskName: "sync the long running job",
}, nil
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/stop",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Task stopped. \"sync the long running job\" was canceled." {
t.Fatalf("/stop reply=%q", reply)
}
}
func TestBuiltinStop_NoActiveTask(t *testing.T) {
rt := &Runtime{
StopActiveTurn: func() (StopResult, error) {
return StopResult{}, nil
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/stop",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "No active task to stop." {
t.Fatalf("/stop reply=%q, want no-active message", reply)
}
}
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), nil)
+7
View File
@@ -36,6 +36,12 @@ type ContextStats struct {
MessageCount int
}
// StopResult describes the outcome of a stop request for the current session.
type StopResult struct {
Stopped bool
TaskName string
}
// Runtime provides runtime dependencies to command handlers. It is constructed
// per-request by the agent loop so that per-request state (like session scope)
// can coexist with long-lived callbacks (like GetModelInfo).
@@ -55,4 +61,5 @@ type Runtime struct {
SwitchChannel func(value string) error
ClearHistory func() error
ReloadConfig func() error
StopActiveTurn func() (StopResult, error)
}