From f3ef7090c5d40b463cf1730132b770c1039daca9 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Mon, 4 May 2026 08:41:17 +0200 Subject: [PATCH] feat(agent): stop command --- pkg/agent/agent.go | 10 +++ pkg/agent/agent_command.go | 6 ++ pkg/agent/pipeline_llm.go | 2 +- pkg/agent/steering.go | 23 ++++++ pkg/agent/steering_test.go | 143 +++++++++++++++++++++++++++++++++++ pkg/agent/turn_coord.go | 9 +++ pkg/agent/turn_state.go | 5 +- pkg/commands/builtin.go | 1 + pkg/commands/builtin_test.go | 56 ++++++++++++++ pkg/commands/runtime.go | 7 ++ 10 files changed, 260 insertions(+), 2 deletions(-) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 84849aece..bb21b7c5e 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -58,6 +58,7 @@ type AgentLoop struct { hookRuntime hookRuntime steering *steeringQueue pendingSkills sync.Map + pendingStops sync.Map mu sync.RWMutex // workerSem limits concurrent turn processing workers. @@ -177,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error { phase: TurnPhaseSetup, } if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded { + if al.tryHandleStopCommand(ctx, msg, sessionKey) { + continue + } + // Another turn is already active (or reserved) for this session — enqueue if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{ Role: "user", @@ -240,6 +245,11 @@ func (al *AgentLoop) Run(ctx context.Context) error { defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID) } + if al.takePendingStop(sessionKey) { + al.activeTurnStates.Delete(sessionKey) + return + } + al.runTurnWithSteering(ctx, m) }(msg) diff --git a/pkg/agent/agent_command.go b/pkg/agent/agent_command.go index a2ed068d6..ae0293d71 100644 --- a/pkg/agent/agent_command.go +++ b/pkg/agent/agent_command.go @@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime( return nil }, } + rt.StopActiveTurn = func() (commands.StopResult, error) { + if opts == nil { + return commands.StopResult{}, fmt.Errorf("process options not available") + } + return al.stopActiveTurnForSession(opts.Dispatch.SessionKey) + } if agent != nil && agent.ContextBuilder != nil { rt.ListSkillNames = agent.ContextBuilder.ListSkillNames } diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index ff242aef7..496fcd7e4 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -292,7 +292,7 @@ func (p *Pipeline) CallLLM( if isNetworkError && retry < maxRetries { backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second al.emitEvent( - EventKindLLMRetry, + runtimeevents.KindAgentLLMRetry, ts.eventMeta("runTurn", "turn.llm.retry"), LLMRetryPayload{ Attempt: retry + 1, diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index ba171fe5d..7bddbfc31 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -156,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int { return len(sq.queues[normalizeSteeringScope(scope)]) } +func (sq *steeringQueue) clearScope(scope string) int { + sq.mu.Lock() + defer sq.mu.Unlock() + + scope = normalizeSteeringScope(scope) + count := len(sq.queues[scope]) + if count > 0 { + delete(sq.queues, scope) + } + return count +} + // setMode updates the steering mode. func (sq *steeringQueue) setMode(mode SteeringMode) { sq.mu.Lock() @@ -290,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int { return al.steering.lenScope(scope) } +func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int { + if al.steering == nil { + return 0 + } + return al.steering.clearScope(scope) +} + func (al *AgentLoop) continueWithSteeringMessages( ctx context.Context, agent *AgentInstance, @@ -511,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error { "initial_history_length": ts.initialHistoryLength, }) + // Cancel the active provider/tool turn contexts immediately so long-running + // execution stops as soon as possible on the root turn. + _ = ts.requestHardAbort() + // IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns // from adding more messages to the session. This prevents race conditions // where rollback happens while children are still writing. diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index 25e06d7a2..1ee1653e9 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -1392,6 +1392,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) { } } +func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "cancel_tool", + Function: &providers.FunctionCall{ + Name: "cancel_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "should not continue", + } + + al := NewAgentLoop(cfg, msgBus, provider) + started := make(chan struct{}) + al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started}) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) + + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- al.Run(runCtx) + }() + defer func() { + cancelRun() + select { + case err := <-runErrCh: + if err != nil { + t.Fatalf("Run() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for Run to stop") + } + }() + + baseMsg := testInboundMessage(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + SessionKey: sessionKey, + }) + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: baseMsg.Context, + Content: "do work", + SessionKey: sessionKey, + }); err != nil { + t.Fatalf("PublishInbound(start) error = %v", err) + } + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for interruptible tool to start") + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: baseMsg.Context, + Content: "follow up after cancel", + SessionKey: sessionKey, + }); err != nil { + t.Fatalf("PublishInbound(follow-up) error = %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for al.pendingSteeringCountForScope(sessionKey) == 0 { + if time.Now().After(deadline) { + t.Fatal("timeout waiting for follow-up message to enter steering queue") + } + time.Sleep(10 * time.Millisecond) + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: baseMsg.Context, + Content: "/stop", + SessionKey: sessionKey, + }); err != nil { + t.Fatalf("PublishInbound(/stop) error = %v", err) + } + + select { + case outbound := <-msgBus.OutboundChan(): + want := "⏹️ Task stopped. \"do work\" was canceled." + if outbound.Content != want { + t.Fatalf("stop reply = %q, want %q", outbound.Content, want) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for /stop reply") + } + + deadline = time.Now().Add(5 * time.Second) + for al.GetActiveTurnBySession(sessionKey) != nil { + if time.Now().After(deadline) { + t.Fatal("timeout waiting for active turn to stop") + } + time.Sleep(10 * time.Millisecond) + } + + if got := al.pendingSteeringCountForScope(sessionKey); got != 0 { + t.Fatalf("expected cleared steering queue, got %d pending message(s)", got) + } + + select { + case outbound := <-msgBus.OutboundChan(): + t.Fatalf("unexpected outbound after stop: %q", outbound.Content) + case <-time.After(300 * time.Millisecond): + } + + provider.mu.Lock() + calls := provider.calls + provider.mu.Unlock() + if calls != 1 { + t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls) + } +} + // capturingMockProvider captures messages sent to Chat for inspection. type capturingMockProvider struct { response string diff --git a/pkg/agent/turn_coord.go b/pkg/agent/turn_coord.go index ae6bd8c82..2826e662c 100644 --- a/pkg/agent/turn_coord.go +++ b/pkg/agent/turn_coord.go @@ -26,6 +26,10 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel al.registerActiveTurn(ts) defer al.clearActiveTurn(ts) + if al.takePendingStop(ts.sessionKey) { + _ = ts.requestHardAbort() + } + turnStatus := TurnEndStatusCompleted defer func() { al.emitEvent( @@ -40,6 +44,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel ) }() + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + al.emitEvent( runtimeevents.KindAgentTurnStart, ts.eventMeta("runTurn", "turn.start"), diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index 85e7dd3c0..b769ebcd0 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop // Bind session store and capture initial history length for rollback logic if agent != nil && agent.Sessions != nil { ts.session = agent.Sessions - ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey)) + history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey) + ts.initialHistoryLength = len(history) + ts.restorePointHistory = append([]providers.Message(nil), history...) + ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey) } return ts diff --git a/pkg/commands/builtin.go b/pkg/commands/builtin.go index a7e401bb8..e268812a0 100644 --- a/pkg/commands/builtin.go +++ b/pkg/commands/builtin.go @@ -8,6 +8,7 @@ func BuiltinDefinitions() []Definition { return []Definition{ startCommand(), helpCommand(), + stopCommand(), showCommand(), listCommand(), useCommand(), diff --git a/pkg/commands/builtin_test.go b/pkg/commands/builtin_test.go index efd27fa00..bb9abe360 100644 --- a/pkg/commands/builtin_test.go +++ b/pkg/commands/builtin_test.go @@ -42,6 +42,9 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { if !strings.Contains(reply, "/list [models|channels|agents|skills|mcp]") { t.Fatalf("/help reply missing /list usage, got %q", reply) } + if !strings.Contains(reply, "/stop") { + t.Fatalf("/help reply missing /stop usage, got %q", reply) + } if !strings.Contains(reply, "/use ") { if !strings.Contains(reply, "/use [message]") { t.Fatalf("/help reply missing /use usage, got %q", reply) @@ -49,6 +52,59 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { } } +func TestBuiltinStop_UsesRuntimeStopper(t *testing.T) { + rt := &Runtime{ + StopActiveTurn: func() (StopResult, error) { + return StopResult{ + Stopped: true, + TaskName: "sync the long running job", + }, nil + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/stop", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Task stopped. \"sync the long running job\" was canceled." { + t.Fatalf("/stop reply=%q", reply) + } +} + +func TestBuiltinStop_NoActiveTask(t *testing.T) { + rt := &Runtime{ + StopActiveTurn: func() (StopResult, error) { + return StopResult{}, nil + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/stop", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "No active task to stop." { + t.Fatalf("/stop reply=%q, want no-active message", reply) + } +} + func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) { defs := BuiltinDefinitions() ex := NewExecutor(NewRegistry(defs), nil) diff --git a/pkg/commands/runtime.go b/pkg/commands/runtime.go index c17b7cf1c..b0327c863 100644 --- a/pkg/commands/runtime.go +++ b/pkg/commands/runtime.go @@ -36,6 +36,12 @@ type ContextStats struct { MessageCount int } +// StopResult describes the outcome of a stop request for the current session. +type StopResult struct { + Stopped bool + TaskName string +} + // Runtime provides runtime dependencies to command handlers. It is constructed // per-request by the agent loop so that per-request state (like session scope) // can coexist with long-lived callbacks (like GetModelInfo). @@ -55,4 +61,5 @@ type Runtime struct { SwitchChannel func(value string) error ClearHistory func() error ReloadConfig func() error + StopActiveTurn func() (StopResult, error) }