mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): drain scoped follow-up queue when pending stop skips turn startup
This commit is contained in:
@@ -247,6 +247,19 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
|
||||
if al.takePendingStop(sessionKey) {
|
||||
al.activeTurnStates.Delete(sessionKey)
|
||||
target := &continuationTarget{
|
||||
SessionKey: sessionKey,
|
||||
Channel: m.Channel,
|
||||
ChatID: m.ChatID,
|
||||
}
|
||||
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
|
||||
if continueErr != nil {
|
||||
al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr)
|
||||
return
|
||||
}
|
||||
if continued != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+31
-15
@@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
return
|
||||
}
|
||||
|
||||
// Drain steering queue using existing Continue mechanism
|
||||
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
})
|
||||
} else if continued != "" {
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
// Publish final response
|
||||
if finalResponse != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) drainQueuedSteeringContinuations(
|
||||
ctx context.Context,
|
||||
target *continuationTarget,
|
||||
) (string, error) {
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
finalResponse := ""
|
||||
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
|
||||
// Check for context cancellation between iterations
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
if err := ctx.Err(); err != nil {
|
||||
return finalResponse, err
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Continuing queued steering after turn end",
|
||||
@@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
|
||||
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
})
|
||||
break
|
||||
return finalResponse, continueErr
|
||||
}
|
||||
if continued == "" {
|
||||
break
|
||||
@@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
// Publish final response
|
||||
if finalResponse != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
|
||||
}
|
||||
return finalResponse, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
|
||||
|
||||
@@ -840,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
MaxParallelTurns: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &lateSteeringProvider{
|
||||
firstCallStarted: make(chan struct{}),
|
||||
releaseFirstCall: make(chan struct{}),
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
defer func() {
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}()
|
||||
|
||||
blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker")
|
||||
targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target")
|
||||
blockerCtx := bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "blocker-chat",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
}
|
||||
targetCtx := bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "target-chat",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: blockerCtx,
|
||||
Content: "block worker pool",
|
||||
SessionKey: blockerSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(blocker) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstCallStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for blocker turn to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "skip this turn",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(target start) error = %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
ts := al.getActiveTurnState(targetSessionKey)
|
||||
if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for pending placeholder")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "/stop",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(/stop) error = %v", err)
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
stopSeen := false
|
||||
for !stopSeen {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." {
|
||||
stopSeen = true
|
||||
}
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for /stop reply")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "run this instead",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(follow-up) error = %v", err)
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for al.pendingSteeringCountForScope(targetSessionKey) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for follow-up to enter scoped steering queue")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
close(provider.releaseFirstCall)
|
||||
|
||||
deadline = time.Now().Add(5 * time.Second)
|
||||
followUpSeen := false
|
||||
for !followUpSeen {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.ChatID == "target-chat" && outbound.Content == "continued response" {
|
||||
followUpSeen = true
|
||||
}
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for queued follow-up continuation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if al.GetActiveTurnBySession(targetSessionKey) == nil &&
|
||||
al.pendingSteeringCountForScope(targetSessionKey) == 0 {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for target session to go idle")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls)
|
||||
}
|
||||
|
||||
foundFollowUp := false
|
||||
for _, msg := range secondMessages {
|
||||
if msg.Role == "user" && msg.Content == "run this instead" {
|
||||
foundFollowUp = true
|
||||
}
|
||||
if msg.Role == "user" && msg.Content == "skip this turn" {
|
||||
t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
if !foundFollowUp {
|
||||
t.Fatal("expected queued follow-up to be processed after pending stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user