From 0e075f7300014e4d305c346f3555742e34cb8174 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Fri, 20 Mar 2026 17:28:12 +0800 Subject: [PATCH] feat(agent): centralize turn lifecycle and continue queued steering Refactor agent loop execution around runTurn, add explicit turn state and interrupt semantics, and automatically continue queued steering that misses the current turn boundary. --- pkg/agent/eventbus_test.go | 3 + pkg/agent/events.go | 14 +- pkg/agent/loop.go | 818 ++++++++++++++++++++++--------------- pkg/agent/steering.go | 70 +++- pkg/agent/steering_test.go | 518 +++++++++++++++++++++++ pkg/agent/turn.go | 309 ++++++++++++++ 6 files changed, 1395 insertions(+), 337 deletions(-) create mode 100644 pkg/agent/turn.go diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go index 13f2f2282..9acc6ddd8 100644 --- a/pkg/agent/eventbus_test.go +++ b/pkg/agent/eventbus_test.go @@ -334,6 +334,9 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) { if interruptPayload.Role != "user" { t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role) } + if interruptPayload.Kind != InterruptKindSteering { + t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind) + } if interruptPayload.ContentLen != len("change course") { t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen) } diff --git a/pkg/agent/events.go b/pkg/agent/events.go index fae5033a3..95e4c90d0 100644 --- a/pkg/agent/events.go +++ b/pkg/agent/events.go @@ -105,6 +105,8 @@ const ( TurnEndStatusCompleted TurnEndStatus = "completed" // TurnEndStatusError indicates the turn ended because of an error. TurnEndStatusError TurnEndStatus = "error" + // TurnEndStatusAborted indicates the turn was hard-aborted and rolled back. + TurnEndStatusAborted TurnEndStatus = "aborted" ) // TurnStartPayload describes the start of a turn. @@ -215,11 +217,21 @@ type FollowUpQueuedPayload struct { ContentLen int } -// InterruptReceivedPayload describes a queued soft interrupt. +type InterruptKind string + +const ( + InterruptKindSteering InterruptKind = "steering" + InterruptKindGraceful InterruptKind = "graceful" + InterruptKindHard InterruptKind = "hard_abort" +) + +// InterruptReceivedPayload describes accepted turn-control input. type InterruptReceivedPayload struct { + Kind InterruptKind Role string ContentLen int QueueDepth int + HintLen int } // SubTurnSpawnPayload describes the creation of a child turn. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 877dbbd94..f54482ae8 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -50,6 +50,8 @@ type AgentLoop struct { mcp mcpRuntime steering *steeringQueue mu sync.RWMutex + activeTurnMu sync.RWMutex + activeTurn *turnState turnSeq atomic.Uint64 // Track active requests for safe provider cleanup activeRequests sync.WaitGroup @@ -69,6 +71,12 @@ type processOptions struct { SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) } +type continuationTarget struct { + SessionKey string + Channel string + ChatID string +} + const ( defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." sessionKeyAgentPrefix = "agent:" @@ -292,38 +300,46 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.GetRegistry().GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() - } - } + al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, response) + } + + target, targetErr := al.buildContinuationTarget(msg) + if targetErr != nil { + logger.WarnCF("agent", "Failed to build steering continuation target", + map[string]any{ + "channel": msg.Channel, + "error": targetErr.Error(), + }) + return + } + if target == nil { + return + } + + for al.pendingSteeringCount() > 0 { + logger.InfoCF("agent", "Continuing queued steering after turn end", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "session_key": target.SessionKey, + "queue_depth": al.pendingSteeringCount(), + }) + + continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) + if continueErr != nil { + logger.WarnCF("agent", "Failed to continue queued steering", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "error": continueErr.Error(), + }) + return + } + if continued == "" { + return } - if !alreadySent { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) - logger.InfoCF("agent", "Published outbound response", - map[string]any{ - "channel": msg.Channel, - "chat_id": msg.ChatID, - "content_len": len(response), - }) - } else { - logger.DebugCF( - "agent", - "Skipped outbound (message tool already sent)", - map[string]any{"channel": msg.Channel}, - ) - } + al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, continued) } }() } @@ -369,6 +385,67 @@ func (al *AgentLoop) Stop() { al.running.Store(false) } +func (al *AgentLoop) publishResponseIfNeeded(ctx context.Context, channel, chatID, response string) { + if response == "" { + return + } + + alreadySent := false + defaultAgent := al.GetRegistry().GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } + } + } + + if alreadySent { + logger.DebugCF( + "agent", + "Skipped outbound (message tool already sent)", + map[string]any{"channel": channel}, + ) + return + } + + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: response, + }) + logger.InfoCF("agent", "Published outbound response", + map[string]any{ + "channel": channel, + "chat_id": chatID, + "content_len": len(response), + }) +} + +func (al *AgentLoop) pendingSteeringCount() int { + if al.steering == nil { + return 0 + } + return al.steering.len() +} + +func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuationTarget, error) { + if msg.Channel == "system" { + return nil, nil + } + + route, _, err := al.resolveMessageRoute(msg) + if err != nil { + return nil, err + } + + return &continuationTarget{ + SessionKey: resolveScopeKey(route, msg.SessionKey), + Channel: msg.Channel, + ChatID: msg.ChatID, + }, nil +} + // Close releases resources held by agent session stores. Call after Stop. func (al *AgentLoop) Close() { mcpManager := al.mcp.takeManager() @@ -543,9 +620,11 @@ func (al *AgentLoop) logEvent(evt Event) { fields["chat_id"] = payload.ChatID fields["content_len"] = payload.ContentLen case InterruptReceivedPayload: + fields["interrupt_kind"] = payload.Kind fields["role"] = payload.Role fields["content_len"] = payload.ContentLen fields["queue_depth"] = payload.QueueDepth + fields["hint_len"] = payload.HintLen case SubTurnSpawnPayload: fields["child_agent_id"] = payload.AgentID fields["label"] = payload.Label @@ -1071,153 +1150,63 @@ func (al *AgentLoop) processSystemMessage( }) } -// runAgentLoop is the core message processing logic. +// runAgentLoop remains the top-level shell that starts a turn and publishes +// any post-turn work. runTurn owns the full turn lifecycle. func (al *AgentLoop) runAgentLoop( ctx context.Context, agent *AgentInstance, opts processOptions, ) (string, error) { - turnScope := al.newTurnEventScope(agent.ID, opts.SessionKey) - turnStartedAt := time.Now() - turnIterations := 0 - turnFinalContentLen := 0 - turnStatus := TurnEndStatusCompleted - defer func() { - al.emitEvent( - EventKindTurnEnd, - turnScope.meta(turnIterations, "runAgentLoop", "turn.end"), - TurnEndPayload{ - Status: turnStatus, - Iterations: turnIterations, - Duration: time.Since(turnStartedAt), - FinalContentLen: turnFinalContentLen, - }, - ) - }() - - al.emitEvent( - EventKindTurnStart, - turnScope.meta(0, "runAgentLoop", "turn.start"), - TurnStartPayload{ - Channel: opts.Channel, - ChatID: opts.ChatID, - UserMessage: opts.UserMessage, - MediaCount: len(opts.Media), - }, - ) - - // 0. Record last channel for heartbeat notifications (skip internal channels and cli) - if opts.Channel != "" && opts.ChatID != "" { - if !constants.IsInternalChannel(opts.Channel) { - channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) - if err := al.RecordLastChannel(channelKey); err != nil { - logger.WarnCF( - "agent", - "Failed to record last channel", - map[string]any{"error": err.Error()}, - ) - } - } - } - - // 1. Build messages (skip history for heartbeat) - var history []providers.Message - var summary string - if !opts.NoHistory { - history = agent.Sessions.GetHistory(opts.SessionKey) - summary = agent.Sessions.GetSummary(opts.SessionKey) - } - messages := agent.ContextBuilder.BuildMessages( - history, - summary, - opts.UserMessage, - opts.Media, - opts.Channel, - opts.ChatID, - ) - - // Resolve media:// refs: images→base64 data URLs, non-images→local paths in content - cfg := al.GetConfig() - maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize() - messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) - - // 1.5. Proactive context budget check: compress before LLM call - // rather than waiting for a 400 context-length error. - if !opts.NoHistory { - toolDefs := agent.Tools.ToProviderDefs() - if isOverContextBudget(agent.ContextWindow, messages, toolDefs, agent.MaxTokens) { - logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call", - map[string]any{"session_key": opts.SessionKey}) - if compression, ok := al.forceCompression(agent, opts.SessionKey); ok { - al.emitEvent( - EventKindContextCompress, - turnScope.meta(0, "runAgentLoop", "turn.context.compress"), - ContextCompressPayload{ - Reason: ContextCompressReasonProactive, - DroppedMessages: compression.DroppedMessages, - RemainingMessages: compression.RemainingMessages, - }, - ) - } - newHistory := agent.Sessions.GetHistory(opts.SessionKey) - newSummary := agent.Sessions.GetSummary(opts.SessionKey) - messages = agent.ContextBuilder.BuildMessages( - newHistory, newSummary, opts.UserMessage, - opts.Media, opts.Channel, opts.ChatID, + if opts.Channel != "" && opts.ChatID != "" && !constants.IsInternalChannel(opts.Channel) { + channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) + if err := al.RecordLastChannel(channelKey); err != nil { + logger.WarnCF( + "agent", + "Failed to record last channel", + map[string]any{"error": err.Error()}, ) - messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) } } - // 2. Save user message to session - agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - - // 3. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts, turnScope) - turnIterations = iteration + ts := newTurnState(agent, opts, al.newTurnEventScope(agent.ID, opts.SessionKey)) + result, err := al.runTurn(ctx, ts) if err != nil { - turnStatus = TurnEndStatusError return "", err } - - // If last tool had ForUser content and we already sent it, we might not need to send final response - // This is controlled by the tool's Silent flag and ForUser content - - // 4. Handle empty response - if finalContent == "" { - finalContent = opts.DefaultResponse - } - turnFinalContentLen = len(finalContent) - - // 5. Save final assistant message to session - agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - agent.Sessions.Save(opts.SessionKey) - - // 6. Optional: summarization - if opts.EnableSummary { - al.maybeSummarize(agent, opts.SessionKey, turnScope) + if result.status == TurnEndStatusAborted { + return "", nil } - // 7. Optional: send response via bus - if opts.SendResponse { + for _, followUp := range result.followUps { + if pubErr := al.bus.PublishInbound(ctx, followUp); pubErr != nil { + logger.WarnCF("agent", "Failed to publish follow-up after turn", + map[string]any{ + "turn_id": ts.turnID, + "error": pubErr.Error(), + }) + } + } + + if opts.SendResponse && result.finalContent != "" { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: finalContent, + Content: result.finalContent, }) } - // 8. Log response - responsePreview := utils.Truncate(finalContent, 120) - logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), - map[string]any{ - "agent_id": agent.ID, - "session_key": opts.SessionKey, - "iterations": iteration, - "final_length": len(finalContent), - }) + if result.finalContent != "" { + responsePreview := utils.Truncate(result.finalContent, 120) + logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), + map[string]any{ + "agent_id": agent.ID, + "session_key": opts.SessionKey, + "iterations": ts.currentIteration(), + "final_length": len(result.finalContent), + }) + } - return finalContent, nil + return result.finalContent, nil } func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) { @@ -1276,54 +1265,135 @@ func (al *AgentLoop) handleReasoning( } } -// runLLMIteration executes the LLM call loop with tool handling. -func (al *AgentLoop) runLLMIteration( - ctx context.Context, - agent *AgentInstance, - messages []providers.Message, - opts processOptions, - turnScope turnEventScope, -) (string, int, error) { - iteration := 0 - var finalContent string - var pendingMessages []providers.Message +func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, error) { + turnCtx, turnCancel := context.WithCancel(ctx) + defer turnCancel() + ts.setTurnCancel(turnCancel) - // Poll for steering messages at loop start (in case the user typed while - // the agent was setting up), unless the caller already provided initial - // steering messages (e.g. Continue). - if !opts.SkipInitialSteeringPoll { - if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 { - pendingMessages = msgs + al.registerActiveTurn(ts) + defer al.clearActiveTurn(ts) + + turnStatus := TurnEndStatusCompleted + defer func() { + al.emitEvent( + EventKindTurnEnd, + ts.eventMeta("runTurn", "turn.end"), + TurnEndPayload{ + Status: turnStatus, + Iterations: ts.currentIteration(), + Duration: time.Since(ts.startedAt), + FinalContentLen: ts.finalContentLen(), + }, + ) + }() + + al.emitEvent( + EventKindTurnStart, + ts.eventMeta("runTurn", "turn.start"), + TurnStartPayload{ + Channel: ts.channel, + ChatID: ts.chatID, + UserMessage: ts.userMessage, + MediaCount: len(ts.media), + }, + ) + + var history []providers.Message + var summary string + if !ts.opts.NoHistory { + history = ts.agent.Sessions.GetHistory(ts.sessionKey) + summary = ts.agent.Sessions.GetSummary(ts.sessionKey) + } + ts.captureRestorePoint(history, summary) + + messages := ts.agent.ContextBuilder.BuildMessages( + history, + summary, + ts.userMessage, + ts.media, + ts.channel, + ts.chatID, + ) + + cfg := al.GetConfig() + maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize() + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) + + if !ts.opts.NoHistory { + toolDefs := ts.agent.Tools.ToProviderDefs() + if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) { + logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call", + map[string]any{"session_key": ts.sessionKey}) + if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok { + al.emitEvent( + EventKindContextCompress, + ts.eventMeta("runTurn", "turn.context.compress"), + ContextCompressPayload{ + Reason: ContextCompressReasonProactive, + DroppedMessages: compression.DroppedMessages, + RemainingMessages: compression.RemainingMessages, + }, + ) + ts.refreshRestorePointFromSession(ts.agent) + } + newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey) + newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey) + messages = ts.agent.ContextBuilder.BuildMessages( + newHistory, newSummary, ts.userMessage, + ts.media, ts.channel, ts.chatID, + ) + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) } } - // Determine effective model tier for this conversation turn. - // selectCandidates evaluates routing once and the decision is sticky for - // all tool-follow-up iterations within the same turn so that a multi-step - // tool chain doesn't switch models mid-way through. - activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages) + if !ts.opts.NoHistory { + rootMsg := providers.Message{Role: "user", Content: ts.userMessage} + ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content) + ts.recordPersistedMessage(rootMsg) + } - for iteration < agent.MaxIterations || len(pendingMessages) > 0 { - iteration++ + activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages) + var pendingMessages []providers.Message + var finalContent string + + for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool { + graceful, _ := ts.gracefulInterruptRequested() + return graceful + }() { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + + iteration := ts.currentIteration() + 1 + ts.setIteration(iteration) + ts.setPhase(TurnPhaseRunning) + + if iteration > 1 || !ts.opts.SkipInitialSteeringPoll { + if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + pendingMessages = append(pendingMessages, steerMsgs...) + } + } - // Inject pending steering messages into the conversation context - // before the next LLM call. if len(pendingMessages) > 0 { totalContentLen := 0 for _, pm := range pendingMessages { messages = append(messages, pm) - agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content) totalContentLen += len(pm.Content) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddMessage(ts.sessionKey, pm.Role, pm.Content) + ts.recordPersistedMessage(pm) + } logger.InfoCF("agent", "Injected steering message into context", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "content_len": len(pm.Content), }) } al.emitEvent( EventKindSteeringInjected, - turnScope.meta(iteration, "runLLMIteration", "turn.steering.injected"), + ts.eventMeta("runTurn", "turn.steering.injected"), SteeringInjectedPayload{ Count: len(pendingMessages), TotalContentLen: totalContentLen, @@ -1334,78 +1404,81 @@ func (al *AgentLoop) runLLMIteration( logger.DebugCF("agent", "LLM iteration", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, - "max": agent.MaxIterations, + "max": ts.agent.MaxIterations, }) - // Build tool definitions - providerToolDefs := agent.Tools.ToProviderDefs() + gracefulTerminal, _ := ts.gracefulInterruptRequested() + providerToolDefs := ts.agent.Tools.ToProviderDefs() + callMessages := messages + if gracefulTerminal { + callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage()) + providerToolDefs = nil + ts.markGracefulTerminalUsed() + } + al.emitEvent( EventKindLLMRequest, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.request"), + ts.eventMeta("runTurn", "turn.llm.request"), LLMRequestPayload{ Model: activeModel, - MessagesCount: len(messages), + MessagesCount: len(callMessages), ToolsCount: len(providerToolDefs), - MaxTokens: agent.MaxTokens, - Temperature: agent.Temperature, + MaxTokens: ts.agent.MaxTokens, + Temperature: ts.agent.Temperature, }, ) - // Log LLM request details logger.DebugCF("agent", "LLM request", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "model": activeModel, - "messages_count": len(messages), + "messages_count": len(callMessages), "tools_count": len(providerToolDefs), - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "system_prompt_len": len(messages[0].Content), + "max_tokens": ts.agent.MaxTokens, + "temperature": ts.agent.Temperature, + "system_prompt_len": len(callMessages[0].Content), }) - - // Log full messages (detailed) logger.DebugCF("agent", "Full LLM request", map[string]any{ "iteration": iteration, - "messages_json": formatMessagesForLog(messages), + "messages_json": formatMessagesForLog(callMessages), "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM with fallback chain if multiple candidates are configured. - var response *providers.LLMResponse - var err error - llmOpts := map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, + "max_tokens": ts.agent.MaxTokens, + "temperature": ts.agent.Temperature, + "prompt_cache_key": ts.agent.ID, } - // parseThinkingLevel guarantees ThinkingOff for empty/unknown values, - // so checking != ThinkingOff is sufficient. - if agent.ThinkingLevel != ThinkingOff { - if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() { - llmOpts["thinking_level"] = string(agent.ThinkingLevel) + if ts.agent.ThinkingLevel != ThinkingOff { + if tc, ok := ts.agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() { + llmOpts["thinking_level"] = string(ts.agent.ThinkingLevel) } else { logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring", - map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)}) + map[string]any{"agent_id": ts.agent.ID, "thinking_level": string(ts.agent.ThinkingLevel)}) } } - callLLM := func() (*providers.LLMResponse, error) { + callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) { + providerCtx, providerCancel := context.WithCancel(turnCtx) + ts.setProviderCancel(providerCancel) + defer func() { + providerCancel() + ts.clearProviderCancel(providerCancel) + }() + al.activeRequests.Add(1) defer al.activeRequests.Done() - // TODO(eventbus): emit EventKindLLMDelta when providers expose - // streaming callbacks instead of only the final Chat response. if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( - ctx, + providerCtx, activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts) + return ts.agent.Provider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts) }, ) if fbErr != nil { @@ -1416,32 +1489,34 @@ func (al *AgentLoop) runLLMIteration( "agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), - map[string]any{"agent_id": agent.ID, "iteration": iteration}, + map[string]any{"agent_id": ts.agent.ID, "iteration": iteration}, ) } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts) + return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts) } - // Retry loop for context/token errors + var response *providers.LLMResponse + var err error maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { - response, err = callLLM() + response, err = callLLM(callMessages, providerToolDefs) if err == nil { break } + if ts.hardAbortRequested() && errors.Is(err, context.Canceled) { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } errMsg := strings.ToLower(err.Error()) - - // Check if this is a network/HTTP timeout — not a context window error. isTimeoutError := errors.Is(err, context.DeadlineExceeded) || strings.Contains(errMsg, "deadline exceeded") || strings.Contains(errMsg, "client.timeout") || strings.Contains(errMsg, "timed out") || strings.Contains(errMsg, "timeout exceeded") - // Detect real context window / token limit errors, excluding network timeouts. isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") || strings.Contains(errMsg, "context window") || strings.Contains(errMsg, "maximum context length") || @@ -1456,7 +1531,7 @@ func (al *AgentLoop) runLLMIteration( backoff := time.Duration(retry+1) * 5 * time.Second al.emitEvent( EventKindLLMRetry, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"), + ts.eventMeta("runTurn", "turn.llm.retry"), LLMRetryPayload{ Attempt: retry + 1, MaxRetries: maxRetries, @@ -1470,14 +1545,21 @@ func (al *AgentLoop) runLLMIteration( "retry": retry, "backoff": backoff.String(), }) - time.Sleep(backoff) + if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + err = sleepErr + break + } continue } - if isContextError && retry < maxRetries { + if isContextError && retry < maxRetries && !ts.opts.NoHistory { al.emitEvent( EventKindLLMRetry, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"), + ts.eventMeta("runTurn", "turn.llm.retry"), LLMRetryPayload{ Attempt: retry + 1, MaxRetries: maxRetries, @@ -1494,40 +1576,47 @@ func (al *AgentLoop) runLLMIteration( }, ) - if retry == 0 && !constants.IsInternalChannel(opts.Channel) { + if retry == 0 && !constants.IsInternalChannel(ts.channel) { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Content: "Context window exceeded. Compressing history and retrying...", }) } - if compression, ok := al.forceCompression(agent, opts.SessionKey); ok { + if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok { al.emitEvent( EventKindContextCompress, - turnScope.meta(iteration, "runLLMIteration", "turn.context.compress"), + ts.eventMeta("runTurn", "turn.context.compress"), ContextCompressPayload{ Reason: ContextCompressReasonRetry, DroppedMessages: compression.DroppedMessages, RemainingMessages: compression.RemainingMessages, }, ) + ts.refreshRestorePointFromSession(ts.agent) } - newHistory := agent.Sessions.GetHistory(opts.SessionKey) - newSummary := agent.Sessions.GetSummary(opts.SessionKey) - messages = agent.ContextBuilder.BuildMessages( + + newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey) + newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey) + messages = ts.agent.ContextBuilder.BuildMessages( newHistory, newSummary, "", - nil, opts.Channel, opts.ChatID, + nil, ts.channel, ts.chatID, ) + callMessages = messages + if gracefulTerminal { + callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage()) + } continue } break } if err != nil { + turnStatus = TurnEndStatusError al.emitEvent( EventKindError, - turnScope.meta(iteration, "runLLMIteration", "turn.error"), + ts.eventMeta("runTurn", "turn.error"), ErrorPayload{ Stage: "llm", Message: err.Error(), @@ -1535,23 +1624,23 @@ func (al *AgentLoop) runLLMIteration( ) logger.ErrorCF("agent", "LLM call failed", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "model": activeModel, "error": err.Error(), }) - return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) + return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err) } go al.handleReasoning( - ctx, + turnCtx, response.Reasoning, - opts.Channel, - al.targetReasoningChannelID(opts.Channel), + ts.channel, + al.targetReasoningChannelID(ts.channel), ) al.emitEvent( EventKindLLMResponse, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.response"), + ts.eventMeta("runTurn", "turn.llm.response"), LLMResponsePayload{ ContentLen: len(response.Content), ToolCalls: len(response.ToolCalls), @@ -1561,23 +1650,23 @@ func (al *AgentLoop) runLLMIteration( logger.DebugCF("agent", "LLM response", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "content_chars": len(response.Content), "tool_calls": len(response.ToolCalls), "reasoning": response.Reasoning, - "target_channel": al.targetReasoningChannelID(opts.Channel), - "channel": opts.Channel, + "target_channel": al.targetReasoningChannelID(ts.channel), + "channel": ts.channel, }) - // Check if no tool calls - then check reasoning content if any - if len(response.ToolCalls) == 0 { + + if len(response.ToolCalls) == 0 || gracefulTerminal { finalContent = response.Content if finalContent == "" && response.ReasoningContent != "" { finalContent = response.ReasoningContent } logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "content_chars": len(finalContent), }) @@ -1589,20 +1678,18 @@ func (al *AgentLoop) runLLMIteration( normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc)) } - // Log tool calls toolNames := make([]string, 0, len(normalizedToolCalls)) for _, tc := range normalizedToolCalls { toolNames = append(toolNames, tc.Name) } logger.InfoCF("agent", "LLM requested tool calls", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "tools": toolNames, "count": len(normalizedToolCalls), "iteration": iteration, }) - // Build assistant message with tool calls assistantMsg := providers.Message{ Role: "assistant", Content: response.Content, @@ -1610,13 +1697,11 @@ func (al *AgentLoop) runLLMIteration( } for _, tc := range normalizedToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) - // Copy ExtraContent to ensure thought_signature is persisted for Gemini 3 extraContent := tc.ExtraContent thoughtSignature := "" if tc.Function != nil { thoughtSignature = tc.Function.ThoughtSignature } - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ ID: tc.ID, Type: "function", @@ -1631,40 +1716,44 @@ func (al *AgentLoop) runLLMIteration( }) } messages = append(messages, assistantMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg) + ts.recordPersistedMessage(assistantMsg) + } - // Save assistant message with tool calls to session - agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) - - // Execute tool calls sequentially. After each tool completes, check - // for steering messages. If any are found, skip remaining tools. - var steeringAfterTools []providers.Message - + ts.setPhase(TurnPhaseTools) for i, tc := range normalizedToolCalls { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "tool": tc.Name, "iteration": iteration, }) al.emitEvent( EventKindToolExecStart, - turnScope.meta(iteration, "runLLMIteration", "turn.tool.start"), + ts.eventMeta("runTurn", "turn.tool.start"), ToolExecStartPayload{ Tool: tc.Name, Arguments: cloneEventArguments(tc.Arguments), }, ) - // Create async callback for tools that implement AsyncExecutor. + toolCall := tc + toolIteration := iteration asyncCallback := func(_ context.Context, result *tools.ToolResult) { if !result.Silent && result.ForUser != "" { outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) defer outCancel() _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Content: result.ForUser, }) } @@ -1679,17 +1768,17 @@ func (al *AgentLoop) runLLMIteration( logger.InfoCF("agent", "Async tool completed, publishing result", map[string]any{ - "tool": tc.Name, + "tool": toolCall.Name, "content_len": len(content), - "channel": opts.Channel, + "channel": ts.channel, }) al.emitEvent( EventKindFollowUpQueued, - turnScope.meta(iteration, "runLLMIteration", "turn.follow_up.queued"), + ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"), FollowUpQueuedPayload{ - SourceTool: tc.Name, - Channel: opts.Channel, - ChatID: opts.ChatID, + SourceTool: toolCall.Name, + Channel: ts.channel, + ChatID: ts.chatID, ContentLen: len(content), }, ) @@ -1698,33 +1787,37 @@ func (al *AgentLoop) runLLMIteration( defer pubCancel() _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ Channel: "system", - SenderID: fmt.Sprintf("async:%s", tc.Name), - ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID), + SenderID: fmt.Sprintf("async:%s", toolCall.Name), + ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID), Content: content, }) } toolStart := time.Now() - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, + toolResult := ts.agent.Tools.ExecuteWithContext( + turnCtx, + toolCall.Name, + toolCall.Arguments, + ts.channel, + ts.chatID, asyncCallback, ) toolDuration := time.Since(toolStart) - // Process tool result - if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + + if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Content: toolResult.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": tc.Name, + "tool": toolCall.Name, "content_len": len(toolResult.ForUser), }) } @@ -1743,8 +1836,8 @@ func (al *AgentLoop) runLLMIteration( parts = append(parts, part) } al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Parts: parts, }) } @@ -1757,13 +1850,13 @@ func (al *AgentLoop) runLLMIteration( toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, + ToolCallID: toolCall.ID, } al.emitEvent( EventKindToolExecEnd, - turnScope.meta(iteration, "runLLMIteration", "turn.tool.end"), + ts.eventMeta("runTurn", "turn.tool.end"), ToolExecEndPayload{ - Tool: tc.Name, + Tool: toolCall.Name, Duration: toolDuration, ForLLMLen: len(contentForLLM), ForUserLen: len(toolResult.ForUser), @@ -1772,67 +1865,136 @@ func (al *AgentLoop) runLLMIteration( }, ) messages = append(messages, toolResultMsg) - agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg) + ts.recordPersistedMessage(toolResultMsg) + } - // After EVERY tool (including the first and last), check for - // steering messages. If found and there are remaining tools, - // skip them all. if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + pendingMessages = append(pendingMessages, steerMsgs...) + } + + skipReason := "" + skipMessage := "" + if len(pendingMessages) > 0 { + skipReason = "queued user steering message" + skipMessage = "Skipped due to queued user message." + } else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending { + skipReason = "graceful interrupt requested" + skipMessage = "Skipped due to graceful interrupt." + } + + if skipReason != "" { remaining := len(normalizedToolCalls) - i - 1 if remaining > 0 { - logger.InfoCF("agent", "Steering interrupt: skipping remaining tools", + logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools", map[string]any{ - "agent_id": agent.ID, - "completed": i + 1, - "skipped": remaining, - "total_tools": len(normalizedToolCalls), - "steering_count": len(steerMsgs), + "agent_id": ts.agent.ID, + "completed": i + 1, + "skipped": remaining, + "reason": skipReason, }) - - // Mark remaining tool calls as skipped for j := i + 1; j < len(normalizedToolCalls); j++ { skippedTC := normalizedToolCalls[j] al.emitEvent( EventKindToolExecSkipped, - turnScope.meta(iteration, "runLLMIteration", "turn.tool.skipped"), + ts.eventMeta("runTurn", "turn.tool.skipped"), ToolExecSkippedPayload{ Tool: skippedTC.Name, - Reason: "queued user steering message", + Reason: skipReason, }, ) - toolResultMsg := providers.Message{ + skippedMsg := providers.Message{ Role: "tool", - Content: "Skipped due to queued user message.", + Content: skipMessage, ToolCallID: skippedTC.ID, } - messages = append(messages, toolResultMsg) - agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + messages = append(messages, skippedMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg) + ts.recordPersistedMessage(skippedMsg) + } } } - steeringAfterTools = steerMsgs break } } - // If steering messages were captured during tool execution, they - // become pendingMessages for the next iteration of the inner loop. - if len(steeringAfterTools) > 0 { - pendingMessages = steeringAfterTools - } - - // Tick down TTL of discovered tools after processing tool results. - // Only reached when tool calls were made (the loop continues); - // the break on no-tool-call responses skips this. - // NOTE: This is safe because processMessage is sequential per agent. - // If per-agent concurrency is added, TTL consistency between - // ToProviderDefs and Get must be re-evaluated. - agent.Tools.TickTTL() + ts.agent.Tools.TickTTL() logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{ - "agent_id": agent.ID, "iteration": iteration, + "agent_id": ts.agent.ID, "iteration": iteration, }) } - return finalContent, iteration, nil + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + + if finalContent == "" { + finalContent = ts.opts.DefaultResponse + } + + ts.setPhase(TurnPhaseFinalizing) + ts.setFinalContent(finalContent) + if !ts.opts.NoHistory { + finalMsg := providers.Message{Role: "assistant", Content: finalContent} + ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content) + ts.recordPersistedMessage(finalMsg) + if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil { + turnStatus = TurnEndStatusError + al.emitEvent( + EventKindError, + ts.eventMeta("runTurn", "turn.error"), + ErrorPayload{ + Stage: "session_save", + Message: err.Error(), + }, + ) + return turnResult{}, err + } + } + + if ts.opts.EnableSummary { + al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope) + } + + ts.setPhase(TurnPhaseCompleted) + return turnResult{ + finalContent: finalContent, + status: turnStatus, + followUps: append([]bus.InboundMessage(nil), ts.followUps...), + }, nil +} + +func (al *AgentLoop) abortTurn(ts *turnState) (turnResult, error) { + ts.setPhase(TurnPhaseAborted) + if !ts.opts.NoHistory { + if err := ts.restoreSession(ts.agent); err != nil { + al.emitEvent( + EventKindError, + ts.eventMeta("abortTurn", "turn.error"), + ErrorPayload{ + Stage: "session_restore", + Message: err.Error(), + }, + ) + return turnResult{}, err + } + } + return turnResult{status: TurnEndStatusAborted}, nil +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } } // selectCandidates returns the model candidates and resolved model name to use diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index 90d1cc091..77c2e0c17 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -122,20 +122,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error { "content_len": len(msg.Content), "queue_len": al.steering.len(), }) - agentID := "" - if registry := al.GetRegistry(); registry != nil { + + meta := EventMeta{ + Source: "Steer", + TracePath: "turn.interrupt.received", + } + if ts := al.getActiveTurnState(); ts != nil { + meta = ts.eventMeta("Steer", "turn.interrupt.received") + } else if registry := al.GetRegistry(); registry != nil { if agent := registry.GetDefaultAgent(); agent != nil { - agentID = agent.ID + meta.AgentID = agent.ID } } al.emitEvent( EventKindInterruptReceived, - EventMeta{ - AgentID: agentID, - Source: "Steer", - TracePath: "turn.interrupt.received", - }, + meta, InterruptReceivedPayload{ + Kind: InterruptKindSteering, Role: msg.Role, ContentLen: len(msg.Content), QueueDepth: al.steering.len(), @@ -177,6 +180,10 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message { // // If no steering messages are pending, it returns an empty string. func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) { + if active := al.GetActiveTurn(); active != nil { + return "", fmt.Errorf("turn %s is still active", active.TurnID) + } + steeringMsgs := al.dequeueSteeringMessages() if len(steeringMsgs) == 0 { return "", nil @@ -187,6 +194,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s return "", fmt.Errorf("no default agent available") } + if tool, ok := agent.Tools.Get("message"); ok { + if resetter, ok := tool.(interface{ ResetSentInRound() }); ok { + resetter.ResetSentInRound() + } + } + // Build a combined user message from the steering messages. var contents []string for _, msg := range steeringMsgs { @@ -205,3 +218,44 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s SkipInitialSteeringPoll: true, }) } + +func (al *AgentLoop) InterruptGraceful(hint string) error { + ts := al.getActiveTurnState() + if ts == nil { + return fmt.Errorf("no active turn") + } + if !ts.requestGracefulInterrupt(hint) { + return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID) + } + + al.emitEvent( + EventKindInterruptReceived, + ts.eventMeta("InterruptGraceful", "turn.interrupt.received"), + InterruptReceivedPayload{ + Kind: InterruptKindGraceful, + HintLen: len(hint), + }, + ) + + return nil +} + +func (al *AgentLoop) InterruptHard() error { + ts := al.getActiveTurnState() + if ts == nil { + return fmt.Errorf("no active turn") + } + if !ts.requestHardAbort() { + return fmt.Errorf("turn %s is already aborting", ts.turnID) + } + + al.emitEvent( + EventKindInterruptReceived, + ts.eventMeta("InterruptHard", "turn.interrupt.received"), + InterruptReceivedPayload{ + Kind: InterruptKindHard, + }, + ) + + return nil +} diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index e8cdb2344..f8c046ea9 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os" + "reflect" "sync" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -396,6 +398,103 @@ func (m *toolCallProvider) GetDefaultModel() string { return "tool-call-mock" } +type gracefulCaptureProvider struct { + mu sync.Mutex + calls int + toolCalls []providers.ToolCall + finalResp string + terminalMessages []providers.Message + terminalToolsCount int +} + +func (p *gracefulCaptureProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.calls++ + + if p.calls == 1 { + return &providers.LLMResponse{ + ToolCalls: p.toolCalls, + }, nil + } + + p.terminalMessages = append([]providers.Message(nil), messages...) + p.terminalToolsCount = len(tools) + return &providers.LLMResponse{ + Content: p.finalResp, + }, nil +} + +func (p *gracefulCaptureProvider) GetDefaultModel() string { + return "graceful-capture-mock" +} + +type lateSteeringProvider struct { + mu sync.Mutex + calls int + firstCallStarted chan struct{} + releaseFirstCall chan struct{} + firstStartOnce sync.Once + secondCallMessages []providers.Message +} + +func (p *lateSteeringProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + p.calls++ + call := p.calls + p.mu.Unlock() + + if call == 1 { + p.firstStartOnce.Do(func() { close(p.firstCallStarted) }) + <-p.releaseFirstCall + return &providers.LLMResponse{Content: "first response"}, nil + } + + p.mu.Lock() + p.secondCallMessages = append([]providers.Message(nil), messages...) + p.mu.Unlock() + return &providers.LLMResponse{Content: "continued response"}, nil +} + +func (p *lateSteeringProvider) GetDefaultModel() string { + return "late-steering-mock" +} + +type interruptibleTool struct { + name string + started chan struct{} + once sync.Once +} + +func (t *interruptibleTool) Name() string { return t.name } +func (t *interruptibleTool) Description() string { return "interruptible tool for testing" } +func (t *interruptibleTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + if t.started != nil { + t.once.Do(func() { close(t.started) }) + } + <-ctx.Done() + return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err()) +} + func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { @@ -568,6 +667,425 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) { } } +func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &lateSteeringProvider{ + firstCallStarted: make(chan struct{}), + releaseFirstCall: make(chan struct{}), + } + al := NewAgentLoop(cfg, msgBus, provider) + + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- al.Run(runCtx) + }() + + first := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "first message", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + late := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "late append", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer pubCancel() + if err := msgBus.PublishInbound(pubCtx, first); err != nil { + t.Fatalf("publish first inbound: %v", err) + } + + select { + case <-provider.firstCallStarted: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for first provider call to start") + } + + if err := msgBus.PublishInbound(pubCtx, late); err != nil { + t.Fatalf("publish late inbound: %v", err) + } + + close(provider.releaseFirstCall) + + subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer subCancel() + + out1, ok := msgBus.SubscribeOutbound(subCtx) + if !ok { + t.Fatal("expected first outbound response") + } + if out1.Content != "first response" { + t.Fatalf("expected first response, got %q", out1.Content) + } + + out2, ok := msgBus.SubscribeOutbound(subCtx) + if !ok { + t.Fatal("expected continued outbound response") + } + if out2.Content != "continued response" { + t.Fatalf("expected continued response, got %q", out2.Content) + } + + cancelRun() + select { + case err := <-runErrCh: + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for Run to stop") + } + + provider.mu.Lock() + calls := provider.calls + secondMessages := append([]providers.Message(nil), provider.secondCallMessages...) + provider.mu.Unlock() + + if calls != 2 { + t.Fatalf("expected 2 provider calls, got %d", calls) + } + + foundLateMessage := false + for _, msg := range secondMessages { + if msg.Role == "user" && msg.Content == "late append" { + foundLateMessage = true + break + } + } + if !foundLateMessage { + t.Fatal("expected queued late message to be processed in an automatic follow-up turn") + } +} + +func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + tool1ExecCh := make(chan struct{}) + tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh} + tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond} + + provider := &gracefulCaptureProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "tool_one", + Function: &providers.FunctionCall{ + Name: "tool_one", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + { + ID: "call_2", + Type: "function", + Name: "tool_two", + Function: &providers.FunctionCall{ + Name: "tool_two", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "graceful summary", + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + al.RegisterTool(tool1) + al.RegisterTool(tool2) + sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + + sub := al.SubscribeEvents(32) + defer al.UnsubscribeEvents(sub.ID) + + type result struct { + resp string + err error + } + resultCh := make(chan result, 1) + go func() { + resp, err := al.ProcessDirectWithChannel( + context.Background(), + "do something", + sessionKey, + "test", + "chat1", + ) + resultCh <- result{resp: resp, err: err} + }() + + select { + case <-tool1ExecCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for tool_one to start") + } + + active := al.GetActiveTurn() + if active == nil { + t.Fatal("expected active turn while tool is running") + } + if active.SessionKey != sessionKey { + t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey) + } + if active.Channel != "test" || active.ChatID != "chat1" { + t.Fatalf("unexpected active turn target: %#v", active) + } + + if err := al.InterruptGraceful("wrap it up"); err != nil { + t.Fatalf("InterruptGraceful failed: %v", err) + } + + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("unexpected error: %v", r.err) + } + if r.resp != "graceful summary" { + t.Fatalf("expected graceful summary, got %q", r.resp) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for graceful interrupt result") + } + + if active := al.GetActiveTurn(); active != nil { + t.Fatalf("expected no active turn after completion, got %#v", active) + } + + provider.mu.Lock() + terminalMessages := append([]providers.Message(nil), provider.terminalMessages...) + terminalToolsCount := provider.terminalToolsCount + calls := provider.calls + provider.mu.Unlock() + + if calls != 2 { + t.Fatalf("expected 2 provider calls, got %d", calls) + } + if terminalToolsCount != 0 { + t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount) + } + + foundHint := false + foundSkipped := false + for _, msg := range terminalMessages { + if msg.Role == "user" && msg.Content == "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\nInterrupt hint: wrap it up" { + foundHint = true + } + if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." { + foundSkipped = true + } + } + if !foundHint { + t.Fatal("expected graceful terminal call to include interrupt hint message") + } + if !foundSkipped { + t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt") + } + + events := collectEventStream(sub.C) + interruptEvt, ok := findEvent(events, EventKindInterruptReceived) + if !ok { + t.Fatal("expected interrupt received event") + } + interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload) + if !ok { + t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload) + } + if interruptPayload.Kind != InterruptKindGraceful { + t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind) + } + + turnEndEvt, ok := findEvent(events, EventKindTurnEnd) + if !ok { + t.Fatal("expected turn end event") + } + turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload) + if !ok { + t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload) + } + if turnEndPayload.Status != TurnEndStatusCompleted { + t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status) + } +} + +func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "cancel_tool", + Function: &providers.FunctionCall{ + Name: "cancel_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "should not happen", + } + + al := NewAgentLoop(cfg, msgBus, provider) + started := make(chan struct{}) + al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started}) + sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + + originalHistory := []providers.Message{ + {Role: "user", Content: "before"}, + {Role: "assistant", Content: "after"}, + } + defaultAgent.Sessions.SetHistory(sessionKey, originalHistory) + + sub := al.SubscribeEvents(16) + defer al.UnsubscribeEvents(sub.ID) + + type result struct { + resp string + err error + } + resultCh := make(chan result, 1) + go func() { + resp, err := al.ProcessDirectWithChannel( + context.Background(), + "do work", + sessionKey, + "test", + "chat1", + ) + resultCh <- result{resp: resp, err: err} + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for interruptible tool to start") + } + + if active := al.GetActiveTurn(); active == nil { + t.Fatal("expected active turn before hard abort") + } + + if err := al.InterruptHard(); err != nil { + t.Fatalf("InterruptHard failed: %v", err) + } + + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("unexpected error: %v", r.err) + } + if r.resp != "" { + t.Fatalf("expected no final response after hard abort, got %q", r.resp) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for hard abort result") + } + + if active := al.GetActiveTurn(); active != nil { + t.Fatalf("expected no active turn after hard abort, got %#v", active) + } + + finalHistory := defaultAgent.Sessions.GetHistory(sessionKey) + if !reflect.DeepEqual(finalHistory, originalHistory) { + t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory) + } + + events := collectEventStream(sub.C) + interruptEvt, ok := findEvent(events, EventKindInterruptReceived) + if !ok { + t.Fatal("expected interrupt received event") + } + interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload) + if !ok { + t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload) + } + if interruptPayload.Kind != InterruptKindHard { + t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind) + } + + turnEndEvt, ok := findEvent(events, EventKindTurnEnd) + if !ok { + t.Fatal("expected turn end event") + } + turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload) + if !ok { + t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload) + } + if turnEndPayload.Status != TurnEndStatusAborted { + t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status) + } +} + // capturingMockProvider captures messages sent to Chat for inspection. type capturingMockProvider struct { response string diff --git a/pkg/agent/turn.go b/pkg/agent/turn.go new file mode 100644 index 000000000..c44a4f80e --- /dev/null +++ b/pkg/agent/turn.go @@ -0,0 +1,309 @@ +package agent + +import ( + "context" + "reflect" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type TurnPhase string + +const ( + TurnPhaseSetup TurnPhase = "setup" + TurnPhaseRunning TurnPhase = "running" + TurnPhaseTools TurnPhase = "tools" + TurnPhaseFinalizing TurnPhase = "finalizing" + TurnPhaseCompleted TurnPhase = "completed" + TurnPhaseAborted TurnPhase = "aborted" +) + +type ActiveTurnInfo struct { + TurnID string + AgentID string + SessionKey string + Channel string + ChatID string + UserMessage string + Phase TurnPhase + Iteration int + StartedAt time.Time +} + +type turnResult struct { + finalContent string + status TurnEndStatus + followUps []bus.InboundMessage +} + +type turnState struct { + mu sync.RWMutex + + agent *AgentInstance + opts processOptions + scope turnEventScope + + turnID string + agentID string + sessionKey string + + channel string + chatID string + userMessage string + media []string + + phase TurnPhase + iteration int + startedAt time.Time + finalContent string + + pendingSteering []providers.Message + followUps []bus.InboundMessage + + gracefulInterrupt bool + gracefulInterruptHint string + gracefulTerminalUsed bool + hardAbort bool + providerCancel context.CancelFunc + turnCancel context.CancelFunc + + restorePointHistory []providers.Message + restorePointSummary string + persistedMessages []providers.Message +} + +func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState { + return &turnState{ + agent: agent, + opts: opts, + scope: scope, + turnID: scope.turnID, + agentID: agent.ID, + sessionKey: opts.SessionKey, + channel: opts.Channel, + chatID: opts.ChatID, + userMessage: opts.UserMessage, + media: append([]string(nil), opts.Media...), + phase: TurnPhaseSetup, + startedAt: time.Now(), + } +} + +func (al *AgentLoop) registerActiveTurn(ts *turnState) { + al.activeTurnMu.Lock() + defer al.activeTurnMu.Unlock() + al.activeTurn = ts +} + +func (al *AgentLoop) clearActiveTurn(ts *turnState) { + al.activeTurnMu.Lock() + defer al.activeTurnMu.Unlock() + if al.activeTurn == ts { + al.activeTurn = nil + } +} + +func (al *AgentLoop) getActiveTurnState() *turnState { + al.activeTurnMu.RLock() + defer al.activeTurnMu.RUnlock() + return al.activeTurn +} + +func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo { + ts := al.getActiveTurnState() + if ts == nil { + return nil + } + info := ts.snapshot() + return &info +} + +func (ts *turnState) snapshot() ActiveTurnInfo { + ts.mu.RLock() + defer ts.mu.RUnlock() + + return ActiveTurnInfo{ + TurnID: ts.turnID, + AgentID: ts.agentID, + SessionKey: ts.sessionKey, + Channel: ts.channel, + ChatID: ts.chatID, + UserMessage: ts.userMessage, + Phase: ts.phase, + Iteration: ts.iteration, + StartedAt: ts.startedAt, + } +} + +func (ts *turnState) setPhase(phase TurnPhase) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.phase = phase +} + +func (ts *turnState) setIteration(iteration int) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.iteration = iteration +} + +func (ts *turnState) currentIteration() int { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.iteration +} + +func (ts *turnState) setFinalContent(content string) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.finalContent = content +} + +func (ts *turnState) finalContentLen() int { + ts.mu.RLock() + defer ts.mu.RUnlock() + return len(ts.finalContent) +} + +func (ts *turnState) setTurnCancel(cancel context.CancelFunc) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.turnCancel = cancel +} + +func (ts *turnState) setProviderCancel(cancel context.CancelFunc) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.providerCancel = cancel +} + +func (ts *turnState) clearProviderCancel(_ context.CancelFunc) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.providerCancel = nil +} + +func (ts *turnState) requestGracefulInterrupt(hint string) bool { + ts.mu.Lock() + defer ts.mu.Unlock() + if ts.hardAbort { + return false + } + ts.gracefulInterrupt = true + ts.gracefulInterruptHint = hint + return true +} + +func (ts *turnState) gracefulInterruptRequested() (bool, string) { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint +} + +func (ts *turnState) markGracefulTerminalUsed() { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.gracefulTerminalUsed = true +} + +func (ts *turnState) requestHardAbort() bool { + ts.mu.Lock() + if ts.hardAbort { + ts.mu.Unlock() + return false + } + ts.hardAbort = true + turnCancel := ts.turnCancel + providerCancel := ts.providerCancel + ts.mu.Unlock() + + if providerCancel != nil { + providerCancel() + } + if turnCancel != nil { + turnCancel() + } + return true +} + +func (ts *turnState) hardAbortRequested() bool { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.hardAbort +} + +func (ts *turnState) eventMeta(source, tracePath string) EventMeta { + snap := ts.snapshot() + return EventMeta{ + AgentID: snap.AgentID, + TurnID: snap.TurnID, + SessionKey: snap.SessionKey, + Iteration: snap.Iteration, + Source: source, + TracePath: tracePath, + } +} + +func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.restorePointHistory = append([]providers.Message(nil), history...) + ts.restorePointSummary = summary +} + +func (ts *turnState) recordPersistedMessage(msg providers.Message) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.persistedMessages = append(ts.persistedMessages, msg) +} + +func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) { + history := agent.Sessions.GetHistory(ts.sessionKey) + summary := agent.Sessions.GetSummary(ts.sessionKey) + + ts.mu.RLock() + persisted := append([]providers.Message(nil), ts.persistedMessages...) + ts.mu.RUnlock() + + if matched := matchingTurnMessageTail(history, persisted); matched > 0 { + history = append([]providers.Message(nil), history[:len(history)-matched]...) + } + + ts.captureRestorePoint(history, summary) +} + +func (ts *turnState) restoreSession(agent *AgentInstance) error { + ts.mu.RLock() + history := append([]providers.Message(nil), ts.restorePointHistory...) + summary := ts.restorePointSummary + ts.mu.RUnlock() + + agent.Sessions.SetHistory(ts.sessionKey, history) + agent.Sessions.SetSummary(ts.sessionKey, summary) + return agent.Sessions.Save(ts.sessionKey) +} + +func matchingTurnMessageTail(history, persisted []providers.Message) int { + maxMatch := min(len(history), len(persisted)) + for size := maxMatch; size > 0; size-- { + if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) { + return size + } + } + return 0 +} + +func (ts *turnState) interruptHintMessage() providers.Message { + _, hint := ts.gracefulInterruptRequested() + content := "Interrupt requested. Stop scheduling tools and provide a short final summary." + if hint != "" { + content += "\n\nInterrupt hint: " + hint + } + return providers.Message{ + Role: "user", + Content: content, + } +}