diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index bb21b7c5e..97ee4fe7d 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -247,6 +247,19 @@ func (al *AgentLoop) Run(ctx context.Context) error { if al.takePendingStop(sessionKey) { al.activeTurnStates.Delete(sessionKey) + target := &continuationTarget{ + SessionKey: sessionKey, + Channel: m.Channel, + ChatID: m.ChatID, + } + continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target) + if continueErr != nil { + al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr) + return + } + if continued != "" { + al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued) + } return } diff --git a/pkg/agent/agent_steering.go b/pkg/agent/agent_steering.go index c674bcafa..9b136e7cd 100644 --- a/pkg/agent/agent_steering.go +++ b/pkg/agent/agent_steering.go @@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb return } - // Drain steering queue using existing Continue mechanism + continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target) + if continueErr != nil { + logger.WarnCF("agent", "Failed to continue queued steering", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "error": continueErr.Error(), + }) + } else if continued != "" { + finalResponse = continued + } + + // Publish final response + if finalResponse != "" { + al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse) + } +} + +func (al *AgentLoop) drainQueuedSteeringContinuations( + ctx context.Context, + target *continuationTarget, +) (string, error) { + if target == nil { + return "", nil + } + + finalResponse := "" for al.pendingSteeringCountForScope(target.SessionKey) > 0 { - // Check for context cancellation between iterations - if ctx.Err() != nil { - return + if err := ctx.Err(); err != nil { + return finalResponse, err } logger.InfoCF("agent", "Continuing queued steering after turn end", @@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) if continueErr != nil { - logger.WarnCF("agent", "Failed to continue queued steering", - map[string]any{ - "channel": target.Channel, - "chat_id": target.ChatID, - "error": continueErr.Error(), - }) - break + return finalResponse, continueErr } if continued == "" { break @@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb finalResponse = continued } - // Publish final response - if finalResponse != "" { - al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse) - } + return finalResponse, nil } func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) { diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index eb8874122..813013649 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -840,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { } } +func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + MaxParallelTurns: 1, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &lateSteeringProvider{ + firstCallStarted: make(chan struct{}), + releaseFirstCall: make(chan struct{}), + } + al := NewAgentLoop(cfg, msgBus, provider) + + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- al.Run(runCtx) + }() + defer func() { + cancelRun() + select { + case err := <-runErrCh: + if err != nil { + t.Fatalf("Run() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for Run to stop") + } + }() + + blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker") + targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target") + blockerCtx := bus.InboundContext{ + Channel: "test", + ChatID: "blocker-chat", + ChatType: "direct", + SenderID: "user1", + } + targetCtx := bus.InboundContext{ + Channel: "test", + ChatID: "target-chat", + ChatType: "direct", + SenderID: "user1", + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: blockerCtx, + Content: "block worker pool", + SessionKey: blockerSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(blocker) error = %v", err) + } + + select { + case <-provider.firstCallStarted: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for blocker turn to start") + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: targetCtx, + Content: "skip this turn", + SessionKey: targetSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(target start) error = %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for { + ts := al.getActiveTurnState(targetSessionKey) + if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) { + break + } + if time.Now().After(deadline) { + t.Fatal("timeout waiting for pending placeholder") + } + time.Sleep(10 * time.Millisecond) + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: targetCtx, + Content: "/stop", + SessionKey: targetSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(/stop) error = %v", err) + } + + deadline = time.Now().Add(2 * time.Second) + stopSeen := false + for !stopSeen { + select { + case outbound := <-msgBus.OutboundChan(): + if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." { + stopSeen = true + } + case <-time.After(10 * time.Millisecond): + if time.Now().After(deadline) { + t.Fatal("timeout waiting for /stop reply") + } + } + } + + if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{ + Context: targetCtx, + Content: "run this instead", + SessionKey: targetSessionKey, + }); err != nil { + t.Fatalf("PublishInbound(follow-up) error = %v", err) + } + + deadline = time.Now().Add(2 * time.Second) + for al.pendingSteeringCountForScope(targetSessionKey) == 0 { + if time.Now().After(deadline) { + t.Fatal("timeout waiting for follow-up to enter scoped steering queue") + } + time.Sleep(10 * time.Millisecond) + } + + close(provider.releaseFirstCall) + + deadline = time.Now().Add(5 * time.Second) + followUpSeen := false + for !followUpSeen { + select { + case outbound := <-msgBus.OutboundChan(): + if outbound.ChatID == "target-chat" && outbound.Content == "continued response" { + followUpSeen = true + } + case <-time.After(10 * time.Millisecond): + if time.Now().After(deadline) { + t.Fatal("timeout waiting for queued follow-up continuation") + } + } + } + + deadline = time.Now().Add(2 * time.Second) + for { + if al.GetActiveTurnBySession(targetSessionKey) == nil && + al.pendingSteeringCountForScope(targetSessionKey) == 0 { + break + } + if time.Now().After(deadline) { + t.Fatal("timeout waiting for target session to go idle") + } + time.Sleep(10 * time.Millisecond) + } + + provider.mu.Lock() + calls := provider.calls + secondMessages := append([]providers.Message(nil), provider.secondCallMessages...) + provider.mu.Unlock() + + if calls != 2 { + t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls) + } + + foundFollowUp := false + for _, msg := range secondMessages { + if msg.Role == "user" && msg.Content == "run this instead" { + foundFollowUp = true + } + if msg.Role == "user" && msg.Content == "skip this turn" { + t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content) + } + } + if !foundFollowUp { + t.Fatal("expected queued follow-up to be processed after pending stop") + } +} + func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil {