diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go index 13f2f2282..9acc6ddd8 100644 --- a/pkg/agent/eventbus_test.go +++ b/pkg/agent/eventbus_test.go @@ -334,6 +334,9 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) { if interruptPayload.Role != "user" { t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role) } + if interruptPayload.Kind != InterruptKindSteering { + t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind) + } if interruptPayload.ContentLen != len("change course") { t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen) } diff --git a/pkg/agent/events.go b/pkg/agent/events.go index fae5033a3..95e4c90d0 100644 --- a/pkg/agent/events.go +++ b/pkg/agent/events.go @@ -105,6 +105,8 @@ const ( TurnEndStatusCompleted TurnEndStatus = "completed" // TurnEndStatusError indicates the turn ended because of an error. TurnEndStatusError TurnEndStatus = "error" + // TurnEndStatusAborted indicates the turn was hard-aborted and rolled back. + TurnEndStatusAborted TurnEndStatus = "aborted" ) // TurnStartPayload describes the start of a turn. @@ -215,11 +217,21 @@ type FollowUpQueuedPayload struct { ContentLen int } -// InterruptReceivedPayload describes a queued soft interrupt. +type InterruptKind string + +const ( + InterruptKindSteering InterruptKind = "steering" + InterruptKindGraceful InterruptKind = "graceful" + InterruptKindHard InterruptKind = "hard_abort" +) + +// InterruptReceivedPayload describes accepted turn-control input. type InterruptReceivedPayload struct { + Kind InterruptKind Role string ContentLen int QueueDepth int + HintLen int } // SubTurnSpawnPayload describes the creation of a child turn. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 877dbbd94..f54482ae8 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -50,6 +50,8 @@ type AgentLoop struct { mcp mcpRuntime steering *steeringQueue mu sync.RWMutex + activeTurnMu sync.RWMutex + activeTurn *turnState turnSeq atomic.Uint64 // Track active requests for safe provider cleanup activeRequests sync.WaitGroup @@ -69,6 +71,12 @@ type processOptions struct { SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) } +type continuationTarget struct { + SessionKey string + Channel string + ChatID string +} + const ( defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." sessionKeyAgentPrefix = "agent:" @@ -292,38 +300,46 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.GetRegistry().GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() - } - } + al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, response) + } + + target, targetErr := al.buildContinuationTarget(msg) + if targetErr != nil { + logger.WarnCF("agent", "Failed to build steering continuation target", + map[string]any{ + "channel": msg.Channel, + "error": targetErr.Error(), + }) + return + } + if target == nil { + return + } + + for al.pendingSteeringCount() > 0 { + logger.InfoCF("agent", "Continuing queued steering after turn end", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "session_key": target.SessionKey, + "queue_depth": al.pendingSteeringCount(), + }) + + continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID) + if continueErr != nil { + logger.WarnCF("agent", "Failed to continue queued steering", + map[string]any{ + "channel": target.Channel, + "chat_id": target.ChatID, + "error": continueErr.Error(), + }) + return + } + if continued == "" { + return } - if !alreadySent { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) - logger.InfoCF("agent", "Published outbound response", - map[string]any{ - "channel": msg.Channel, - "chat_id": msg.ChatID, - "content_len": len(response), - }) - } else { - logger.DebugCF( - "agent", - "Skipped outbound (message tool already sent)", - map[string]any{"channel": msg.Channel}, - ) - } + al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, continued) } }() } @@ -369,6 +385,67 @@ func (al *AgentLoop) Stop() { al.running.Store(false) } +func (al *AgentLoop) publishResponseIfNeeded(ctx context.Context, channel, chatID, response string) { + if response == "" { + return + } + + alreadySent := false + defaultAgent := al.GetRegistry().GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } + } + } + + if alreadySent { + logger.DebugCF( + "agent", + "Skipped outbound (message tool already sent)", + map[string]any{"channel": channel}, + ) + return + } + + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: response, + }) + logger.InfoCF("agent", "Published outbound response", + map[string]any{ + "channel": channel, + "chat_id": chatID, + "content_len": len(response), + }) +} + +func (al *AgentLoop) pendingSteeringCount() int { + if al.steering == nil { + return 0 + } + return al.steering.len() +} + +func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuationTarget, error) { + if msg.Channel == "system" { + return nil, nil + } + + route, _, err := al.resolveMessageRoute(msg) + if err != nil { + return nil, err + } + + return &continuationTarget{ + SessionKey: resolveScopeKey(route, msg.SessionKey), + Channel: msg.Channel, + ChatID: msg.ChatID, + }, nil +} + // Close releases resources held by agent session stores. Call after Stop. func (al *AgentLoop) Close() { mcpManager := al.mcp.takeManager() @@ -543,9 +620,11 @@ func (al *AgentLoop) logEvent(evt Event) { fields["chat_id"] = payload.ChatID fields["content_len"] = payload.ContentLen case InterruptReceivedPayload: + fields["interrupt_kind"] = payload.Kind fields["role"] = payload.Role fields["content_len"] = payload.ContentLen fields["queue_depth"] = payload.QueueDepth + fields["hint_len"] = payload.HintLen case SubTurnSpawnPayload: fields["child_agent_id"] = payload.AgentID fields["label"] = payload.Label @@ -1071,153 +1150,63 @@ func (al *AgentLoop) processSystemMessage( }) } -// runAgentLoop is the core message processing logic. +// runAgentLoop remains the top-level shell that starts a turn and publishes +// any post-turn work. runTurn owns the full turn lifecycle. func (al *AgentLoop) runAgentLoop( ctx context.Context, agent *AgentInstance, opts processOptions, ) (string, error) { - turnScope := al.newTurnEventScope(agent.ID, opts.SessionKey) - turnStartedAt := time.Now() - turnIterations := 0 - turnFinalContentLen := 0 - turnStatus := TurnEndStatusCompleted - defer func() { - al.emitEvent( - EventKindTurnEnd, - turnScope.meta(turnIterations, "runAgentLoop", "turn.end"), - TurnEndPayload{ - Status: turnStatus, - Iterations: turnIterations, - Duration: time.Since(turnStartedAt), - FinalContentLen: turnFinalContentLen, - }, - ) - }() - - al.emitEvent( - EventKindTurnStart, - turnScope.meta(0, "runAgentLoop", "turn.start"), - TurnStartPayload{ - Channel: opts.Channel, - ChatID: opts.ChatID, - UserMessage: opts.UserMessage, - MediaCount: len(opts.Media), - }, - ) - - // 0. Record last channel for heartbeat notifications (skip internal channels and cli) - if opts.Channel != "" && opts.ChatID != "" { - if !constants.IsInternalChannel(opts.Channel) { - channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) - if err := al.RecordLastChannel(channelKey); err != nil { - logger.WarnCF( - "agent", - "Failed to record last channel", - map[string]any{"error": err.Error()}, - ) - } - } - } - - // 1. Build messages (skip history for heartbeat) - var history []providers.Message - var summary string - if !opts.NoHistory { - history = agent.Sessions.GetHistory(opts.SessionKey) - summary = agent.Sessions.GetSummary(opts.SessionKey) - } - messages := agent.ContextBuilder.BuildMessages( - history, - summary, - opts.UserMessage, - opts.Media, - opts.Channel, - opts.ChatID, - ) - - // Resolve media:// refs: images→base64 data URLs, non-images→local paths in content - cfg := al.GetConfig() - maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize() - messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) - - // 1.5. Proactive context budget check: compress before LLM call - // rather than waiting for a 400 context-length error. - if !opts.NoHistory { - toolDefs := agent.Tools.ToProviderDefs() - if isOverContextBudget(agent.ContextWindow, messages, toolDefs, agent.MaxTokens) { - logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call", - map[string]any{"session_key": opts.SessionKey}) - if compression, ok := al.forceCompression(agent, opts.SessionKey); ok { - al.emitEvent( - EventKindContextCompress, - turnScope.meta(0, "runAgentLoop", "turn.context.compress"), - ContextCompressPayload{ - Reason: ContextCompressReasonProactive, - DroppedMessages: compression.DroppedMessages, - RemainingMessages: compression.RemainingMessages, - }, - ) - } - newHistory := agent.Sessions.GetHistory(opts.SessionKey) - newSummary := agent.Sessions.GetSummary(opts.SessionKey) - messages = agent.ContextBuilder.BuildMessages( - newHistory, newSummary, opts.UserMessage, - opts.Media, opts.Channel, opts.ChatID, + if opts.Channel != "" && opts.ChatID != "" && !constants.IsInternalChannel(opts.Channel) { + channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) + if err := al.RecordLastChannel(channelKey); err != nil { + logger.WarnCF( + "agent", + "Failed to record last channel", + map[string]any{"error": err.Error()}, ) - messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) } } - // 2. Save user message to session - agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - - // 3. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts, turnScope) - turnIterations = iteration + ts := newTurnState(agent, opts, al.newTurnEventScope(agent.ID, opts.SessionKey)) + result, err := al.runTurn(ctx, ts) if err != nil { - turnStatus = TurnEndStatusError return "", err } - - // If last tool had ForUser content and we already sent it, we might not need to send final response - // This is controlled by the tool's Silent flag and ForUser content - - // 4. Handle empty response - if finalContent == "" { - finalContent = opts.DefaultResponse - } - turnFinalContentLen = len(finalContent) - - // 5. Save final assistant message to session - agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - agent.Sessions.Save(opts.SessionKey) - - // 6. Optional: summarization - if opts.EnableSummary { - al.maybeSummarize(agent, opts.SessionKey, turnScope) + if result.status == TurnEndStatusAborted { + return "", nil } - // 7. Optional: send response via bus - if opts.SendResponse { + for _, followUp := range result.followUps { + if pubErr := al.bus.PublishInbound(ctx, followUp); pubErr != nil { + logger.WarnCF("agent", "Failed to publish follow-up after turn", + map[string]any{ + "turn_id": ts.turnID, + "error": pubErr.Error(), + }) + } + } + + if opts.SendResponse && result.finalContent != "" { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: finalContent, + Content: result.finalContent, }) } - // 8. Log response - responsePreview := utils.Truncate(finalContent, 120) - logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), - map[string]any{ - "agent_id": agent.ID, - "session_key": opts.SessionKey, - "iterations": iteration, - "final_length": len(finalContent), - }) + if result.finalContent != "" { + responsePreview := utils.Truncate(result.finalContent, 120) + logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), + map[string]any{ + "agent_id": agent.ID, + "session_key": opts.SessionKey, + "iterations": ts.currentIteration(), + "final_length": len(result.finalContent), + }) + } - return finalContent, nil + return result.finalContent, nil } func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) { @@ -1276,54 +1265,135 @@ func (al *AgentLoop) handleReasoning( } } -// runLLMIteration executes the LLM call loop with tool handling. -func (al *AgentLoop) runLLMIteration( - ctx context.Context, - agent *AgentInstance, - messages []providers.Message, - opts processOptions, - turnScope turnEventScope, -) (string, int, error) { - iteration := 0 - var finalContent string - var pendingMessages []providers.Message +func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, error) { + turnCtx, turnCancel := context.WithCancel(ctx) + defer turnCancel() + ts.setTurnCancel(turnCancel) - // Poll for steering messages at loop start (in case the user typed while - // the agent was setting up), unless the caller already provided initial - // steering messages (e.g. Continue). - if !opts.SkipInitialSteeringPoll { - if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 { - pendingMessages = msgs + al.registerActiveTurn(ts) + defer al.clearActiveTurn(ts) + + turnStatus := TurnEndStatusCompleted + defer func() { + al.emitEvent( + EventKindTurnEnd, + ts.eventMeta("runTurn", "turn.end"), + TurnEndPayload{ + Status: turnStatus, + Iterations: ts.currentIteration(), + Duration: time.Since(ts.startedAt), + FinalContentLen: ts.finalContentLen(), + }, + ) + }() + + al.emitEvent( + EventKindTurnStart, + ts.eventMeta("runTurn", "turn.start"), + TurnStartPayload{ + Channel: ts.channel, + ChatID: ts.chatID, + UserMessage: ts.userMessage, + MediaCount: len(ts.media), + }, + ) + + var history []providers.Message + var summary string + if !ts.opts.NoHistory { + history = ts.agent.Sessions.GetHistory(ts.sessionKey) + summary = ts.agent.Sessions.GetSummary(ts.sessionKey) + } + ts.captureRestorePoint(history, summary) + + messages := ts.agent.ContextBuilder.BuildMessages( + history, + summary, + ts.userMessage, + ts.media, + ts.channel, + ts.chatID, + ) + + cfg := al.GetConfig() + maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize() + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) + + if !ts.opts.NoHistory { + toolDefs := ts.agent.Tools.ToProviderDefs() + if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) { + logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call", + map[string]any{"session_key": ts.sessionKey}) + if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok { + al.emitEvent( + EventKindContextCompress, + ts.eventMeta("runTurn", "turn.context.compress"), + ContextCompressPayload{ + Reason: ContextCompressReasonProactive, + DroppedMessages: compression.DroppedMessages, + RemainingMessages: compression.RemainingMessages, + }, + ) + ts.refreshRestorePointFromSession(ts.agent) + } + newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey) + newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey) + messages = ts.agent.ContextBuilder.BuildMessages( + newHistory, newSummary, ts.userMessage, + ts.media, ts.channel, ts.chatID, + ) + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) } } - // Determine effective model tier for this conversation turn. - // selectCandidates evaluates routing once and the decision is sticky for - // all tool-follow-up iterations within the same turn so that a multi-step - // tool chain doesn't switch models mid-way through. - activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages) + if !ts.opts.NoHistory { + rootMsg := providers.Message{Role: "user", Content: ts.userMessage} + ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content) + ts.recordPersistedMessage(rootMsg) + } - for iteration < agent.MaxIterations || len(pendingMessages) > 0 { - iteration++ + activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages) + var pendingMessages []providers.Message + var finalContent string + + for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool { + graceful, _ := ts.gracefulInterruptRequested() + return graceful + }() { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + + iteration := ts.currentIteration() + 1 + ts.setIteration(iteration) + ts.setPhase(TurnPhaseRunning) + + if iteration > 1 || !ts.opts.SkipInitialSteeringPoll { + if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + pendingMessages = append(pendingMessages, steerMsgs...) + } + } - // Inject pending steering messages into the conversation context - // before the next LLM call. if len(pendingMessages) > 0 { totalContentLen := 0 for _, pm := range pendingMessages { messages = append(messages, pm) - agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content) totalContentLen += len(pm.Content) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddMessage(ts.sessionKey, pm.Role, pm.Content) + ts.recordPersistedMessage(pm) + } logger.InfoCF("agent", "Injected steering message into context", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "content_len": len(pm.Content), }) } al.emitEvent( EventKindSteeringInjected, - turnScope.meta(iteration, "runLLMIteration", "turn.steering.injected"), + ts.eventMeta("runTurn", "turn.steering.injected"), SteeringInjectedPayload{ Count: len(pendingMessages), TotalContentLen: totalContentLen, @@ -1334,78 +1404,81 @@ func (al *AgentLoop) runLLMIteration( logger.DebugCF("agent", "LLM iteration", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, - "max": agent.MaxIterations, + "max": ts.agent.MaxIterations, }) - // Build tool definitions - providerToolDefs := agent.Tools.ToProviderDefs() + gracefulTerminal, _ := ts.gracefulInterruptRequested() + providerToolDefs := ts.agent.Tools.ToProviderDefs() + callMessages := messages + if gracefulTerminal { + callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage()) + providerToolDefs = nil + ts.markGracefulTerminalUsed() + } + al.emitEvent( EventKindLLMRequest, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.request"), + ts.eventMeta("runTurn", "turn.llm.request"), LLMRequestPayload{ Model: activeModel, - MessagesCount: len(messages), + MessagesCount: len(callMessages), ToolsCount: len(providerToolDefs), - MaxTokens: agent.MaxTokens, - Temperature: agent.Temperature, + MaxTokens: ts.agent.MaxTokens, + Temperature: ts.agent.Temperature, }, ) - // Log LLM request details logger.DebugCF("agent", "LLM request", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "model": activeModel, - "messages_count": len(messages), + "messages_count": len(callMessages), "tools_count": len(providerToolDefs), - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "system_prompt_len": len(messages[0].Content), + "max_tokens": ts.agent.MaxTokens, + "temperature": ts.agent.Temperature, + "system_prompt_len": len(callMessages[0].Content), }) - - // Log full messages (detailed) logger.DebugCF("agent", "Full LLM request", map[string]any{ "iteration": iteration, - "messages_json": formatMessagesForLog(messages), + "messages_json": formatMessagesForLog(callMessages), "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM with fallback chain if multiple candidates are configured. - var response *providers.LLMResponse - var err error - llmOpts := map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, + "max_tokens": ts.agent.MaxTokens, + "temperature": ts.agent.Temperature, + "prompt_cache_key": ts.agent.ID, } - // parseThinkingLevel guarantees ThinkingOff for empty/unknown values, - // so checking != ThinkingOff is sufficient. - if agent.ThinkingLevel != ThinkingOff { - if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() { - llmOpts["thinking_level"] = string(agent.ThinkingLevel) + if ts.agent.ThinkingLevel != ThinkingOff { + if tc, ok := ts.agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() { + llmOpts["thinking_level"] = string(ts.agent.ThinkingLevel) } else { logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring", - map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)}) + map[string]any{"agent_id": ts.agent.ID, "thinking_level": string(ts.agent.ThinkingLevel)}) } } - callLLM := func() (*providers.LLMResponse, error) { + callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) { + providerCtx, providerCancel := context.WithCancel(turnCtx) + ts.setProviderCancel(providerCancel) + defer func() { + providerCancel() + ts.clearProviderCancel(providerCancel) + }() + al.activeRequests.Add(1) defer al.activeRequests.Done() - // TODO(eventbus): emit EventKindLLMDelta when providers expose - // streaming callbacks instead of only the final Chat response. if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( - ctx, + providerCtx, activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts) + return ts.agent.Provider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts) }, ) if fbErr != nil { @@ -1416,32 +1489,34 @@ func (al *AgentLoop) runLLMIteration( "agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), - map[string]any{"agent_id": agent.ID, "iteration": iteration}, + map[string]any{"agent_id": ts.agent.ID, "iteration": iteration}, ) } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts) + return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts) } - // Retry loop for context/token errors + var response *providers.LLMResponse + var err error maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { - response, err = callLLM() + response, err = callLLM(callMessages, providerToolDefs) if err == nil { break } + if ts.hardAbortRequested() && errors.Is(err, context.Canceled) { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } errMsg := strings.ToLower(err.Error()) - - // Check if this is a network/HTTP timeout — not a context window error. isTimeoutError := errors.Is(err, context.DeadlineExceeded) || strings.Contains(errMsg, "deadline exceeded") || strings.Contains(errMsg, "client.timeout") || strings.Contains(errMsg, "timed out") || strings.Contains(errMsg, "timeout exceeded") - // Detect real context window / token limit errors, excluding network timeouts. isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") || strings.Contains(errMsg, "context window") || strings.Contains(errMsg, "maximum context length") || @@ -1456,7 +1531,7 @@ func (al *AgentLoop) runLLMIteration( backoff := time.Duration(retry+1) * 5 * time.Second al.emitEvent( EventKindLLMRetry, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"), + ts.eventMeta("runTurn", "turn.llm.retry"), LLMRetryPayload{ Attempt: retry + 1, MaxRetries: maxRetries, @@ -1470,14 +1545,21 @@ func (al *AgentLoop) runLLMIteration( "retry": retry, "backoff": backoff.String(), }) - time.Sleep(backoff) + if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + err = sleepErr + break + } continue } - if isContextError && retry < maxRetries { + if isContextError && retry < maxRetries && !ts.opts.NoHistory { al.emitEvent( EventKindLLMRetry, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"), + ts.eventMeta("runTurn", "turn.llm.retry"), LLMRetryPayload{ Attempt: retry + 1, MaxRetries: maxRetries, @@ -1494,40 +1576,47 @@ func (al *AgentLoop) runLLMIteration( }, ) - if retry == 0 && !constants.IsInternalChannel(opts.Channel) { + if retry == 0 && !constants.IsInternalChannel(ts.channel) { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Content: "Context window exceeded. Compressing history and retrying...", }) } - if compression, ok := al.forceCompression(agent, opts.SessionKey); ok { + if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok { al.emitEvent( EventKindContextCompress, - turnScope.meta(iteration, "runLLMIteration", "turn.context.compress"), + ts.eventMeta("runTurn", "turn.context.compress"), ContextCompressPayload{ Reason: ContextCompressReasonRetry, DroppedMessages: compression.DroppedMessages, RemainingMessages: compression.RemainingMessages, }, ) + ts.refreshRestorePointFromSession(ts.agent) } - newHistory := agent.Sessions.GetHistory(opts.SessionKey) - newSummary := agent.Sessions.GetSummary(opts.SessionKey) - messages = agent.ContextBuilder.BuildMessages( + + newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey) + newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey) + messages = ts.agent.ContextBuilder.BuildMessages( newHistory, newSummary, "", - nil, opts.Channel, opts.ChatID, + nil, ts.channel, ts.chatID, ) + callMessages = messages + if gracefulTerminal { + callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage()) + } continue } break } if err != nil { + turnStatus = TurnEndStatusError al.emitEvent( EventKindError, - turnScope.meta(iteration, "runLLMIteration", "turn.error"), + ts.eventMeta("runTurn", "turn.error"), ErrorPayload{ Stage: "llm", Message: err.Error(), @@ -1535,23 +1624,23 @@ func (al *AgentLoop) runLLMIteration( ) logger.ErrorCF("agent", "LLM call failed", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "model": activeModel, "error": err.Error(), }) - return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) + return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err) } go al.handleReasoning( - ctx, + turnCtx, response.Reasoning, - opts.Channel, - al.targetReasoningChannelID(opts.Channel), + ts.channel, + al.targetReasoningChannelID(ts.channel), ) al.emitEvent( EventKindLLMResponse, - turnScope.meta(iteration, "runLLMIteration", "turn.llm.response"), + ts.eventMeta("runTurn", "turn.llm.response"), LLMResponsePayload{ ContentLen: len(response.Content), ToolCalls: len(response.ToolCalls), @@ -1561,23 +1650,23 @@ func (al *AgentLoop) runLLMIteration( logger.DebugCF("agent", "LLM response", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "content_chars": len(response.Content), "tool_calls": len(response.ToolCalls), "reasoning": response.Reasoning, - "target_channel": al.targetReasoningChannelID(opts.Channel), - "channel": opts.Channel, + "target_channel": al.targetReasoningChannelID(ts.channel), + "channel": ts.channel, }) - // Check if no tool calls - then check reasoning content if any - if len(response.ToolCalls) == 0 { + + if len(response.ToolCalls) == 0 || gracefulTerminal { finalContent = response.Content if finalContent == "" && response.ReasoningContent != "" { finalContent = response.ReasoningContent } logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "iteration": iteration, "content_chars": len(finalContent), }) @@ -1589,20 +1678,18 @@ func (al *AgentLoop) runLLMIteration( normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc)) } - // Log tool calls toolNames := make([]string, 0, len(normalizedToolCalls)) for _, tc := range normalizedToolCalls { toolNames = append(toolNames, tc.Name) } logger.InfoCF("agent", "LLM requested tool calls", map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "tools": toolNames, "count": len(normalizedToolCalls), "iteration": iteration, }) - // Build assistant message with tool calls assistantMsg := providers.Message{ Role: "assistant", Content: response.Content, @@ -1610,13 +1697,11 @@ func (al *AgentLoop) runLLMIteration( } for _, tc := range normalizedToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) - // Copy ExtraContent to ensure thought_signature is persisted for Gemini 3 extraContent := tc.ExtraContent thoughtSignature := "" if tc.Function != nil { thoughtSignature = tc.Function.ThoughtSignature } - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ ID: tc.ID, Type: "function", @@ -1631,40 +1716,44 @@ func (al *AgentLoop) runLLMIteration( }) } messages = append(messages, assistantMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg) + ts.recordPersistedMessage(assistantMsg) + } - // Save assistant message with tool calls to session - agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) - - // Execute tool calls sequentially. After each tool completes, check - // for steering messages. If any are found, skip remaining tools. - var steeringAfterTools []providers.Message - + ts.setPhase(TurnPhaseTools) for i, tc := range normalizedToolCalls { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), map[string]any{ - "agent_id": agent.ID, + "agent_id": ts.agent.ID, "tool": tc.Name, "iteration": iteration, }) al.emitEvent( EventKindToolExecStart, - turnScope.meta(iteration, "runLLMIteration", "turn.tool.start"), + ts.eventMeta("runTurn", "turn.tool.start"), ToolExecStartPayload{ Tool: tc.Name, Arguments: cloneEventArguments(tc.Arguments), }, ) - // Create async callback for tools that implement AsyncExecutor. + toolCall := tc + toolIteration := iteration asyncCallback := func(_ context.Context, result *tools.ToolResult) { if !result.Silent && result.ForUser != "" { outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) defer outCancel() _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Content: result.ForUser, }) } @@ -1679,17 +1768,17 @@ func (al *AgentLoop) runLLMIteration( logger.InfoCF("agent", "Async tool completed, publishing result", map[string]any{ - "tool": tc.Name, + "tool": toolCall.Name, "content_len": len(content), - "channel": opts.Channel, + "channel": ts.channel, }) al.emitEvent( EventKindFollowUpQueued, - turnScope.meta(iteration, "runLLMIteration", "turn.follow_up.queued"), + ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"), FollowUpQueuedPayload{ - SourceTool: tc.Name, - Channel: opts.Channel, - ChatID: opts.ChatID, + SourceTool: toolCall.Name, + Channel: ts.channel, + ChatID: ts.chatID, ContentLen: len(content), }, ) @@ -1698,33 +1787,37 @@ func (al *AgentLoop) runLLMIteration( defer pubCancel() _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ Channel: "system", - SenderID: fmt.Sprintf("async:%s", tc.Name), - ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID), + SenderID: fmt.Sprintf("async:%s", toolCall.Name), + ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID), Content: content, }) } toolStart := time.Now() - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, + toolResult := ts.agent.Tools.ExecuteWithContext( + turnCtx, + toolCall.Name, + toolCall.Arguments, + ts.channel, + ts.chatID, asyncCallback, ) toolDuration := time.Since(toolStart) - // Process tool result - if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + + if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Content: toolResult.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": tc.Name, + "tool": toolCall.Name, "content_len": len(toolResult.ForUser), }) } @@ -1743,8 +1836,8 @@ func (al *AgentLoop) runLLMIteration( parts = append(parts, part) } al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, + Channel: ts.channel, + ChatID: ts.chatID, Parts: parts, }) } @@ -1757,13 +1850,13 @@ func (al *AgentLoop) runLLMIteration( toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, + ToolCallID: toolCall.ID, } al.emitEvent( EventKindToolExecEnd, - turnScope.meta(iteration, "runLLMIteration", "turn.tool.end"), + ts.eventMeta("runTurn", "turn.tool.end"), ToolExecEndPayload{ - Tool: tc.Name, + Tool: toolCall.Name, Duration: toolDuration, ForLLMLen: len(contentForLLM), ForUserLen: len(toolResult.ForUser), @@ -1772,67 +1865,136 @@ func (al *AgentLoop) runLLMIteration( }, ) messages = append(messages, toolResultMsg) - agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg) + ts.recordPersistedMessage(toolResultMsg) + } - // After EVERY tool (including the first and last), check for - // steering messages. If found and there are remaining tools, - // skip them all. if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 { + pendingMessages = append(pendingMessages, steerMsgs...) + } + + skipReason := "" + skipMessage := "" + if len(pendingMessages) > 0 { + skipReason = "queued user steering message" + skipMessage = "Skipped due to queued user message." + } else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending { + skipReason = "graceful interrupt requested" + skipMessage = "Skipped due to graceful interrupt." + } + + if skipReason != "" { remaining := len(normalizedToolCalls) - i - 1 if remaining > 0 { - logger.InfoCF("agent", "Steering interrupt: skipping remaining tools", + logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools", map[string]any{ - "agent_id": agent.ID, - "completed": i + 1, - "skipped": remaining, - "total_tools": len(normalizedToolCalls), - "steering_count": len(steerMsgs), + "agent_id": ts.agent.ID, + "completed": i + 1, + "skipped": remaining, + "reason": skipReason, }) - - // Mark remaining tool calls as skipped for j := i + 1; j < len(normalizedToolCalls); j++ { skippedTC := normalizedToolCalls[j] al.emitEvent( EventKindToolExecSkipped, - turnScope.meta(iteration, "runLLMIteration", "turn.tool.skipped"), + ts.eventMeta("runTurn", "turn.tool.skipped"), ToolExecSkippedPayload{ Tool: skippedTC.Name, - Reason: "queued user steering message", + Reason: skipReason, }, ) - toolResultMsg := providers.Message{ + skippedMsg := providers.Message{ Role: "tool", - Content: "Skipped due to queued user message.", + Content: skipMessage, ToolCallID: skippedTC.ID, } - messages = append(messages, toolResultMsg) - agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + messages = append(messages, skippedMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg) + ts.recordPersistedMessage(skippedMsg) + } } } - steeringAfterTools = steerMsgs break } } - // If steering messages were captured during tool execution, they - // become pendingMessages for the next iteration of the inner loop. - if len(steeringAfterTools) > 0 { - pendingMessages = steeringAfterTools - } - - // Tick down TTL of discovered tools after processing tool results. - // Only reached when tool calls were made (the loop continues); - // the break on no-tool-call responses skips this. - // NOTE: This is safe because processMessage is sequential per agent. - // If per-agent concurrency is added, TTL consistency between - // ToProviderDefs and Get must be re-evaluated. - agent.Tools.TickTTL() + ts.agent.Tools.TickTTL() logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{ - "agent_id": agent.ID, "iteration": iteration, + "agent_id": ts.agent.ID, "iteration": iteration, }) } - return finalContent, iteration, nil + if ts.hardAbortRequested() { + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + + if finalContent == "" { + finalContent = ts.opts.DefaultResponse + } + + ts.setPhase(TurnPhaseFinalizing) + ts.setFinalContent(finalContent) + if !ts.opts.NoHistory { + finalMsg := providers.Message{Role: "assistant", Content: finalContent} + ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content) + ts.recordPersistedMessage(finalMsg) + if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil { + turnStatus = TurnEndStatusError + al.emitEvent( + EventKindError, + ts.eventMeta("runTurn", "turn.error"), + ErrorPayload{ + Stage: "session_save", + Message: err.Error(), + }, + ) + return turnResult{}, err + } + } + + if ts.opts.EnableSummary { + al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope) + } + + ts.setPhase(TurnPhaseCompleted) + return turnResult{ + finalContent: finalContent, + status: turnStatus, + followUps: append([]bus.InboundMessage(nil), ts.followUps...), + }, nil +} + +func (al *AgentLoop) abortTurn(ts *turnState) (turnResult, error) { + ts.setPhase(TurnPhaseAborted) + if !ts.opts.NoHistory { + if err := ts.restoreSession(ts.agent); err != nil { + al.emitEvent( + EventKindError, + ts.eventMeta("abortTurn", "turn.error"), + ErrorPayload{ + Stage: "session_restore", + Message: err.Error(), + }, + ) + return turnResult{}, err + } + } + return turnResult{status: TurnEndStatusAborted}, nil +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } } // selectCandidates returns the model candidates and resolved model name to use diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index 90d1cc091..77c2e0c17 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -122,20 +122,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error { "content_len": len(msg.Content), "queue_len": al.steering.len(), }) - agentID := "" - if registry := al.GetRegistry(); registry != nil { + + meta := EventMeta{ + Source: "Steer", + TracePath: "turn.interrupt.received", + } + if ts := al.getActiveTurnState(); ts != nil { + meta = ts.eventMeta("Steer", "turn.interrupt.received") + } else if registry := al.GetRegistry(); registry != nil { if agent := registry.GetDefaultAgent(); agent != nil { - agentID = agent.ID + meta.AgentID = agent.ID } } al.emitEvent( EventKindInterruptReceived, - EventMeta{ - AgentID: agentID, - Source: "Steer", - TracePath: "turn.interrupt.received", - }, + meta, InterruptReceivedPayload{ + Kind: InterruptKindSteering, Role: msg.Role, ContentLen: len(msg.Content), QueueDepth: al.steering.len(), @@ -177,6 +180,10 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message { // // If no steering messages are pending, it returns an empty string. func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) { + if active := al.GetActiveTurn(); active != nil { + return "", fmt.Errorf("turn %s is still active", active.TurnID) + } + steeringMsgs := al.dequeueSteeringMessages() if len(steeringMsgs) == 0 { return "", nil @@ -187,6 +194,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s return "", fmt.Errorf("no default agent available") } + if tool, ok := agent.Tools.Get("message"); ok { + if resetter, ok := tool.(interface{ ResetSentInRound() }); ok { + resetter.ResetSentInRound() + } + } + // Build a combined user message from the steering messages. var contents []string for _, msg := range steeringMsgs { @@ -205,3 +218,44 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s SkipInitialSteeringPoll: true, }) } + +func (al *AgentLoop) InterruptGraceful(hint string) error { + ts := al.getActiveTurnState() + if ts == nil { + return fmt.Errorf("no active turn") + } + if !ts.requestGracefulInterrupt(hint) { + return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID) + } + + al.emitEvent( + EventKindInterruptReceived, + ts.eventMeta("InterruptGraceful", "turn.interrupt.received"), + InterruptReceivedPayload{ + Kind: InterruptKindGraceful, + HintLen: len(hint), + }, + ) + + return nil +} + +func (al *AgentLoop) InterruptHard() error { + ts := al.getActiveTurnState() + if ts == nil { + return fmt.Errorf("no active turn") + } + if !ts.requestHardAbort() { + return fmt.Errorf("turn %s is already aborting", ts.turnID) + } + + al.emitEvent( + EventKindInterruptReceived, + ts.eventMeta("InterruptHard", "turn.interrupt.received"), + InterruptReceivedPayload{ + Kind: InterruptKindHard, + }, + ) + + return nil +} diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index e8cdb2344..bb5d42c73 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os" + "reflect" "sync" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -396,6 +398,103 @@ func (m *toolCallProvider) GetDefaultModel() string { return "tool-call-mock" } +type gracefulCaptureProvider struct { + mu sync.Mutex + calls int + toolCalls []providers.ToolCall + finalResp string + terminalMessages []providers.Message + terminalToolsCount int +} + +func (p *gracefulCaptureProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.calls++ + + if p.calls == 1 { + return &providers.LLMResponse{ + ToolCalls: p.toolCalls, + }, nil + } + + p.terminalMessages = append([]providers.Message(nil), messages...) + p.terminalToolsCount = len(tools) + return &providers.LLMResponse{ + Content: p.finalResp, + }, nil +} + +func (p *gracefulCaptureProvider) GetDefaultModel() string { + return "graceful-capture-mock" +} + +type lateSteeringProvider struct { + mu sync.Mutex + calls int + firstCallStarted chan struct{} + releaseFirstCall chan struct{} + firstStartOnce sync.Once + secondCallMessages []providers.Message +} + +func (p *lateSteeringProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + p.calls++ + call := p.calls + p.mu.Unlock() + + if call == 1 { + p.firstStartOnce.Do(func() { close(p.firstCallStarted) }) + <-p.releaseFirstCall + return &providers.LLMResponse{Content: "first response"}, nil + } + + p.mu.Lock() + p.secondCallMessages = append([]providers.Message(nil), messages...) + p.mu.Unlock() + return &providers.LLMResponse{Content: "continued response"}, nil +} + +func (p *lateSteeringProvider) GetDefaultModel() string { + return "late-steering-mock" +} + +type interruptibleTool struct { + name string + started chan struct{} + once sync.Once +} + +func (t *interruptibleTool) Name() string { return t.name } +func (t *interruptibleTool) Description() string { return "interruptible tool for testing" } +func (t *interruptibleTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + if t.started != nil { + t.once.Do(func() { close(t.started) }) + } + <-ctx.Done() + return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err()) +} + func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { @@ -568,6 +667,427 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) { } } +func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &lateSteeringProvider{ + firstCallStarted: make(chan struct{}), + releaseFirstCall: make(chan struct{}), + } + al := NewAgentLoop(cfg, msgBus, provider) + + runCtx, cancelRun := context.WithCancel(context.Background()) + defer cancelRun() + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- al.Run(runCtx) + }() + + first := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "first message", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + late := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "late append", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer pubCancel() + if err := msgBus.PublishInbound(pubCtx, first); err != nil { + t.Fatalf("publish first inbound: %v", err) + } + + select { + case <-provider.firstCallStarted: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for first provider call to start") + } + + if err := msgBus.PublishInbound(pubCtx, late); err != nil { + t.Fatalf("publish late inbound: %v", err) + } + + close(provider.releaseFirstCall) + + subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer subCancel() + + out1, ok := msgBus.SubscribeOutbound(subCtx) + if !ok { + t.Fatal("expected first outbound response") + } + if out1.Content != "first response" { + t.Fatalf("expected first response, got %q", out1.Content) + } + + out2, ok := msgBus.SubscribeOutbound(subCtx) + if !ok { + t.Fatal("expected continued outbound response") + } + if out2.Content != "continued response" { + t.Fatalf("expected continued response, got %q", out2.Content) + } + + cancelRun() + select { + case err := <-runErrCh: + if err != nil { + t.Fatalf("Run returned error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for Run to stop") + } + + provider.mu.Lock() + calls := provider.calls + secondMessages := append([]providers.Message(nil), provider.secondCallMessages...) + provider.mu.Unlock() + + if calls != 2 { + t.Fatalf("expected 2 provider calls, got %d", calls) + } + + foundLateMessage := false + for _, msg := range secondMessages { + if msg.Role == "user" && msg.Content == "late append" { + foundLateMessage = true + break + } + } + if !foundLateMessage { + t.Fatal("expected queued late message to be processed in an automatic follow-up turn") + } +} + +func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + tool1ExecCh := make(chan struct{}) + tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh} + tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond} + + provider := &gracefulCaptureProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "tool_one", + Function: &providers.FunctionCall{ + Name: "tool_one", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + { + ID: "call_2", + Type: "function", + Name: "tool_two", + Function: &providers.FunctionCall{ + Name: "tool_two", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "graceful summary", + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, provider) + al.RegisterTool(tool1) + al.RegisterTool(tool2) + sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + + sub := al.SubscribeEvents(32) + defer al.UnsubscribeEvents(sub.ID) + + type result struct { + resp string + err error + } + resultCh := make(chan result, 1) + go func() { + resp, err := al.ProcessDirectWithChannel( + context.Background(), + "do something", + sessionKey, + "test", + "chat1", + ) + resultCh <- result{resp: resp, err: err} + }() + + select { + case <-tool1ExecCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for tool_one to start") + } + + active := al.GetActiveTurn() + if active == nil { + t.Fatal("expected active turn while tool is running") + } + if active.SessionKey != sessionKey { + t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey) + } + if active.Channel != "test" || active.ChatID != "chat1" { + t.Fatalf("unexpected active turn target: %#v", active) + } + + if err := al.InterruptGraceful("wrap it up"); err != nil { + t.Fatalf("InterruptGraceful failed: %v", err) + } + + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("unexpected error: %v", r.err) + } + if r.resp != "graceful summary" { + t.Fatalf("expected graceful summary, got %q", r.resp) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for graceful interrupt result") + } + + if active := al.GetActiveTurn(); active != nil { + t.Fatalf("expected no active turn after completion, got %#v", active) + } + + provider.mu.Lock() + terminalMessages := append([]providers.Message(nil), provider.terminalMessages...) + terminalToolsCount := provider.terminalToolsCount + calls := provider.calls + provider.mu.Unlock() + + if calls != 2 { + t.Fatalf("expected 2 provider calls, got %d", calls) + } + if terminalToolsCount != 0 { + t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount) + } + + foundHint := false + foundSkipped := false + expectedHint := "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\n" + + "Interrupt hint: wrap it up" + for _, msg := range terminalMessages { + if msg.Role == "user" && msg.Content == expectedHint { + foundHint = true + } + if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." { + foundSkipped = true + } + } + if !foundHint { + t.Fatal("expected graceful terminal call to include interrupt hint message") + } + if !foundSkipped { + t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt") + } + + events := collectEventStream(sub.C) + interruptEvt, ok := findEvent(events, EventKindInterruptReceived) + if !ok { + t.Fatal("expected interrupt received event") + } + interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload) + if !ok { + t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload) + } + if interruptPayload.Kind != InterruptKindGraceful { + t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind) + } + + turnEndEvt, ok := findEvent(events, EventKindTurnEnd) + if !ok { + t.Fatal("expected turn end event") + } + turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload) + if !ok { + t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload) + } + if turnEndPayload.Status != TurnEndStatusCompleted { + t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status) + } +} + +func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &toolCallProvider{ + toolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "cancel_tool", + Function: &providers.FunctionCall{ + Name: "cancel_tool", + Arguments: "{}", + }, + Arguments: map[string]any{}, + }, + }, + finalResp: "should not happen", + } + + al := NewAgentLoop(cfg, msgBus, provider) + started := make(chan struct{}) + al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started}) + sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + + originalHistory := []providers.Message{ + {Role: "user", Content: "before"}, + {Role: "assistant", Content: "after"}, + } + defaultAgent.Sessions.SetHistory(sessionKey, originalHistory) + + sub := al.SubscribeEvents(16) + defer al.UnsubscribeEvents(sub.ID) + + type result struct { + resp string + err error + } + resultCh := make(chan result, 1) + go func() { + resp, err := al.ProcessDirectWithChannel( + context.Background(), + "do work", + sessionKey, + "test", + "chat1", + ) + resultCh <- result{resp: resp, err: err} + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for interruptible tool to start") + } + + if active := al.GetActiveTurn(); active == nil { + t.Fatal("expected active turn before hard abort") + } + + if err := al.InterruptHard(); err != nil { + t.Fatalf("InterruptHard failed: %v", err) + } + + select { + case r := <-resultCh: + if r.err != nil { + t.Fatalf("unexpected error: %v", r.err) + } + if r.resp != "" { + t.Fatalf("expected no final response after hard abort, got %q", r.resp) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for hard abort result") + } + + if active := al.GetActiveTurn(); active != nil { + t.Fatalf("expected no active turn after hard abort, got %#v", active) + } + + finalHistory := defaultAgent.Sessions.GetHistory(sessionKey) + if !reflect.DeepEqual(finalHistory, originalHistory) { + t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory) + } + + events := collectEventStream(sub.C) + interruptEvt, ok := findEvent(events, EventKindInterruptReceived) + if !ok { + t.Fatal("expected interrupt received event") + } + interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload) + if !ok { + t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload) + } + if interruptPayload.Kind != InterruptKindHard { + t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind) + } + + turnEndEvt, ok := findEvent(events, EventKindTurnEnd) + if !ok { + t.Fatal("expected turn end event") + } + turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload) + if !ok { + t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload) + } + if turnEndPayload.Status != TurnEndStatusAborted { + t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status) + } +} + // capturingMockProvider captures messages sent to Chat for inspection. type capturingMockProvider struct { response string diff --git a/pkg/agent/turn.go b/pkg/agent/turn.go new file mode 100644 index 000000000..358dae2b4 --- /dev/null +++ b/pkg/agent/turn.go @@ -0,0 +1,308 @@ +package agent + +import ( + "context" + "reflect" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type TurnPhase string + +const ( + TurnPhaseSetup TurnPhase = "setup" + TurnPhaseRunning TurnPhase = "running" + TurnPhaseTools TurnPhase = "tools" + TurnPhaseFinalizing TurnPhase = "finalizing" + TurnPhaseCompleted TurnPhase = "completed" + TurnPhaseAborted TurnPhase = "aborted" +) + +type ActiveTurnInfo struct { + TurnID string + AgentID string + SessionKey string + Channel string + ChatID string + UserMessage string + Phase TurnPhase + Iteration int + StartedAt time.Time +} + +type turnResult struct { + finalContent string + status TurnEndStatus + followUps []bus.InboundMessage +} + +type turnState struct { + mu sync.RWMutex + + agent *AgentInstance + opts processOptions + scope turnEventScope + + turnID string + agentID string + sessionKey string + + channel string + chatID string + userMessage string + media []string + + phase TurnPhase + iteration int + startedAt time.Time + finalContent string + + followUps []bus.InboundMessage + + gracefulInterrupt bool + gracefulInterruptHint string + gracefulTerminalUsed bool + hardAbort bool + providerCancel context.CancelFunc + turnCancel context.CancelFunc + + restorePointHistory []providers.Message + restorePointSummary string + persistedMessages []providers.Message +} + +func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState { + return &turnState{ + agent: agent, + opts: opts, + scope: scope, + turnID: scope.turnID, + agentID: agent.ID, + sessionKey: opts.SessionKey, + channel: opts.Channel, + chatID: opts.ChatID, + userMessage: opts.UserMessage, + media: append([]string(nil), opts.Media...), + phase: TurnPhaseSetup, + startedAt: time.Now(), + } +} + +func (al *AgentLoop) registerActiveTurn(ts *turnState) { + al.activeTurnMu.Lock() + defer al.activeTurnMu.Unlock() + al.activeTurn = ts +} + +func (al *AgentLoop) clearActiveTurn(ts *turnState) { + al.activeTurnMu.Lock() + defer al.activeTurnMu.Unlock() + if al.activeTurn == ts { + al.activeTurn = nil + } +} + +func (al *AgentLoop) getActiveTurnState() *turnState { + al.activeTurnMu.RLock() + defer al.activeTurnMu.RUnlock() + return al.activeTurn +} + +func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo { + ts := al.getActiveTurnState() + if ts == nil { + return nil + } + info := ts.snapshot() + return &info +} + +func (ts *turnState) snapshot() ActiveTurnInfo { + ts.mu.RLock() + defer ts.mu.RUnlock() + + return ActiveTurnInfo{ + TurnID: ts.turnID, + AgentID: ts.agentID, + SessionKey: ts.sessionKey, + Channel: ts.channel, + ChatID: ts.chatID, + UserMessage: ts.userMessage, + Phase: ts.phase, + Iteration: ts.iteration, + StartedAt: ts.startedAt, + } +} + +func (ts *turnState) setPhase(phase TurnPhase) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.phase = phase +} + +func (ts *turnState) setIteration(iteration int) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.iteration = iteration +} + +func (ts *turnState) currentIteration() int { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.iteration +} + +func (ts *turnState) setFinalContent(content string) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.finalContent = content +} + +func (ts *turnState) finalContentLen() int { + ts.mu.RLock() + defer ts.mu.RUnlock() + return len(ts.finalContent) +} + +func (ts *turnState) setTurnCancel(cancel context.CancelFunc) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.turnCancel = cancel +} + +func (ts *turnState) setProviderCancel(cancel context.CancelFunc) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.providerCancel = cancel +} + +func (ts *turnState) clearProviderCancel(_ context.CancelFunc) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.providerCancel = nil +} + +func (ts *turnState) requestGracefulInterrupt(hint string) bool { + ts.mu.Lock() + defer ts.mu.Unlock() + if ts.hardAbort { + return false + } + ts.gracefulInterrupt = true + ts.gracefulInterruptHint = hint + return true +} + +func (ts *turnState) gracefulInterruptRequested() (bool, string) { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint +} + +func (ts *turnState) markGracefulTerminalUsed() { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.gracefulTerminalUsed = true +} + +func (ts *turnState) requestHardAbort() bool { + ts.mu.Lock() + if ts.hardAbort { + ts.mu.Unlock() + return false + } + ts.hardAbort = true + turnCancel := ts.turnCancel + providerCancel := ts.providerCancel + ts.mu.Unlock() + + if providerCancel != nil { + providerCancel() + } + if turnCancel != nil { + turnCancel() + } + return true +} + +func (ts *turnState) hardAbortRequested() bool { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.hardAbort +} + +func (ts *turnState) eventMeta(source, tracePath string) EventMeta { + snap := ts.snapshot() + return EventMeta{ + AgentID: snap.AgentID, + TurnID: snap.TurnID, + SessionKey: snap.SessionKey, + Iteration: snap.Iteration, + Source: source, + TracePath: tracePath, + } +} + +func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.restorePointHistory = append([]providers.Message(nil), history...) + ts.restorePointSummary = summary +} + +func (ts *turnState) recordPersistedMessage(msg providers.Message) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.persistedMessages = append(ts.persistedMessages, msg) +} + +func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) { + history := agent.Sessions.GetHistory(ts.sessionKey) + summary := agent.Sessions.GetSummary(ts.sessionKey) + + ts.mu.RLock() + persisted := append([]providers.Message(nil), ts.persistedMessages...) + ts.mu.RUnlock() + + if matched := matchingTurnMessageTail(history, persisted); matched > 0 { + history = append([]providers.Message(nil), history[:len(history)-matched]...) + } + + ts.captureRestorePoint(history, summary) +} + +func (ts *turnState) restoreSession(agent *AgentInstance) error { + ts.mu.RLock() + history := append([]providers.Message(nil), ts.restorePointHistory...) + summary := ts.restorePointSummary + ts.mu.RUnlock() + + agent.Sessions.SetHistory(ts.sessionKey, history) + agent.Sessions.SetSummary(ts.sessionKey, summary) + return agent.Sessions.Save(ts.sessionKey) +} + +func matchingTurnMessageTail(history, persisted []providers.Message) int { + maxMatch := min(len(history), len(persisted)) + for size := maxMatch; size > 0; size-- { + if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) { + return size + } + } + return 0 +} + +func (ts *turnState) interruptHintMessage() providers.Message { + _, hint := ts.gracefulInterruptRequested() + content := "Interrupt requested. Stop scheduling tools and provide a short final summary." + if hint != "" { + content += "\n\nInterrupt hint: " + hint + } + return providers.Message{ + Role: "user", + Content: content, + } +}