mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2762 from afjcjsbx/feat/stop-command
feat(agent): stop command
This commit is contained in:
@@ -58,6 +58,7 @@ type AgentLoop struct {
|
||||
hookRuntime hookRuntime
|
||||
steering *steeringQueue
|
||||
pendingSkills sync.Map
|
||||
pendingStops sync.Map
|
||||
mu sync.RWMutex
|
||||
|
||||
// workerSem limits concurrent turn processing workers.
|
||||
@@ -177,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
phase: TurnPhaseSetup,
|
||||
}
|
||||
if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded {
|
||||
if al.tryHandleStopCommand(ctx, msg, sessionKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Another turn is already active (or reserved) for this session — enqueue
|
||||
if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{
|
||||
Role: "user",
|
||||
@@ -240,6 +245,24 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID)
|
||||
}
|
||||
|
||||
if al.takePendingStop(sessionKey) {
|
||||
al.activeTurnStates.Delete(sessionKey)
|
||||
target := &continuationTarget{
|
||||
SessionKey: sessionKey,
|
||||
Channel: m.Channel,
|
||||
ChatID: m.ChatID,
|
||||
}
|
||||
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
|
||||
if continueErr != nil {
|
||||
al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr)
|
||||
return
|
||||
}
|
||||
if continued != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
al.runTurnWithSteering(ctx, m)
|
||||
}(msg)
|
||||
|
||||
|
||||
@@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime(
|
||||
return nil
|
||||
},
|
||||
}
|
||||
rt.StopActiveTurn = func() (commands.StopResult, error) {
|
||||
if opts == nil {
|
||||
return commands.StopResult{}, fmt.Errorf("process options not available")
|
||||
}
|
||||
return al.stopActiveTurnForSession(opts.Dispatch.SessionKey)
|
||||
}
|
||||
if agent != nil && agent.ContextBuilder != nil {
|
||||
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
|
||||
}
|
||||
|
||||
+31
-15
@@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
return
|
||||
}
|
||||
|
||||
// Drain steering queue using existing Continue mechanism
|
||||
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
})
|
||||
} else if continued != "" {
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
// Publish final response
|
||||
if finalResponse != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) drainQueuedSteeringContinuations(
|
||||
ctx context.Context,
|
||||
target *continuationTarget,
|
||||
) (string, error) {
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
finalResponse := ""
|
||||
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
|
||||
// Check for context cancellation between iterations
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
if err := ctx.Err(); err != nil {
|
||||
return finalResponse, err
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Continuing queued steering after turn end",
|
||||
@@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
|
||||
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
})
|
||||
break
|
||||
return finalResponse, continueErr
|
||||
}
|
||||
if continued == "" {
|
||||
break
|
||||
@@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
// Publish final response
|
||||
if finalResponse != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
|
||||
}
|
||||
return finalResponse, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
func (al *AgentLoop) tryHandleStopCommand(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
sessionKey string,
|
||||
) bool {
|
||||
cmdName, ok := commands.CommandName(msg.Content)
|
||||
if !ok || cmdName != "stop" {
|
||||
return false
|
||||
}
|
||||
|
||||
result, err := al.stopActiveTurnForSession(sessionKey)
|
||||
|
||||
// This function is only called when loaded=true (another turn already
|
||||
// claimed this session). If stopActiveTurnForSession found a pending
|
||||
// placeholder but didn't stop it, that placeholder belongs to the other
|
||||
// message's worker which hasn't started yet — arm a pending stop so the
|
||||
// worker will bail when it checks before running.
|
||||
if err == nil && !result.Stopped {
|
||||
if ts := al.getActiveTurnState(sessionKey); ts != nil {
|
||||
snap := ts.snapshot()
|
||||
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
|
||||
al.markPendingStop(sessionKey)
|
||||
result.Stopped = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reply := commands.FormatStopReply(result)
|
||||
if err != nil {
|
||||
reply = "Failed to stop task: " + err.Error()
|
||||
}
|
||||
|
||||
if al.channelManager != nil {
|
||||
al.channelManager.InvokeTypingStop(msg.Channel, msg.ChatID)
|
||||
}
|
||||
al.resetMessageToolRound(sessionKey)
|
||||
al.PublishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, sessionKey, reply)
|
||||
return true
|
||||
}
|
||||
|
||||
func (al *AgentLoop) stopActiveTurnForSession(sessionKey string) (commands.StopResult, error) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return commands.StopResult{}, fmt.Errorf("session key is required")
|
||||
}
|
||||
|
||||
result := commands.StopResult{}
|
||||
cleared := al.clearSteeringMessagesForScope(sessionKey)
|
||||
al.clearPendingSkills(sessionKey)
|
||||
|
||||
ts := al.getActiveTurnState(sessionKey)
|
||||
if ts == nil {
|
||||
result.Stopped = cleared > 0
|
||||
return result, nil
|
||||
}
|
||||
|
||||
snap := ts.snapshot()
|
||||
result.TaskName = snap.UserMessage
|
||||
|
||||
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
|
||||
// A pending placeholder means this session is either idle (our own
|
||||
// placeholder from the /stop command) or another message is queued but
|
||||
// hasn't started yet. In both cases, we don't arm a pending stop here;
|
||||
// the caller (tryHandleStopCommand) handles the "another message queued"
|
||||
// case explicitly, since it knows loaded=true.
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if err := al.HardAbort(sessionKey); err != nil {
|
||||
if al.getActiveTurnState(sessionKey) == nil {
|
||||
result.Stopped = cleared > 0
|
||||
return result, nil
|
||||
}
|
||||
return commands.StopResult{}, err
|
||||
}
|
||||
|
||||
result.Stopped = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) markPendingStop(sessionKey string) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return
|
||||
}
|
||||
al.pendingStops.Store(sessionKey, struct{}{})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) takePendingStop(sessionKey string) bool {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := al.pendingStops.LoadAndDelete(sessionKey)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resetMessageToolRound(sessionKey string) {
|
||||
if strings.TrimSpace(sessionKey) == "" {
|
||||
return
|
||||
}
|
||||
if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound(sessionKey string) }); ok {
|
||||
resetter.ResetSentInRound(sessionKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -156,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int {
|
||||
return len(sq.queues[normalizeSteeringScope(scope)])
|
||||
}
|
||||
|
||||
func (sq *steeringQueue) clearScope(scope string) int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
scope = normalizeSteeringScope(scope)
|
||||
count := len(sq.queues[scope])
|
||||
if count > 0 {
|
||||
delete(sq.queues, scope)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// setMode updates the steering mode.
|
||||
func (sq *steeringQueue) setMode(mode SteeringMode) {
|
||||
sq.mu.Lock()
|
||||
@@ -290,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
|
||||
return al.steering.lenScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int {
|
||||
if al.steering == nil {
|
||||
return 0
|
||||
}
|
||||
return al.steering.clearScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
@@ -511,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
"initial_history_length": ts.initialHistoryLength,
|
||||
})
|
||||
|
||||
// Cancel the active provider/tool turn contexts immediately so long-running
|
||||
// execution stops as soon as possible on the root turn.
|
||||
_ = ts.requestHardAbort()
|
||||
|
||||
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
|
||||
// from adding more messages to the session. This prevents race conditions
|
||||
// where rollback happens while children are still writing.
|
||||
|
||||
@@ -840,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
MaxParallelTurns: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &lateSteeringProvider{
|
||||
firstCallStarted: make(chan struct{}),
|
||||
releaseFirstCall: make(chan struct{}),
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
defer func() {
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}()
|
||||
|
||||
blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker")
|
||||
targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target")
|
||||
blockerCtx := bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "blocker-chat",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
}
|
||||
targetCtx := bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "target-chat",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: blockerCtx,
|
||||
Content: "block worker pool",
|
||||
SessionKey: blockerSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(blocker) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstCallStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for blocker turn to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "skip this turn",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(target start) error = %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
ts := al.getActiveTurnState(targetSessionKey)
|
||||
if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for pending placeholder")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "/stop",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(/stop) error = %v", err)
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
stopSeen := false
|
||||
for !stopSeen {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." {
|
||||
stopSeen = true
|
||||
}
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for /stop reply")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "run this instead",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(follow-up) error = %v", err)
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for al.pendingSteeringCountForScope(targetSessionKey) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for follow-up to enter scoped steering queue")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
close(provider.releaseFirstCall)
|
||||
|
||||
deadline = time.Now().Add(5 * time.Second)
|
||||
followUpSeen := false
|
||||
for !followUpSeen {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.ChatID == "target-chat" && outbound.Content == "continued response" {
|
||||
followUpSeen = true
|
||||
}
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for queued follow-up continuation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if al.GetActiveTurnBySession(targetSessionKey) == nil &&
|
||||
al.pendingSteeringCountForScope(targetSessionKey) == 0 {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for target session to go idle")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls)
|
||||
}
|
||||
|
||||
foundFollowUp := false
|
||||
for _, msg := range secondMessages {
|
||||
if msg.Role == "user" && msg.Content == "run this instead" {
|
||||
foundFollowUp = true
|
||||
}
|
||||
if msg.Role == "user" && msg.Content == "skip this turn" {
|
||||
t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
if !foundFollowUp {
|
||||
t.Fatal("expected queued follow-up to be processed after pending stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
@@ -1392,6 +1577,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "cancel_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "cancel_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "should not continue",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
defer func() {
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}()
|
||||
|
||||
baseMsg := testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
SessionKey: sessionKey,
|
||||
})
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "do work",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(start) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for interruptible tool to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "follow up after cancel",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(follow-up) error = %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for al.pendingSteeringCountForScope(sessionKey) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for follow-up message to enter steering queue")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "/stop",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(/stop) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
want := "Task stopped. \"do work\" was canceled."
|
||||
if outbound.Content != want {
|
||||
t.Fatalf("stop reply = %q, want %q", outbound.Content, want)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /stop reply")
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(5 * time.Second)
|
||||
for al.GetActiveTurnBySession(sessionKey) != nil {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for active turn to stop")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if got := al.pendingSteeringCountForScope(sessionKey); got != 0 {
|
||||
t.Fatalf("expected cleared steering queue, got %d pending message(s)", got)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("unexpected outbound after stop: %q", outbound.Content)
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if calls != 1 {
|
||||
t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls)
|
||||
}
|
||||
}
|
||||
|
||||
// capturingMockProvider captures messages sent to Chat for inspection.
|
||||
type capturingMockProvider struct {
|
||||
response string
|
||||
|
||||
@@ -26,6 +26,10 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
|
||||
al.registerActiveTurn(ts)
|
||||
defer al.clearActiveTurn(ts)
|
||||
|
||||
if al.takePendingStop(ts.sessionKey) {
|
||||
_ = ts.requestHardAbort()
|
||||
}
|
||||
|
||||
turnStatus := TurnEndStatusCompleted
|
||||
defer func() {
|
||||
al.emitEvent(
|
||||
@@ -40,6 +44,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
|
||||
)
|
||||
}()
|
||||
|
||||
if ts.hardAbortRequested() {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
runtimeevents.KindAgentTurnStart,
|
||||
ts.eventMeta("runTurn", "turn.start"),
|
||||
|
||||
@@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
|
||||
// Bind session store and capture initial history length for rollback logic
|
||||
if agent != nil && agent.Sessions != nil {
|
||||
ts.session = agent.Sessions
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
|
||||
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
|
||||
ts.initialHistoryLength = len(history)
|
||||
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
||||
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
|
||||
}
|
||||
|
||||
return ts
|
||||
|
||||
Reference in New Issue
Block a user