fix(agent): drain scoped follow-up queue when pending stop skips turn startup

This commit is contained in:
afjcjsbx
2026-05-05 19:24:15 +02:00
parent d63430ab33
commit a7e52e8a25
3 changed files with 229 additions and 15 deletions
+13
View File
@@ -247,6 +247,19 @@ func (al *AgentLoop) Run(ctx context.Context) error {
if al.takePendingStop(sessionKey) {
al.activeTurnStates.Delete(sessionKey)
target := &continuationTarget{
SessionKey: sessionKey,
Channel: m.Channel,
ChatID: m.ChatID,
}
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
if continueErr != nil {
al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr)
return
}
if continued != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued)
}
return
}
+31 -15
View File
@@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
return
}
// Drain steering queue using existing Continue mechanism
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
if continueErr != nil {
logger.WarnCF("agent", "Failed to continue queued steering",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"error": continueErr.Error(),
})
} else if continued != "" {
finalResponse = continued
}
// Publish final response
if finalResponse != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
}
}
func (al *AgentLoop) drainQueuedSteeringContinuations(
ctx context.Context,
target *continuationTarget,
) (string, error) {
if target == nil {
return "", nil
}
finalResponse := ""
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
// Check for context cancellation between iterations
if ctx.Err() != nil {
return
if err := ctx.Err(); err != nil {
return finalResponse, err
}
logger.InfoCF("agent", "Continuing queued steering after turn end",
@@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
if continueErr != nil {
logger.WarnCF("agent", "Failed to continue queued steering",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"error": continueErr.Error(),
})
break
return finalResponse, continueErr
}
if continued == "" {
break
@@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
finalResponse = continued
}
// Publish final response
if finalResponse != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
}
return finalResponse, nil
}
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
+185
View File
@@ -840,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
}
}
func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxParallelTurns: 1,
},
},
}
msgBus := bus.NewMessageBus()
provider := &lateSteeringProvider{
firstCallStarted: make(chan struct{}),
releaseFirstCall: make(chan struct{}),
}
al := NewAgentLoop(cfg, msgBus, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
defer func() {
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run() error = %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}()
blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker")
targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target")
blockerCtx := bus.InboundContext{
Channel: "test",
ChatID: "blocker-chat",
ChatType: "direct",
SenderID: "user1",
}
targetCtx := bus.InboundContext{
Channel: "test",
ChatID: "target-chat",
ChatType: "direct",
SenderID: "user1",
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: blockerCtx,
Content: "block worker pool",
SessionKey: blockerSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(blocker) error = %v", err)
}
select {
case <-provider.firstCallStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for blocker turn to start")
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "skip this turn",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(target start) error = %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for {
ts := al.getActiveTurnState(targetSessionKey)
if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) {
break
}
if time.Now().After(deadline) {
t.Fatal("timeout waiting for pending placeholder")
}
time.Sleep(10 * time.Millisecond)
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "/stop",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(/stop) error = %v", err)
}
deadline = time.Now().Add(2 * time.Second)
stopSeen := false
for !stopSeen {
select {
case outbound := <-msgBus.OutboundChan():
if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." {
stopSeen = true
}
case <-time.After(10 * time.Millisecond):
if time.Now().After(deadline) {
t.Fatal("timeout waiting for /stop reply")
}
}
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "run this instead",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(follow-up) error = %v", err)
}
deadline = time.Now().Add(2 * time.Second)
for al.pendingSteeringCountForScope(targetSessionKey) == 0 {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for follow-up to enter scoped steering queue")
}
time.Sleep(10 * time.Millisecond)
}
close(provider.releaseFirstCall)
deadline = time.Now().Add(5 * time.Second)
followUpSeen := false
for !followUpSeen {
select {
case outbound := <-msgBus.OutboundChan():
if outbound.ChatID == "target-chat" && outbound.Content == "continued response" {
followUpSeen = true
}
case <-time.After(10 * time.Millisecond):
if time.Now().After(deadline) {
t.Fatal("timeout waiting for queued follow-up continuation")
}
}
}
deadline = time.Now().Add(2 * time.Second)
for {
if al.GetActiveTurnBySession(targetSessionKey) == nil &&
al.pendingSteeringCountForScope(targetSessionKey) == 0 {
break
}
if time.Now().After(deadline) {
t.Fatal("timeout waiting for target session to go idle")
}
time.Sleep(10 * time.Millisecond)
}
provider.mu.Lock()
calls := provider.calls
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls)
}
foundFollowUp := false
for _, msg := range secondMessages {
if msg.Role == "user" && msg.Content == "run this instead" {
foundFollowUp = true
}
if msg.Role == "user" && msg.Content == "skip this turn" {
t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content)
}
}
if !foundFollowUp {
t.Fatal("expected queued follow-up to be processed after pending stop")
}
}
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {