mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): stop command
This commit is contained in:
@@ -58,6 +58,7 @@ type AgentLoop struct {
|
||||
hookRuntime hookRuntime
|
||||
steering *steeringQueue
|
||||
pendingSkills sync.Map
|
||||
pendingStops sync.Map
|
||||
mu sync.RWMutex
|
||||
|
||||
// workerSem limits concurrent turn processing workers.
|
||||
@@ -177,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
phase: TurnPhaseSetup,
|
||||
}
|
||||
if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded {
|
||||
if al.tryHandleStopCommand(ctx, msg, sessionKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Another turn is already active (or reserved) for this session — enqueue
|
||||
if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{
|
||||
Role: "user",
|
||||
@@ -240,6 +245,11 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID)
|
||||
}
|
||||
|
||||
if al.takePendingStop(sessionKey) {
|
||||
al.activeTurnStates.Delete(sessionKey)
|
||||
return
|
||||
}
|
||||
|
||||
al.runTurnWithSteering(ctx, m)
|
||||
}(msg)
|
||||
|
||||
|
||||
@@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime(
|
||||
return nil
|
||||
},
|
||||
}
|
||||
rt.StopActiveTurn = func() (commands.StopResult, error) {
|
||||
if opts == nil {
|
||||
return commands.StopResult{}, fmt.Errorf("process options not available")
|
||||
}
|
||||
return al.stopActiveTurnForSession(opts.Dispatch.SessionKey)
|
||||
}
|
||||
if agent != nil && agent.ContextBuilder != nil {
|
||||
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
|
||||
}
|
||||
|
||||
@@ -292,7 +292,7 @@ func (p *Pipeline) CallLLM(
|
||||
if isNetworkError && retry < maxRetries {
|
||||
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
runtimeevents.KindAgentLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
|
||||
@@ -156,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int {
|
||||
return len(sq.queues[normalizeSteeringScope(scope)])
|
||||
}
|
||||
|
||||
func (sq *steeringQueue) clearScope(scope string) int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
scope = normalizeSteeringScope(scope)
|
||||
count := len(sq.queues[scope])
|
||||
if count > 0 {
|
||||
delete(sq.queues, scope)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// setMode updates the steering mode.
|
||||
func (sq *steeringQueue) setMode(mode SteeringMode) {
|
||||
sq.mu.Lock()
|
||||
@@ -290,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
|
||||
return al.steering.lenScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int {
|
||||
if al.steering == nil {
|
||||
return 0
|
||||
}
|
||||
return al.steering.clearScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
@@ -511,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
"initial_history_length": ts.initialHistoryLength,
|
||||
})
|
||||
|
||||
// Cancel the active provider/tool turn contexts immediately so long-running
|
||||
// execution stops as soon as possible on the root turn.
|
||||
_ = ts.requestHardAbort()
|
||||
|
||||
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
|
||||
// from adding more messages to the session. This prevents race conditions
|
||||
// where rollback happens while children are still writing.
|
||||
|
||||
@@ -1392,6 +1392,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "cancel_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "cancel_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "should not continue",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
defer func() {
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}()
|
||||
|
||||
baseMsg := testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
SessionKey: sessionKey,
|
||||
})
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "do work",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(start) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for interruptible tool to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "follow up after cancel",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(follow-up) error = %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for al.pendingSteeringCountForScope(sessionKey) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for follow-up message to enter steering queue")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "/stop",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(/stop) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
want := "⏹️ Task stopped. \"do work\" was canceled."
|
||||
if outbound.Content != want {
|
||||
t.Fatalf("stop reply = %q, want %q", outbound.Content, want)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /stop reply")
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(5 * time.Second)
|
||||
for al.GetActiveTurnBySession(sessionKey) != nil {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for active turn to stop")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if got := al.pendingSteeringCountForScope(sessionKey); got != 0 {
|
||||
t.Fatalf("expected cleared steering queue, got %d pending message(s)", got)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("unexpected outbound after stop: %q", outbound.Content)
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if calls != 1 {
|
||||
t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls)
|
||||
}
|
||||
}
|
||||
|
||||
// capturingMockProvider captures messages sent to Chat for inspection.
|
||||
type capturingMockProvider struct {
|
||||
response string
|
||||
|
||||
@@ -26,6 +26,10 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
|
||||
al.registerActiveTurn(ts)
|
||||
defer al.clearActiveTurn(ts)
|
||||
|
||||
if al.takePendingStop(ts.sessionKey) {
|
||||
_ = ts.requestHardAbort()
|
||||
}
|
||||
|
||||
turnStatus := TurnEndStatusCompleted
|
||||
defer func() {
|
||||
al.emitEvent(
|
||||
@@ -40,6 +44,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
|
||||
)
|
||||
}()
|
||||
|
||||
if ts.hardAbortRequested() {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
runtimeevents.KindAgentTurnStart,
|
||||
ts.eventMeta("runTurn", "turn.start"),
|
||||
|
||||
@@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
|
||||
// Bind session store and capture initial history length for rollback logic
|
||||
if agent != nil && agent.Sessions != nil {
|
||||
ts.session = agent.Sessions
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
|
||||
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
|
||||
ts.initialHistoryLength = len(history)
|
||||
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
||||
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
|
||||
}
|
||||
|
||||
return ts
|
||||
|
||||
@@ -8,6 +8,7 @@ func BuiltinDefinitions() []Definition {
|
||||
return []Definition{
|
||||
startCommand(),
|
||||
helpCommand(),
|
||||
stopCommand(),
|
||||
showCommand(),
|
||||
listCommand(),
|
||||
useCommand(),
|
||||
|
||||
@@ -42,6 +42,9 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
|
||||
if !strings.Contains(reply, "/list [models|channels|agents|skills|mcp]") {
|
||||
t.Fatalf("/help reply missing /list usage, got %q", reply)
|
||||
}
|
||||
if !strings.Contains(reply, "/stop") {
|
||||
t.Fatalf("/help reply missing /stop usage, got %q", reply)
|
||||
}
|
||||
if !strings.Contains(reply, "/use <skill> <message>") {
|
||||
if !strings.Contains(reply, "/use <skill> [message]") {
|
||||
t.Fatalf("/help reply missing /use usage, got %q", reply)
|
||||
@@ -49,6 +52,59 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinStop_UsesRuntimeStopper(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
StopActiveTurn: func() (StopResult, error) {
|
||||
return StopResult{
|
||||
Stopped: true,
|
||||
TaskName: "sync the long running job",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/stop",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Task stopped. \"sync the long running job\" was canceled." {
|
||||
t.Fatalf("/stop reply=%q", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinStop_NoActiveTask(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
StopActiveTurn: func() (StopResult, error) {
|
||||
return StopResult{}, nil
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/stop",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "No active task to stop." {
|
||||
t.Fatalf("/stop reply=%q, want no-active message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
@@ -36,6 +36,12 @@ type ContextStats struct {
|
||||
MessageCount int
|
||||
}
|
||||
|
||||
// StopResult describes the outcome of a stop request for the current session.
|
||||
type StopResult struct {
|
||||
Stopped bool
|
||||
TaskName string
|
||||
}
|
||||
|
||||
// Runtime provides runtime dependencies to command handlers. It is constructed
|
||||
// per-request by the agent loop so that per-request state (like session scope)
|
||||
// can coexist with long-lived callbacks (like GetModelInfo).
|
||||
@@ -55,4 +61,5 @@ type Runtime struct {
|
||||
SwitchChannel func(value string) error
|
||||
ClearHistory func() error
|
||||
ReloadConfig func() error
|
||||
StopActiveTurn func() (StopResult, error)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user