diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 84849aece..97ee4fe7d 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -58,6 +58,7 @@ type AgentLoop struct { hookRuntime hookRuntime steering *steeringQueue pendingSkills sync.Map + pendingStops sync.Map mu sync.RWMutex // workerSem limits concurrent turn processing workers. @@ -177,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error { phase: TurnPhaseSetup, } if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded { + if al.tryHandleStopCommand(ctx, msg, sessionKey) { + continue + } + // Another turn is already active (or reserved) for this session — enqueue if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{ Role: "user", @@ -240,6 +245,24 @@ func (al *AgentLoop) Run(ctx context.Context) error { defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID) } + if al.takePendingStop(sessionKey) { + al.activeTurnStates.Delete(sessionKey) + target := &continuationTarget{ + SessionKey: sessionKey, + Channel: m.Channel, + ChatID: m.ChatID, + } + continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target) + if continueErr != nil { + al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr) + return + } + if continued != "" { + al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued) + } + return + } + al.runTurnWithSteering(ctx, m) }(msg) diff --git a/pkg/agent/agent_command.go b/pkg/agent/agent_command.go index a2ed068d6..ae0293d71 100644 --- a/pkg/agent/agent_command.go +++ b/pkg/agent/agent_command.go @@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime( return nil }, } + rt.StopActiveTurn = func() (commands.StopResult, error) { + if opts == nil { + return commands.StopResult{}, fmt.Errorf("process options not available") + } + return al.stopActiveTurnForSession(opts.Dispatch.SessionKey) + } if agent != nil && agent.ContextBuilder != nil { rt.ListSkillNames = agent.ContextBuilder.ListSkillNames } diff --git a/pkg/agent/agent_steering.go b/pkg/agent/agent_steering.go index c674bcafa..9b136e7cd 100644 --- a/pkg/agent/agent_steering.go +++ b/pkg/agent/agent_steering.go @@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb return } - // Drain steering queue using existing Continue mechanism + continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target) + if continueErr != nil { + logger.WarnCF("agent", "Failed to continue queued steering", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "error": continueErr.Error(), + }) + } else if continued != "" { + finalResponse = continued + } + + // Publish final response + if finalResponse != "" { + al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse) + } +} + +func (al *AgentLoop) drainQueuedSteeringContinuations( + ctx context.Context, + target *continuationTarget, +) (string, error) { + if target == nil { + return "", nil + } + + finalResponse := "" for al.pendingSteeringCountForScope(target.SessionKey) > 0 { - // Check for context cancellation between iterations - if ctx.Err() != nil { - return + if err := ctx.Err(); err != nil { + return finalResponse, err } logger.InfoCF("agent", "Continuing queued steering after turn end", @@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) if continueErr != nil { - logger.WarnCF("agent", "Failed to continue queued steering", - map[string]any{ - "channel": target.Channel, - "chat_id": target.ChatID, - "error": continueErr.Error(), - }) - break + return finalResponse, continueErr } if continued == "" { break @@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb finalResponse = continued } - // Publish final response - if finalResponse != "" { - al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse) - } + return finalResponse, nil } func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) { diff --git a/pkg/agent/agent_stop.go b/pkg/agent/agent_stop.go new file mode 100644 index 000000000..54cd51477 --- /dev/null +++ b/pkg/agent/agent_stop.go @@ -0,0 +1,122 @@ +package agent + +import ( + "context" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/commands" +) + +func (al *AgentLoop) tryHandleStopCommand( + ctx context.Context, + msg bus.InboundMessage, + sessionKey string, +) bool { + cmdName, ok := commands.CommandName(msg.Content) + if !ok || cmdName != "stop" { + return false + } + + result, err := al.stopActiveTurnForSession(sessionKey) + + // This function is only called when loaded=true (another turn already + // claimed this session). If stopActiveTurnForSession found a pending + // placeholder but didn't stop it, that placeholder belongs to the other + // message's worker which hasn't started yet — arm a pending stop so the + // worker will bail when it checks before running. + if err == nil && !result.Stopped { + if ts := al.getActiveTurnState(sessionKey); ts != nil { + snap := ts.snapshot() + if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) { + al.markPendingStop(sessionKey) + result.Stopped = true + } + } + } + + reply := commands.FormatStopReply(result) + if err != nil { + reply = "Failed to stop task: " + err.Error() + } + + if al.channelManager != nil { + al.channelManager.InvokeTypingStop(msg.Channel, msg.ChatID) + } + al.resetMessageToolRound(sessionKey) + al.PublishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, sessionKey, reply) + return true +} + +func (al *AgentLoop) stopActiveTurnForSession(sessionKey string) (commands.StopResult, error) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return commands.StopResult{}, fmt.Errorf("session key is required") + } + + result := commands.StopResult{} + cleared := al.clearSteeringMessagesForScope(sessionKey) + al.clearPendingSkills(sessionKey) + + ts := al.getActiveTurnState(sessionKey) + if ts == nil { + result.Stopped = cleared > 0 + return result, nil + } + + snap := ts.snapshot() + result.TaskName = snap.UserMessage + + if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) { + // A pending placeholder means this session is either idle (our own + // placeholder from the /stop command) or another message is queued but + // hasn't started yet. In both cases, we don't arm a pending stop here; + // the caller (tryHandleStopCommand) handles the "another message queued" + // case explicitly, since it knows loaded=true. + return result, nil + } + + if err := al.HardAbort(sessionKey); err != nil { + if al.getActiveTurnState(sessionKey) == nil { + result.Stopped = cleared > 0 + return result, nil + } + return commands.StopResult{}, err + } + + result.Stopped = true + return result, nil +} + +func (al *AgentLoop) markPendingStop(sessionKey string) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return + } + al.pendingStops.Store(sessionKey, struct{}{}) +} + +func (al *AgentLoop) takePendingStop(sessionKey string) bool { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return false + } + _, ok := al.pendingStops.LoadAndDelete(sessionKey) + return ok +} + +func (al *AgentLoop) resetMessageToolRound(sessionKey string) { + if strings.TrimSpace(sessionKey) == "" { + return + } + if registry := al.GetRegistry(); registry != nil { + if agent := registry.GetDefaultAgent(); agent != nil { + if tool, ok := agent.Tools.Get("message"); ok { + if resetter, ok := tool.(interface{ ResetSentInRound(sessionKey string) }); ok { + resetter.ResetSentInRound(sessionKey) + } + } + } + } +} diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index ba171fe5d..7bddbfc31 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -156,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int { return len(sq.queues[normalizeSteeringScope(scope)]) } +func (sq *steeringQueue) clearScope(scope string) int { + sq.mu.Lock() + defer sq.mu.Unlock() + + scope = normalizeSteeringScope(scope) + count := len(sq.queues[scope]) + if count > 0 { + delete(sq.queues, scope) + } + return count +} + // setMode updates the steering mode. func (sq *steeringQueue) setMode(mode SteeringMode) { sq.mu.Lock() @@ -290,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int { return al.steering.lenScope(scope) } +func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int { + if al.steering == nil { + return 0 + } + return al.steering.clearScope(scope) +} + func (al *AgentLoop) continueWithSteeringMessages( ctx context.Context, agent *AgentInstance, @@ -511,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error { "initial_history_length": ts.initialHistoryLength, }) + // Cancel the active provider/tool turn contexts immediately so long-running + // execution stops as soon as possible on the root turn. + _ = ts.requestHardAbort() + // IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns // from adding more messages to the session. This prevents race conditions // where rollback happens while children are still writing. diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index 25e06d7a2..813013649 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -840,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { } } +func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + MaxParallelTurns: 1, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &lateSteeringProvider{ + firstCallStarted: make(chan struct{}), + releaseFirstCall: make(chan struct{}), + } + al := NewAgentLoop(cfg, msgBus, provider) + + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- al.Run(runCtx) + }() + defer func() { + cancelRun() + select { + case err := <-runErrCh: + if err != nil { + t.Fatalf("Run() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for Run to stop") + } + }() + + blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker") + targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target") + blockerCtx := bus.InboundContext{ + Channel: "test", + ChatID: "blocker-chat", + ChatType: "direct", + SenderID: "user1", + } + targetCtx := bus.InboundContext{ + Channel: "test", + ChatID: "target-chat", + ChatType: "direct", + SenderID: "user1", + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: blockerCtx, + Content: "block worker pool", + SessionKey: blockerSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(blocker) error = %v", err) + } + + select { + case <-provider.firstCallStarted: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for blocker turn to start") + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: targetCtx, + Content: "skip this turn", + SessionKey: targetSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(target start) error = %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for { + ts := al.getActiveTurnState(targetSessionKey) + if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) { + break + } + if time.Now().After(deadline) { + t.Fatal("timeout waiting for pending placeholder") + } + time.Sleep(10 * time.Millisecond) + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: targetCtx, + Content: "/stop", + SessionKey: targetSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(/stop) error = %v", err) + } + + deadline = time.Now().Add(2 * time.Second) + stopSeen := false + for !stopSeen { + select { + case outbound := <-msgBus.OutboundChan(): + if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." { + stopSeen = true + } + case <-time.After(10 * time.Millisecond): + if time.Now().After(deadline) { + t.Fatal("timeout waiting for /stop reply") + } + } + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: targetCtx, + Content: "run this instead", + SessionKey: targetSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(follow-up) error = %v", err) + } + + deadline = time.Now().Add(2 * time.Second) + for al.pendingSteeringCountForScope(targetSessionKey) == 0 { + if time.Now().After(deadline) { + t.Fatal("timeout waiting for follow-up to enter scoped steering queue") + } + time.Sleep(10 * time.Millisecond) + } + + close(provider.releaseFirstCall) + + deadline = time.Now().Add(5 * time.Second) + followUpSeen := false + for !followUpSeen { + select { + case outbound := <-msgBus.OutboundChan(): + if outbound.ChatID == "target-chat" && outbound.Content == "continued response" { + followUpSeen = true + } + case <-time.After(10 * time.Millisecond): + if time.Now().After(deadline) { + t.Fatal("timeout waiting for queued follow-up continuation") + } + } + } + + deadline = time.Now().Add(2 * time.Second) + for { + if al.GetActiveTurnBySession(targetSessionKey) == nil && + al.pendingSteeringCountForScope(targetSessionKey) == 0 { + break + } + if time.Now().After(deadline) { + t.Fatal("timeout waiting for target session to go idle") + } + time.Sleep(10 * time.Millisecond) + } + + provider.mu.Lock() + calls := provider.calls + secondMessages := append([]providers.Message(nil), provider.secondCallMessages...) + provider.mu.Unlock() + + if calls != 2 { + t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls) + } + + foundFollowUp := false + for _, msg := range secondMessages { + if msg.Role == "user" && msg.Content == "run this instead" { + foundFollowUp = true + } + if msg.Role == "user" && msg.Content == "skip this turn" { + t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content) + } + } + if !foundFollowUp { + t.Fatal("expected queued follow-up to be processed after pending stop") + } +} + func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { @@ -1392,6 +1577,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) { } } +func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "cancel_tool", + Function: &providers.FunctionCall{ + Name: "cancel_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "should not continue", + } + + al := NewAgentLoop(cfg, msgBus, provider) + started := make(chan struct{}) + al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started}) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) + + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- al.Run(runCtx) + }() + defer func() { + cancelRun() + select { + case err := <-runErrCh: + if err != nil { + t.Fatalf("Run() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for Run to stop") + } + }() + + baseMsg := testInboundMessage(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + SessionKey: sessionKey, + }) + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: baseMsg.Context, + Content: "do work", + SessionKey: sessionKey, + }); err != nil { + t.Fatalf("PublishInbound(start) error = %v", err) + } + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for interruptible tool to start") + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: baseMsg.Context, + Content: "follow up after cancel", + SessionKey: sessionKey, + }); err != nil { + t.Fatalf("PublishInbound(follow-up) error = %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for al.pendingSteeringCountForScope(sessionKey) == 0 { + if time.Now().After(deadline) { + t.Fatal("timeout waiting for follow-up message to enter steering queue") + } + time.Sleep(10 * time.Millisecond) + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: baseMsg.Context, + Content: "/stop", + SessionKey: sessionKey, + }); err != nil { + t.Fatalf("PublishInbound(/stop) error = %v", err) + } + + select { + case outbound := <-msgBus.OutboundChan(): + want := "Task stopped. \"do work\" was canceled." + if outbound.Content != want { + t.Fatalf("stop reply = %q, want %q", outbound.Content, want) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for /stop reply") + } + + deadline = time.Now().Add(5 * time.Second) + for al.GetActiveTurnBySession(sessionKey) != nil { + if time.Now().After(deadline) { + t.Fatal("timeout waiting for active turn to stop") + } + time.Sleep(10 * time.Millisecond) + } + + if got := al.pendingSteeringCountForScope(sessionKey); got != 0 { + t.Fatalf("expected cleared steering queue, got %d pending message(s)", got) + } + + select { + case outbound := <-msgBus.OutboundChan(): + t.Fatalf("unexpected outbound after stop: %q", outbound.Content) + case <-time.After(300 * time.Millisecond): + } + + provider.mu.Lock() + calls := provider.calls + provider.mu.Unlock() + if calls != 1 { + t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls) + } +} + // capturingMockProvider captures messages sent to Chat for inspection. type capturingMockProvider struct { response string diff --git a/pkg/agent/turn_coord.go b/pkg/agent/turn_coord.go index ae6bd8c82..2826e662c 100644 --- a/pkg/agent/turn_coord.go +++ b/pkg/agent/turn_coord.go @@ -26,6 +26,10 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel al.registerActiveTurn(ts) defer al.clearActiveTurn(ts) + if al.takePendingStop(ts.sessionKey) { + _ = ts.requestHardAbort() + } + turnStatus := TurnEndStatusCompleted defer func() { al.emitEvent( @@ -40,6 +44,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel ) }() + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + al.emitEvent( runtimeevents.KindAgentTurnStart, ts.eventMeta("runTurn", "turn.start"), diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index 85e7dd3c0..b769ebcd0 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop // Bind session store and capture initial history length for rollback logic if agent != nil && agent.Sessions != nil { ts.session = agent.Sessions - ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey)) + history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey) + ts.initialHistoryLength = len(history) + ts.restorePointHistory = append([]providers.Message(nil), history...) + ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey) } return ts diff --git a/pkg/commands/builtin.go b/pkg/commands/builtin.go index a7e401bb8..e268812a0 100644 --- a/pkg/commands/builtin.go +++ b/pkg/commands/builtin.go @@ -8,6 +8,7 @@ func BuiltinDefinitions() []Definition { return []Definition{ startCommand(), helpCommand(), + stopCommand(), showCommand(), listCommand(), useCommand(), diff --git a/pkg/commands/builtin_test.go b/pkg/commands/builtin_test.go index efd27fa00..bb9abe360 100644 --- a/pkg/commands/builtin_test.go +++ b/pkg/commands/builtin_test.go @@ -42,6 +42,9 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { if !strings.Contains(reply, "/list [models|channels|agents|skills|mcp]") { t.Fatalf("/help reply missing /list usage, got %q", reply) } + if !strings.Contains(reply, "/stop") { + t.Fatalf("/help reply missing /stop usage, got %q", reply) + } if !strings.Contains(reply, "/use ") { if !strings.Contains(reply, "/use [message]") { t.Fatalf("/help reply missing /use usage, got %q", reply) @@ -49,6 +52,59 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { } } +func TestBuiltinStop_UsesRuntimeStopper(t *testing.T) { + rt := &Runtime{ + StopActiveTurn: func() (StopResult, error) { + return StopResult{ + Stopped: true, + TaskName: "sync the long running job", + }, nil + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/stop", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Task stopped. \"sync the long running job\" was canceled." { + t.Fatalf("/stop reply=%q", reply) + } +} + +func TestBuiltinStop_NoActiveTask(t *testing.T) { + rt := &Runtime{ + StopActiveTurn: func() (StopResult, error) { + return StopResult{}, nil + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/stop", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "No active task to stop." { + t.Fatalf("/stop reply=%q, want no-active message", reply) + } +} + func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) { defs := BuiltinDefinitions() ex := NewExecutor(NewRegistry(defs), nil) diff --git a/pkg/commands/cmd_stop.go b/pkg/commands/cmd_stop.go new file mode 100644 index 000000000..147688bdc --- /dev/null +++ b/pkg/commands/cmd_stop.go @@ -0,0 +1,52 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +func stopCommand() Definition { + return Definition{ + Name: "stop", + Description: "Stop the current task", + Usage: "/stop", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.StopActiveTurn == nil { + return req.Reply(unavailableMsg) + } + + result, err := rt.StopActiveTurn() + if err != nil { + return req.Reply("Failed to stop task: " + err.Error()) + } + + return req.Reply(FormatStopReply(result)) + }, + } +} + +// FormatStopReply renders a user-facing reply for a stop request. +func FormatStopReply(result StopResult) string { + if !result.Stopped { + return "No active task to stop." + } + + taskName := compactStopTaskName(result.TaskName) + if taskName == "" { + return "Task stopped. Current task was canceled." + } + + return fmt.Sprintf("Task stopped. %q was canceled.", taskName) +} + +func compactStopTaskName(taskName string) string { + taskName = strings.Join(strings.Fields(strings.TrimSpace(taskName)), " ") + if taskName == "" { + return "" + } + if len(taskName) > 80 { + return taskName[:77] + "..." + } + return taskName +} diff --git a/pkg/commands/runtime.go b/pkg/commands/runtime.go index c17b7cf1c..b0327c863 100644 --- a/pkg/commands/runtime.go +++ b/pkg/commands/runtime.go @@ -36,6 +36,12 @@ type ContextStats struct { MessageCount int } +// StopResult describes the outcome of a stop request for the current session. +type StopResult struct { + Stopped bool + TaskName string +} + // Runtime provides runtime dependencies to command handlers. It is constructed // per-request by the agent loop so that per-request state (like session scope) // can coexist with long-lived callbacks (like GetModelInfo). @@ -55,4 +61,5 @@ type Runtime struct { SwitchChannel func(value string) error ClearHistory func() error ReloadConfig func() error + StopActiveTurn func() (StopResult, error) }