mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(agent): Agent Looper refactor phase2, restructure pipeline and rename loop files to agent (#2585)
* refactor(agent): introduce interfaces for MessageBus and ChannelManager Phase 2 of loop.go refactor — dependency inversion using adapter pattern. - Add interfaces.MessageBus and interfaces.ChannelManager interfaces - Create adapters/messagebus.go wrapping *bus.MessageBus - Create adapters/channelmanager.go wrapping *channels.Manager - Update AgentLoop to use interfaces instead of concrete types - Update registerSharedTools to accept interfaces.MessageBus Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(agent): restructure pipeline and rename loop files Pipeline refactoring: - Split pipeline.go (1400 lines) into focused files: - pipeline_setup.go (~115 lines): SetupTurn method - pipeline_llm.go (~519 lines): CallLLM method - pipeline_execute.go (~693 lines): ExecuteTools method - pipeline_finalize.go (~78 lines): Finalize method - Pipeline struct and NewPipeline remain in pipeline.go (~39 lines) Agent file renaming: - Rename loop_*.go to agent_*.go for consistent naming: - loop.go -> agent.go, loop_message.go -> agent_message.go, etc. - Merge turn.go + turn_exec.go into turn_state.go - Rename loop_turn.go -> turn_coord.go Documentation: - Update docs/pipeline-restructuring-plan.md - Add docs/agent-rename-plan.md Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(agent): code format fixed * refactor(agent): code test file added/renamed * docs(agent): update agent refactor docs * fix(agent): fix agent hardAbortX --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent/interfaces"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
// channelManagerAdapter wraps *channels.Manager to implement interfaces.ChannelManager.
|
||||
type channelManagerAdapter struct {
|
||||
inner *channels.Manager
|
||||
}
|
||||
|
||||
// NewChannelManager creates an adapter for *channels.Manager.
|
||||
func NewChannelManager(inner *channels.Manager) interfaces.ChannelManager {
|
||||
return &channelManagerAdapter{inner: inner}
|
||||
}
|
||||
|
||||
func (a *channelManagerAdapter) GetChannel(name string) (channels.Channel, bool) {
|
||||
return a.inner.GetChannel(name)
|
||||
}
|
||||
|
||||
func (a *channelManagerAdapter) GetEnabledChannels() []string {
|
||||
return a.inner.GetEnabledChannels()
|
||||
}
|
||||
|
||||
func (a *channelManagerAdapter) InvokeTypingStop(channel, chatID string) {
|
||||
a.inner.InvokeTypingStop(channel, chatID)
|
||||
}
|
||||
|
||||
func (a *channelManagerAdapter) SendMessage(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return a.inner.SendMessage(ctx, msg)
|
||||
}
|
||||
|
||||
func (a *channelManagerAdapter) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
return a.inner.SendMedia(ctx, msg)
|
||||
}
|
||||
|
||||
func (a *channelManagerAdapter) SendPlaceholder(ctx context.Context, channel, chatID string) bool {
|
||||
return a.inner.SendPlaceholder(ctx, channel, chatID)
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent/interfaces"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
)
|
||||
|
||||
// messageBusAdapter wraps *bus.MessageBus to implement interfaces.MessageBus.
|
||||
type messageBusAdapter struct {
|
||||
inner *bus.MessageBus
|
||||
}
|
||||
|
||||
// NewMessageBus creates an adapter for *bus.MessageBus.
|
||||
func NewMessageBus(inner *bus.MessageBus) interfaces.MessageBus {
|
||||
return &messageBusAdapter{inner: inner}
|
||||
}
|
||||
|
||||
func (a *messageBusAdapter) PublishInbound(ctx context.Context, msg bus.InboundMessage) error {
|
||||
return a.inner.PublishInbound(ctx, msg)
|
||||
}
|
||||
|
||||
func (a *messageBusAdapter) PublishOutbound(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return a.inner.PublishOutbound(ctx, msg)
|
||||
}
|
||||
|
||||
func (a *messageBusAdapter) PublishOutboundMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
return a.inner.PublishOutboundMedia(ctx, msg)
|
||||
}
|
||||
|
||||
func (a *messageBusAdapter) InboundChan() <-chan bus.InboundMessage {
|
||||
return a.inner.InboundChan()
|
||||
}
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent/interfaces"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/asr"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
@@ -32,7 +32,7 @@ import (
|
||||
|
||||
type AgentLoop struct {
|
||||
// Core dependencies
|
||||
bus *bus.MessageBus
|
||||
bus interfaces.MessageBus
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
@@ -45,7 +45,7 @@ type AgentLoop struct {
|
||||
running atomic.Bool
|
||||
contextManager ContextManager
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
channelManager interfaces.ChannelManager
|
||||
mediaStore media.MediaStore
|
||||
transcriber asr.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
@@ -495,7 +495,8 @@ func (al *AgentLoop) runAgentLoop(
|
||||
newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope),
|
||||
)
|
||||
ts := newTurnState(agent, opts, turnScope)
|
||||
result, err := al.runTurn(ctx, ts)
|
||||
pipeline := NewPipeline(al)
|
||||
result, err := al.runTurn(ctx, ts, pipeline)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -48,24 +48,6 @@ func (al *AgentLoop) emitEvent(kind EventKind, meta EventMeta, payload any) {
|
||||
al.eventBus.Emit(evt)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) hookAbortError(ts *turnState, stage string, decision HookDecision) error {
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
|
||||
err := fmt.Errorf("hook aborted turn during %s: %s", stage, reason)
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("hooks", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "hook." + stage,
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (al *AgentLoop) logEvent(evt Event) {
|
||||
fields := map[string]any{
|
||||
"event_kind": evt.Kind.String(),
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent/interfaces"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -79,7 +80,7 @@ func NewAgentLoop(
|
||||
func registerSharedTools(
|
||||
al *AgentLoop,
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
msgBus interfaces.MessageBus,
|
||||
registry *AgentRegistry,
|
||||
provider providers.LLMProvider,
|
||||
) {
|
||||
@@ -709,9 +709,10 @@ func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
al.channelManager = newStartedTestChannelManager(t, al.bus, al.mediaStore, "discord", &errorMediaChannel{
|
||||
sendErr: errors.New("channel unavailable"),
|
||||
})
|
||||
al.channelManager = newStartedTestChannelManager(t,
|
||||
al.bus.(*bus.MessageBus), al.mediaStore, "discord", &errorMediaChannel{
|
||||
sendErr: errors.New("channel unavailable"),
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
// MessageBus publishes inbound and outbound messages.
|
||||
// It is the primary communication channel for the agent loop.
|
||||
type MessageBus interface {
|
||||
// PublishInbound sends an inbound message to be processed.
|
||||
PublishInbound(ctx context.Context, msg bus.InboundMessage) error
|
||||
|
||||
// PublishOutbound sends an outbound message to the appropriate channel.
|
||||
PublishOutbound(ctx context.Context, msg bus.OutboundMessage) error
|
||||
|
||||
// PublishOutboundMedia sends an outbound media message.
|
||||
PublishOutboundMedia(ctx context.Context, msg bus.OutboundMediaMessage) error
|
||||
|
||||
// InboundChan returns the channel for receiving inbound messages.
|
||||
InboundChan() <-chan bus.InboundMessage
|
||||
}
|
||||
|
||||
// ChannelManager manages channel lifecycle and provides channel access.
|
||||
type ChannelManager interface {
|
||||
// GetChannel returns the channel with the given name.
|
||||
GetChannel(name string) (channels.Channel, bool)
|
||||
|
||||
// GetEnabledChannels returns the list of enabled channel names.
|
||||
GetEnabledChannels() []string
|
||||
|
||||
// InvokeTypingStop signals that typing has stopped.
|
||||
InvokeTypingStop(channel, chatID string)
|
||||
|
||||
// SendMessage sends a text message to the specified channel and chat.
|
||||
SendMessage(ctx context.Context, msg bus.OutboundMessage) error
|
||||
|
||||
// SendMedia sends a media message to the specified channel and chat.
|
||||
SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error
|
||||
|
||||
// SendPlaceholder sends a placeholder message (e.g., for audio transcription).
|
||||
SendPlaceholder(ctx context.Context, channel, chatID string) bool
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,40 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/agent/interfaces"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// Pipeline holds the runtime dependencies used by Pipeline methods.
|
||||
// It is constructed by runTurn via NewPipeline and passed to sub-methods
|
||||
// so that the coordinator can delegate phase execution.
|
||||
type Pipeline struct {
|
||||
Bus interfaces.MessageBus
|
||||
Cfg *config.Config
|
||||
ContextManager ContextManager
|
||||
Hooks *HookManager
|
||||
Fallback *providers.FallbackChain
|
||||
ChannelManager interfaces.ChannelManager
|
||||
MediaStore media.MediaStore
|
||||
Steering any // TODO: *Steering
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
// NewPipeline creates a Pipeline from an AgentLoop instance.
|
||||
func NewPipeline(al *AgentLoop) *Pipeline {
|
||||
return &Pipeline{
|
||||
Bus: al.bus,
|
||||
Cfg: al.GetConfig(),
|
||||
ContextManager: al.contextManager,
|
||||
Hooks: al.hooks,
|
||||
Fallback: al.fallback,
|
||||
ChannelManager: al.channelManager,
|
||||
MediaStore: al.mediaStore,
|
||||
Steering: al.steering,
|
||||
al: al,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,700 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// ExecuteTools executes the tool loop, handling BeforeTool/ApproveTool/AfterTool hooks,
|
||||
// tool execution with async callbacks, media delivery, and steering injection.
|
||||
// Returns ToolControl indicating what the coordinator should do next:
|
||||
// - ToolControlContinue: all tool results handled, pendingMessages or steering exists, continue turn
|
||||
// - ToolControlBreak: tool loop exited, proceed to coordinator's hardAbort/finalContent/finalize
|
||||
func (p *Pipeline) ExecuteTools(
|
||||
ctx context.Context,
|
||||
turnCtx context.Context,
|
||||
ts *turnState,
|
||||
exec *turnExecution,
|
||||
iteration int,
|
||||
) ToolControl {
|
||||
al := p.al
|
||||
normalizedToolCalls := exec.normalizedToolCalls
|
||||
|
||||
ts.setPhase(TurnPhaseTools)
|
||||
messages := exec.messages
|
||||
|
||||
toolLoop:
|
||||
for i, tc := range normalizedToolCalls {
|
||||
if ts.hardAbortRequested() {
|
||||
exec.abortedByHardAbort = true
|
||||
return ToolControlBreak
|
||||
}
|
||||
|
||||
toolName := tc.Name
|
||||
toolArgs := cloneStringAnyMap(tc.Arguments)
|
||||
|
||||
if al.hooks != nil {
|
||||
toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.before"),
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolReq != nil {
|
||||
toolName = toolReq.Tool
|
||||
toolArgs = toolReq.Arguments
|
||||
}
|
||||
case HookActionRespond:
|
||||
if toolReq != nil && toolReq.HookResult != nil {
|
||||
hookResult := toolReq.HookResult
|
||||
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call (hook respond): %s(%s)", toolName, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: toolName,
|
||||
Arguments: cloneEventArguments(toolArgs),
|
||||
},
|
||||
)
|
||||
|
||||
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() &&
|
||||
ts.channel != "" &&
|
||||
!ts.opts.SuppressToolFeedback {
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
feedbackPreview := utils.Truncate(
|
||||
string(argsJSON),
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
feedbackMsg := utils.FormatToolFeedbackMessage(toolName, feedbackPreview)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: feedbackMsg,
|
||||
})
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
toolDuration := time.Duration(0)
|
||||
|
||||
shouldSendForUser := !hookResult.Silent && hookResult.ForUser != "" &&
|
||||
(ts.opts.SendResponse || hookResult.ResponseHandled)
|
||||
if shouldSendForUser {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Raw: map[string]string{
|
||||
"is_tool_call": "true",
|
||||
},
|
||||
},
|
||||
Content: hookResult.ForUser,
|
||||
})
|
||||
}
|
||||
|
||||
if len(hookResult.Media) > 0 && hookResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(hookResult.Media))
|
||||
for _, ref := range hookResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
}
|
||||
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
|
||||
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
|
||||
logger.WarnCF("agent", "Failed to deliver hook media",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
hookResult.IsError = true
|
||||
hookResult.ForLLM = fmt.Sprintf("failed to deliver attachment: %v", err)
|
||||
}
|
||||
} else if al.bus != nil {
|
||||
al.bus.PublishOutboundMedia(ctx, outboundMedia)
|
||||
hookResult.ResponseHandled = false
|
||||
}
|
||||
}
|
||||
|
||||
if !hookResult.ResponseHandled {
|
||||
exec.allResponsesHandled = false
|
||||
}
|
||||
|
||||
contentForLLM := hookResult.ContentForLLM()
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
|
||||
if len(hookResult.Media) > 0 && !hookResult.ResponseHandled {
|
||||
hookResult.ArtifactTags = buildArtifactTags(al.mediaStore, hookResult.Media)
|
||||
contentForLLM = hookResult.ContentForLLM()
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
|
||||
}
|
||||
toolResultMsg.Content = contentForLLM
|
||||
toolResultMsg.Media = append(toolResultMsg.Media, hookResult.Media...)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolName,
|
||||
Duration: toolDuration,
|
||||
ForLLMLen: len(contentForLLM),
|
||||
ForUserLen: len(hookResult.ForUser),
|
||||
IsError: hookResult.IsError,
|
||||
Async: hookResult.Async,
|
||||
},
|
||||
)
|
||||
|
||||
messages = append(messages, toolResultMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
|
||||
ts.recordPersistedMessage(toolResultMsg)
|
||||
ts.ingestMessage(turnCtx, al, toolResultMsg)
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
exec.pendingMessages = append(exec.pendingMessages, steerMsgs...)
|
||||
}
|
||||
|
||||
skipReason := ""
|
||||
skipMessage := ""
|
||||
if len(exec.pendingMessages) > 0 {
|
||||
skipReason = "queued user steering message"
|
||||
skipMessage = "Skipped due to queued user message."
|
||||
} else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending {
|
||||
skipReason = "graceful interrupt requested"
|
||||
skipMessage = "Skipped due to graceful interrupt."
|
||||
}
|
||||
|
||||
if skipReason != "" {
|
||||
remaining := len(normalizedToolCalls) - i - 1
|
||||
if remaining > 0 {
|
||||
logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools after hook respond",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"completed": i + 1,
|
||||
"skipped": remaining,
|
||||
"reason": skipReason,
|
||||
})
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: skippedTC.Name,
|
||||
Reason: skipReason,
|
||||
},
|
||||
)
|
||||
skippedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: skipMessage,
|
||||
ToolCallID: skippedTC.ID,
|
||||
}
|
||||
messages = append(messages, skippedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg)
|
||||
ts.recordPersistedMessage(skippedMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
break toolLoop
|
||||
}
|
||||
|
||||
if ts.pendingResults != nil {
|
||||
select {
|
||||
case result, ok := <-ts.pendingResults:
|
||||
if ok && result != nil && result.ForLLM != "" {
|
||||
content := al.cfg.FilterSensitiveData(result.ForLLM)
|
||||
msg := providers.Message{Role: "user", Content: fmt.Sprintf("[SubTurn Result] %s", content)}
|
||||
messages = append(messages, msg)
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, msg)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
logger.WarnCF("agent", "Hook returned respond action but no HookResult provided",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"action": "respond",
|
||||
})
|
||||
case HookActionDenyTool:
|
||||
exec.allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
case HookActionAbortTurn:
|
||||
exec.abortedByHook = true
|
||||
return ToolControlBreak
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
exec.abortedByHardAbort = true
|
||||
return ToolControlBreak
|
||||
}
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.approve"),
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
})
|
||||
if !approval.Approved {
|
||||
exec.allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", toolName, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"iteration": iteration,
|
||||
})
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: toolName,
|
||||
Arguments: cloneEventArguments(toolArgs),
|
||||
},
|
||||
)
|
||||
|
||||
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() &&
|
||||
ts.channel != "" &&
|
||||
!ts.opts.SuppressToolFeedback {
|
||||
feedbackPreview := utils.Truncate(
|
||||
string(argsJSON),
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
feedbackMsg := utils.FormatToolFeedbackMessage(tc.Name, feedbackPreview)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurn(ts, feedbackMsg))
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
toolCallID := tc.ID
|
||||
asyncToolName := toolName
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer outCancel()
|
||||
_ = al.bus.PublishOutbound(outCtx, outboundMessageForTurn(ts, result.ForUser))
|
||||
}
|
||||
|
||||
content := result.ContentForLLM()
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
content = al.cfg.FilterSensitiveData(content)
|
||||
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": asyncToolName,
|
||||
"content_len": len(content),
|
||||
"channel": ts.channel,
|
||||
})
|
||||
al.emitEvent(
|
||||
EventKindFollowUpQueued,
|
||||
ts.scope.meta(iteration, "runTurn", "turn.follow_up.queued"),
|
||||
FollowUpQueuedPayload{
|
||||
SourceTool: asyncToolName,
|
||||
ContentLen: len(content),
|
||||
},
|
||||
)
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "system",
|
||||
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
|
||||
ChatType: "direct",
|
||||
SenderID: fmt.Sprintf("async:%s", asyncToolName),
|
||||
},
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
toolStart := time.Now()
|
||||
execCtx := tools.WithToolInboundContext(
|
||||
turnCtx,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
ts.opts.Dispatch.MessageID(),
|
||||
ts.opts.Dispatch.ReplyToMessageID(),
|
||||
)
|
||||
execCtx = tools.WithToolSessionContext(
|
||||
execCtx,
|
||||
ts.agent.ID,
|
||||
ts.sessionKey,
|
||||
ts.opts.Dispatch.SessionScope,
|
||||
)
|
||||
toolResult := ts.agent.Tools.ExecuteWithContext(
|
||||
execCtx,
|
||||
toolName,
|
||||
toolArgs,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
asyncCallback,
|
||||
)
|
||||
toolDuration := time.Since(toolStart)
|
||||
|
||||
if ts.hardAbortRequested() {
|
||||
exec.abortedByHardAbort = true
|
||||
return ToolControlBreak
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.after"),
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Result: toolResult,
|
||||
Duration: toolDuration,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolResp != nil {
|
||||
if toolResp.Tool != "" {
|
||||
toolName = toolResp.Tool
|
||||
}
|
||||
if toolResp.Result != nil {
|
||||
toolResult = toolResp.Result
|
||||
}
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
exec.abortedByHook = true
|
||||
return ToolControlBreak
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
exec.abortedByHardAbort = true
|
||||
return ToolControlBreak
|
||||
}
|
||||
}
|
||||
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
|
||||
for _, ref := range toolResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Context: outboundContextFromInbound(
|
||||
ts.opts.Dispatch.InboundContext,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
ts.opts.Dispatch.ReplyToMessageID(),
|
||||
),
|
||||
AgentID: ts.agent.ID,
|
||||
SessionKey: ts.sessionKey,
|
||||
Scope: outboundScopeFromSessionScope(ts.opts.Dispatch.SessionScope),
|
||||
Parts: parts,
|
||||
}
|
||||
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
|
||||
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
|
||||
logger.WarnCF("agent", "Failed to deliver handled tool media",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
toolResult = tools.ErrorResult(fmt.Sprintf("failed to deliver attachment: %v", err)).WithError(err)
|
||||
}
|
||||
} else if al.bus != nil {
|
||||
al.bus.PublishOutboundMedia(ctx, outboundMedia)
|
||||
toolResult.ResponseHandled = false
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
|
||||
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
|
||||
}
|
||||
|
||||
if !toolResult.ResponseHandled {
|
||||
exec.allResponsesHandled = false
|
||||
}
|
||||
|
||||
shouldSendForUser := !toolResult.Silent &&
|
||||
toolResult.ForUser != "" &&
|
||||
(ts.opts.SendResponse || toolResult.ResponseHandled)
|
||||
if shouldSendForUser {
|
||||
al.bus.PublishOutbound(ctx, outboundMessageForTurn(ts, toolResult.ForUser))
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": toolName,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
contentForLLM := toolResult.ContentForLLM()
|
||||
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: toolCallID,
|
||||
}
|
||||
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
|
||||
toolResultMsg.Media = append(toolResultMsg.Media, toolResult.Media...)
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolName,
|
||||
Duration: toolDuration,
|
||||
ForLLMLen: len(contentForLLM),
|
||||
ForUserLen: len(toolResult.ForUser),
|
||||
IsError: toolResult.IsError,
|
||||
Async: toolResult.Async,
|
||||
},
|
||||
)
|
||||
messages = append(messages, toolResultMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
|
||||
ts.recordPersistedMessage(toolResultMsg)
|
||||
ts.ingestMessage(turnCtx, al, toolResultMsg)
|
||||
}
|
||||
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
exec.pendingMessages = append(exec.pendingMessages, steerMsgs...)
|
||||
}
|
||||
|
||||
skipReason := ""
|
||||
skipMessage := ""
|
||||
if len(exec.pendingMessages) > 0 {
|
||||
skipReason = "queued user steering message"
|
||||
skipMessage = "Skipped due to queued user message."
|
||||
} else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending {
|
||||
skipReason = "graceful interrupt requested"
|
||||
skipMessage = "Skipped due to graceful interrupt."
|
||||
}
|
||||
|
||||
if skipReason != "" {
|
||||
remaining := len(normalizedToolCalls) - i - 1
|
||||
if remaining > 0 {
|
||||
logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"completed": i + 1,
|
||||
"skipped": remaining,
|
||||
"reason": skipReason,
|
||||
})
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: skippedTC.Name,
|
||||
Reason: skipReason,
|
||||
},
|
||||
)
|
||||
skippedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: skipMessage,
|
||||
ToolCallID: skippedTC.ID,
|
||||
}
|
||||
messages = append(messages, skippedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg)
|
||||
ts.recordPersistedMessage(skippedMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
break toolLoop
|
||||
}
|
||||
|
||||
if ts.pendingResults != nil {
|
||||
select {
|
||||
case result, ok := <-ts.pendingResults:
|
||||
if ok && result != nil && result.ForLLM != "" {
|
||||
content := al.cfg.FilterSensitiveData(result.ForLLM)
|
||||
msg := providers.Message{Role: "user", Content: fmt.Sprintf("[SubTurn Result] %s", content)}
|
||||
messages = append(messages, msg)
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, msg)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
exec.messages = messages
|
||||
|
||||
// Continue if pending steering exists (regardless of allResponsesHandled).
|
||||
// This covers the case where tools were partially executed and skipped due to steering,
|
||||
// but one tool had ResponseHandled=false (so allResponsesHandled=false).
|
||||
if len(exec.pendingMessages) > 0 {
|
||||
logger.InfoCF("agent", "Pending steering after partial tool execution; continuing turn",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"pending_count": len(exec.pendingMessages),
|
||||
"allResponsesHandled": exec.allResponsesHandled,
|
||||
})
|
||||
exec.allResponsesHandled = false
|
||||
return ToolControlContinue
|
||||
}
|
||||
|
||||
// Poll for newly arrived steering
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
logger.InfoCF("agent", "Steering arrived after tool delivery; continuing turn",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"steering_count": len(steerMsgs),
|
||||
})
|
||||
exec.pendingMessages = append(exec.pendingMessages, steerMsgs...)
|
||||
exec.allResponsesHandled = false
|
||||
return ToolControlContinue
|
||||
}
|
||||
|
||||
// No pending steering: finalize or break depending on allResponsesHandled
|
||||
if exec.allResponsesHandled {
|
||||
summaryMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: handledToolResponseSummary,
|
||||
}
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content)
|
||||
ts.recordPersistedMessage(summaryMsg)
|
||||
ts.ingestMessage(turnCtx, al, summaryMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to save session after tool delivery",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
if ts.opts.EnableSummary {
|
||||
al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
})
|
||||
}
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
ts.setFinalContent("")
|
||||
logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"tool_count": len(normalizedToolCalls),
|
||||
})
|
||||
return ToolControlBreak
|
||||
}
|
||||
|
||||
// allResponsesHandled=false and no pending steering: continue so coordinator
|
||||
// makes another LLM call. The tool result is in messages and the LLM will
|
||||
// return it as finalContent in the next iteration.
|
||||
ts.agent.Tools.TickTTL()
|
||||
logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{
|
||||
"agent_id": ts.agent.ID, "iteration": iteration,
|
||||
})
|
||||
return ToolControlContinue
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// Finalize handles turn finalization, either:
|
||||
// - Early return when allResponsesHandled=true (ExecuteTools already finalized)
|
||||
// - Normal finalization for allResponsesHandled=false (sets finalContent, saves session, compact)
|
||||
func (p *Pipeline) Finalize(
|
||||
ctx context.Context,
|
||||
turnCtx context.Context,
|
||||
ts *turnState,
|
||||
exec *turnExecution,
|
||||
turnStatus TurnEndStatus,
|
||||
finalContent string,
|
||||
) (turnResult, error) {
|
||||
al := p.al
|
||||
|
||||
// When allResponsesHandled=true, ExecuteTools already finalized
|
||||
// (added handledToolResponseSummary, saved session, set phase to Completed).
|
||||
// But still check for hard abort - if requested, abort the turn.
|
||||
if exec.allResponsesHandled {
|
||||
if ts.hardAbortRequested() {
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
return turnResult{
|
||||
finalContent: finalContent,
|
||||
status: turnStatus,
|
||||
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseFinalizing)
|
||||
ts.setFinalContent(finalContent)
|
||||
if !ts.opts.NoHistory {
|
||||
finalMsg := providers.Message{Role: "assistant", Content: finalContent}
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content)
|
||||
ts.recordPersistedMessage(finalMsg)
|
||||
ts.ingestMessage(turnCtx, al, finalMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("runTurn", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "session_save",
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
return turnResult{status: TurnEndStatusError}, err
|
||||
}
|
||||
}
|
||||
|
||||
if ts.opts.EnableSummary {
|
||||
al.contextManager.Compact(
|
||||
turnCtx,
|
||||
&CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
return turnResult{
|
||||
finalContent: finalContent,
|
||||
status: turnStatus,
|
||||
followUps: append([]bus.InboundMessage(nil), ts.followUps...),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,525 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// CallLLM performs an LLM call with fallback support, hook invocation, and retry logic.
|
||||
// It handles PreLLM setup, the actual LLM invocation with retry, and AfterLLM processing.
|
||||
// Returns Control indicating what the coordinator should do next.
|
||||
func (p *Pipeline) CallLLM(
|
||||
ctx context.Context,
|
||||
turnCtx context.Context,
|
||||
ts *turnState,
|
||||
exec *turnExecution,
|
||||
iteration int,
|
||||
) (Control, error) {
|
||||
al := p.al
|
||||
maxMediaSize := p.Cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
|
||||
// PreLLM: resolve media refs (except on iteration 1 where user media is already resolved)
|
||||
if iteration > 1 {
|
||||
exec.messages = resolveMediaRefs(exec.messages, p.MediaStore, maxMediaSize)
|
||||
}
|
||||
|
||||
// PreLLM: graceful terminal handling
|
||||
exec.gracefulTerminal, _ = ts.gracefulInterruptRequested()
|
||||
exec.providerToolDefs = ts.agent.Tools.ToProviderDefs()
|
||||
|
||||
// Native web search support
|
||||
_, hasWebSearch := ts.agent.Tools.Get("web_search")
|
||||
exec.useNativeSearch = al.cfg.Tools.Web.PreferNative && hasWebSearch &&
|
||||
func() bool {
|
||||
if ns, ok := ts.agent.Provider.(interface{ SupportsNativeSearch() bool }); ok {
|
||||
return ns.SupportsNativeSearch()
|
||||
}
|
||||
return false
|
||||
}()
|
||||
|
||||
if exec.useNativeSearch {
|
||||
filtered := make([]providers.ToolDefinition, 0, len(exec.providerToolDefs))
|
||||
for _, td := range exec.providerToolDefs {
|
||||
if td.Function.Name != "web_search" {
|
||||
filtered = append(filtered, td)
|
||||
}
|
||||
}
|
||||
exec.providerToolDefs = filtered
|
||||
}
|
||||
|
||||
exec.callMessages = exec.messages
|
||||
if exec.gracefulTerminal {
|
||||
exec.callMessages = append(append([]providers.Message(nil), exec.messages...), ts.interruptHintMessage())
|
||||
exec.providerToolDefs = nil
|
||||
ts.markGracefulTerminalUsed()
|
||||
}
|
||||
|
||||
exec.llmOpts = map[string]any{
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"prompt_cache_key": ts.agent.ID,
|
||||
}
|
||||
if exec.useNativeSearch {
|
||||
exec.llmOpts["native_search"] = true
|
||||
}
|
||||
if ts.agent.ThinkingLevel != ThinkingOff {
|
||||
if tc, ok := ts.agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
|
||||
exec.llmOpts["thinking_level"] = string(ts.agent.ThinkingLevel)
|
||||
} else {
|
||||
logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring",
|
||||
map[string]any{"agent_id": ts.agent.ID, "thinking_level": string(ts.agent.ThinkingLevel)})
|
||||
}
|
||||
}
|
||||
|
||||
exec.llmModel = exec.activeModel
|
||||
|
||||
// BeforeLLM hook
|
||||
if p.Hooks != nil {
|
||||
llmReq, decision := p.Hooks.BeforeLLM(turnCtx, &LLMHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Model: exec.llmModel,
|
||||
Messages: exec.callMessages,
|
||||
Tools: exec.providerToolDefs,
|
||||
Options: exec.llmOpts,
|
||||
GracefulTerminal: exec.gracefulTerminal,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmReq != nil {
|
||||
exec.llmModel = llmReq.Model
|
||||
exec.callMessages = llmReq.Messages
|
||||
exec.providerToolDefs = llmReq.Tools
|
||||
exec.llmOpts = llmReq.Options
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
exec.abortedByHook = true
|
||||
return ControlBreak, nil
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
exec.abortedByHardAbort = true
|
||||
return ControlBreak, nil
|
||||
}
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: exec.llmModel,
|
||||
MessagesCount: len(exec.callMessages),
|
||||
ToolsCount: len(exec.providerToolDefs),
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
Temperature: ts.agent.Temperature,
|
||||
},
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": exec.llmModel,
|
||||
"messages_count": len(exec.callMessages),
|
||||
"tools_count": len(exec.providerToolDefs),
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"system_prompt_len": len(exec.callMessages[0].Content),
|
||||
})
|
||||
logger.DebugCF("agent", "Full LLM request",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"messages_json": formatMessagesForLog(exec.callMessages),
|
||||
"tools_json": formatToolsForLog(exec.providerToolDefs),
|
||||
})
|
||||
|
||||
// LLM call closure with fallback support
|
||||
callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) {
|
||||
providerCtx, providerCancel := context.WithCancel(turnCtx)
|
||||
ts.setProviderCancel(providerCancel)
|
||||
defer func() {
|
||||
providerCancel()
|
||||
ts.clearProviderCancel(providerCancel)
|
||||
}()
|
||||
|
||||
al.activeRequests.Add(1)
|
||||
defer al.activeRequests.Done()
|
||||
|
||||
if len(exec.activeCandidates) > 1 && p.Fallback != nil {
|
||||
fbResult, fbErr := p.Fallback.Execute(
|
||||
providerCtx,
|
||||
exec.activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
candidateProvider := exec.activeProvider
|
||||
if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
|
||||
candidateProvider = cp
|
||||
}
|
||||
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, exec.llmOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": ts.agent.ID, "iteration": iteration},
|
||||
)
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return exec.activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, exec.llmModel, exec.llmOpts)
|
||||
}
|
||||
|
||||
// Retry loop
|
||||
var err error
|
||||
maxRetries := 2
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
exec.response, err = callLLM(exec.callMessages, exec.providerToolDefs)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if ts.hardAbortRequested() && errors.Is(err, context.Canceled) {
|
||||
_ = ts.requestHardAbort()
|
||||
exec.abortedByHardAbort = true
|
||||
return ControlBreak, nil
|
||||
}
|
||||
|
||||
// Retry without media if vision is unsupported
|
||||
if hasMediaRefs(exec.callMessages) && isVisionUnsupportedError(err) && retry < maxRetries {
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
MaxRetries: maxRetries,
|
||||
Reason: "vision_unsupported",
|
||||
Error: err.Error(),
|
||||
Backoff: 0,
|
||||
},
|
||||
)
|
||||
logger.WarnCF("agent", "Vision unsupported, retrying without media", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
exec.callMessages = stripMessageMedia(exec.callMessages)
|
||||
if !ts.opts.NoHistory {
|
||||
exec.history = stripMessageMedia(exec.history)
|
||||
ts.agent.Sessions.SetHistory(ts.sessionKey, exec.history)
|
||||
for i := range ts.persistedMessages {
|
||||
ts.persistedMessages[i].Media = nil
|
||||
}
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
|
||||
strings.Contains(errMsg, "deadline exceeded") ||
|
||||
strings.Contains(errMsg, "client.timeout") ||
|
||||
strings.Contains(errMsg, "timed out") ||
|
||||
strings.Contains(errMsg, "timeout exceeded")
|
||||
|
||||
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
|
||||
strings.Contains(errMsg, "context window") ||
|
||||
strings.Contains(errMsg, "context_window") ||
|
||||
strings.Contains(errMsg, "maximum context length") ||
|
||||
strings.Contains(errMsg, "token limit") ||
|
||||
strings.Contains(errMsg, "too many tokens") ||
|
||||
strings.Contains(errMsg, "max_tokens") ||
|
||||
strings.Contains(errMsg, "invalidparameter") ||
|
||||
strings.Contains(errMsg, "prompt is too long") ||
|
||||
strings.Contains(errMsg, "request too large"))
|
||||
|
||||
if isTimeoutError && retry < maxRetries {
|
||||
backoff := time.Duration(retry+1) * 5 * time.Second
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
MaxRetries: maxRetries,
|
||||
Reason: "timeout",
|
||||
Error: err.Error(),
|
||||
Backoff: backoff,
|
||||
},
|
||||
)
|
||||
logger.WarnCF("agent", "Timeout error, retrying after backoff", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
"backoff": backoff.String(),
|
||||
})
|
||||
if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil {
|
||||
if ts.hardAbortRequested() {
|
||||
_ = ts.requestHardAbort()
|
||||
return ControlBreak, nil
|
||||
}
|
||||
err = sleepErr
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isContextError && retry < maxRetries && !ts.opts.NoHistory {
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
MaxRetries: maxRetries,
|
||||
Reason: "context_limit",
|
||||
Error: err.Error(),
|
||||
},
|
||||
)
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Context window error detected, attempting compression",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
},
|
||||
)
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(ts.channel) {
|
||||
al.bus.PublishOutbound(ctx, outboundMessageForTurn(
|
||||
ts,
|
||||
"Context window exceeded. Compressing history and retrying...",
|
||||
))
|
||||
}
|
||||
|
||||
if compactErr := p.ContextManager.Compact(ctx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonRetry,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); compactErr != nil {
|
||||
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": compactErr.Error(),
|
||||
})
|
||||
}
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
if asmResp, asmErr := p.ContextManager.Assemble(ctx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); asmErr == nil && asmResp != nil {
|
||||
exec.history = asmResp.History
|
||||
exec.summary = asmResp.Summary
|
||||
}
|
||||
exec.messages = ts.agent.ContextBuilder.BuildMessages(
|
||||
exec.history, exec.summary, "",
|
||||
nil, ts.channel, ts.chatID, ts.opts.Dispatch.SenderID(), ts.opts.SenderDisplayName,
|
||||
activeSkillNames(ts.agent, ts.opts)...,
|
||||
)
|
||||
exec.callMessages = exec.messages
|
||||
if exec.gracefulTerminal {
|
||||
msgs := append([]providers.Message(nil), exec.messages...)
|
||||
exec.callMessages = append(msgs, ts.interruptHintMessage())
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("runTurn", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "llm",
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
logger.ErrorCF("agent", "LLM call failed",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": exec.llmModel,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ControlBreak, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
// AfterLLM hook
|
||||
if p.Hooks != nil {
|
||||
llmResp, decision := p.Hooks.AfterLLM(turnCtx, &LLMHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.response"),
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Model: exec.llmModel,
|
||||
Response: exec.response,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmResp != nil && llmResp.Response != nil {
|
||||
exec.response = llmResp.Response
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
exec.abortedByHook = true
|
||||
return ControlBreak, nil
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
exec.abortedByHardAbort = true
|
||||
return ControlBreak, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Save finishReason to turnState for SubTurn truncation detection
|
||||
if innerTS := turnStateFromContext(ctx); innerTS != nil {
|
||||
innerTS.SetLastFinishReason(exec.response.FinishReason)
|
||||
if exec.response.Usage != nil {
|
||||
innerTS.SetLastUsage(exec.response.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
reasoningContent := exec.response.Reasoning
|
||||
if reasoningContent == "" {
|
||||
reasoningContent = exec.response.ReasoningContent
|
||||
}
|
||||
if ts.channel == "pico" {
|
||||
go al.publishPicoReasoning(turnCtx, reasoningContent, ts.chatID)
|
||||
} else {
|
||||
go al.handleReasoning(
|
||||
turnCtx,
|
||||
reasoningContent,
|
||||
ts.channel,
|
||||
al.targetReasoningChannelID(ts.channel),
|
||||
)
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindLLMResponse,
|
||||
ts.eventMeta("runTurn", "turn.llm.response"),
|
||||
LLMResponsePayload{
|
||||
ContentLen: len(exec.response.Content),
|
||||
ToolCalls: len(exec.response.ToolCalls),
|
||||
HasReasoning: exec.response.Reasoning != "" || exec.response.ReasoningContent != "",
|
||||
},
|
||||
)
|
||||
|
||||
llmResponseFields := map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_chars": len(exec.response.Content),
|
||||
"tool_calls": len(exec.response.ToolCalls),
|
||||
"reasoning": exec.response.Reasoning,
|
||||
"target_channel": al.targetReasoningChannelID(ts.channel),
|
||||
"channel": ts.channel,
|
||||
}
|
||||
if exec.response.Usage != nil {
|
||||
llmResponseFields["prompt_tokens"] = exec.response.Usage.PromptTokens
|
||||
llmResponseFields["completion_tokens"] = exec.response.Usage.CompletionTokens
|
||||
llmResponseFields["total_tokens"] = exec.response.Usage.TotalTokens
|
||||
}
|
||||
logger.DebugCF("agent", "LLM response", llmResponseFields)
|
||||
|
||||
if al.bus != nil && ts.channel == "pico" && len(exec.response.ToolCalls) > 0 && ts.opts.AllowInterimPicoPublish {
|
||||
if strings.TrimSpace(exec.response.Content) != "" {
|
||||
outCtx, outCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
publishErr := al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: exec.response.Content,
|
||||
})
|
||||
outCancel()
|
||||
if publishErr != nil {
|
||||
logger.WarnCF("agent", "Failed to publish pico interim tool-call content", map[string]any{
|
||||
"error": publishErr.Error(),
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"iteration": iteration,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No-tool-call path: steering check and direct response
|
||||
if len(exec.response.ToolCalls) == 0 || exec.gracefulTerminal {
|
||||
responseContent := exec.response.Content
|
||||
if responseContent == "" && exec.response.ReasoningContent != "" && ts.channel != "pico" {
|
||||
responseContent = exec.response.ReasoningContent
|
||||
}
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
logger.InfoCF("agent", "Steering arrived after direct LLM response; continuing turn",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"steering_count": len(steerMsgs),
|
||||
})
|
||||
exec.pendingMessages = append(exec.pendingMessages, steerMsgs...)
|
||||
return ControlContinue, nil
|
||||
}
|
||||
exec.finalContent = responseContent
|
||||
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_chars": len(exec.finalContent),
|
||||
})
|
||||
return ControlBreak, nil
|
||||
}
|
||||
|
||||
// Tool-call path: normalize and prepare for tool execution
|
||||
exec.normalizedToolCalls = make([]providers.ToolCall, 0, len(exec.response.ToolCalls))
|
||||
for _, tc := range exec.response.ToolCalls {
|
||||
exec.normalizedToolCalls = append(exec.normalizedToolCalls, providers.NormalizeToolCall(tc))
|
||||
}
|
||||
|
||||
toolNames := make([]string, 0, len(exec.normalizedToolCalls))
|
||||
for _, tc := range exec.normalizedToolCalls {
|
||||
toolNames = append(toolNames, tc.Name)
|
||||
}
|
||||
logger.InfoCF("agent", "LLM requested tool calls",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tools": toolNames,
|
||||
"count": len(exec.normalizedToolCalls),
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
exec.allResponsesHandled = len(exec.normalizedToolCalls) > 0
|
||||
assistantMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: exec.response.Content,
|
||||
ReasoningContent: exec.response.ReasoningContent,
|
||||
}
|
||||
for _, tc := range exec.normalizedToolCalls {
|
||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
||||
extraContent := tc.ExtraContent
|
||||
thoughtSignature := ""
|
||||
if tc.Function != nil {
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
}
|
||||
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Name: tc.Name,
|
||||
Function: &providers.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
ExtraContent: extraContent,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
})
|
||||
}
|
||||
exec.messages = append(exec.messages, assistantMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg)
|
||||
ts.recordPersistedMessage(assistantMsg)
|
||||
ts.ingestMessage(turnCtx, al, assistantMsg)
|
||||
}
|
||||
|
||||
return ControlToolLoop, nil
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// SetupTurn extracts the one-time initialization phase, returning a
|
||||
// turnExecution populated with history, messages, and candidate selection.
|
||||
// It replaces lines 56-145 of the original runTurn.
|
||||
func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution, error) {
|
||||
cfg := p.Cfg
|
||||
maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !ts.opts.NoHistory {
|
||||
if resp, err := p.ContextManager.Assemble(ctx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
}
|
||||
ts.captureRestorePoint(history, summary)
|
||||
|
||||
messages := ts.agent.ContextBuilder.BuildMessages(
|
||||
history,
|
||||
summary,
|
||||
ts.userMessage,
|
||||
ts.media,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
ts.opts.Dispatch.SenderID(),
|
||||
ts.opts.SenderDisplayName,
|
||||
activeSkillNames(ts.agent, ts.opts)...,
|
||||
)
|
||||
|
||||
messages = resolveMediaRefs(messages, p.MediaStore, maxMediaSize)
|
||||
|
||||
if !ts.opts.NoHistory {
|
||||
toolDefs := ts.agent.Tools.ToProviderDefs()
|
||||
if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) {
|
||||
logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
|
||||
map[string]any{"session_key": ts.sessionKey})
|
||||
if err := p.ContextManager.Compact(ctx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonProactive,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
if resp, err := p.ContextManager.Assemble(ctx, &AssembleRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
messages = ts.agent.ContextBuilder.BuildMessages(
|
||||
history, summary, ts.userMessage,
|
||||
ts.media, ts.channel, ts.chatID,
|
||||
ts.opts.Dispatch.SenderID(), ts.opts.SenderDisplayName,
|
||||
activeSkillNames(ts.agent, ts.opts)...,
|
||||
)
|
||||
messages = resolveMediaRefs(messages, p.MediaStore, maxMediaSize)
|
||||
}
|
||||
}
|
||||
|
||||
if !ts.opts.NoHistory && (strings.TrimSpace(ts.userMessage) != "" || len(ts.media) > 0) {
|
||||
rootMsg := providers.Message{
|
||||
Role: "user",
|
||||
Content: ts.userMessage,
|
||||
Media: append([]string(nil), ts.media...),
|
||||
}
|
||||
if len(rootMsg.Media) > 0 {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, rootMsg)
|
||||
} else {
|
||||
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
|
||||
}
|
||||
ts.recordPersistedMessage(rootMsg)
|
||||
ts.ingestMessage(ctx, p.al, rootMsg)
|
||||
}
|
||||
|
||||
activeCandidates, activeModel, usedLight := p.al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
activeProvider := ts.agent.Provider
|
||||
if usedLight && ts.agent.LightProvider != nil {
|
||||
activeProvider = ts.agent.LightProvider
|
||||
}
|
||||
|
||||
exec := newTurnExecution(
|
||||
ts.agent,
|
||||
ts.opts,
|
||||
history,
|
||||
summary,
|
||||
messages,
|
||||
)
|
||||
exec.activeCandidates = activeCandidates
|
||||
exec.activeModel = activeModel
|
||||
exec.activeProvider = activeProvider
|
||||
exec.usedLight = usedLight
|
||||
|
||||
return exec, nil
|
||||
}
|
||||
@@ -462,7 +462,8 @@ func spawnSubTurn(
|
||||
}()
|
||||
|
||||
// 8. Execute sub-turn via the real agent loop.
|
||||
turnRes, turnErr := al.runTurn(childCtx, childTS)
|
||||
pipeline := NewPipeline(al)
|
||||
turnRes, turnErr := al.runTurn(childCtx, childTS, pipeline)
|
||||
|
||||
// Release the concurrency semaphore immediately after runTurn completes,
|
||||
// before the cleanup defer runs. This prevents a deadlock where:
|
||||
|
||||
@@ -0,0 +1,624 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipeline) (turnResult, error) {
|
||||
turnCtx, turnCancel := context.WithCancel(ctx)
|
||||
defer turnCancel()
|
||||
ts.setTurnCancel(turnCancel)
|
||||
|
||||
// Inject turnState and AgentLoop into context so tools (e.g. spawn) can retrieve them.
|
||||
turnCtx = withTurnState(turnCtx, ts)
|
||||
turnCtx = WithAgentLoop(turnCtx, al)
|
||||
|
||||
al.registerActiveTurn(ts)
|
||||
defer al.clearActiveTurn(ts)
|
||||
|
||||
turnStatus := TurnEndStatusCompleted
|
||||
defer func() {
|
||||
al.emitEvent(
|
||||
EventKindTurnEnd,
|
||||
ts.eventMeta("runTurn", "turn.end"),
|
||||
TurnEndPayload{
|
||||
Status: turnStatus,
|
||||
Iterations: ts.currentIteration(),
|
||||
Duration: time.Since(ts.startedAt),
|
||||
FinalContentLen: ts.finalContentLen(),
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
al.emitEvent(
|
||||
EventKindTurnStart,
|
||||
ts.eventMeta("runTurn", "turn.start"),
|
||||
TurnStartPayload{
|
||||
UserMessage: ts.userMessage,
|
||||
MediaCount: len(ts.media),
|
||||
},
|
||||
)
|
||||
|
||||
// SetupTurn extracts the one-time initialization phase.
|
||||
exec, err := pipeline.SetupTurn(turnCtx, ts)
|
||||
if err != nil {
|
||||
return turnResult{}, err
|
||||
}
|
||||
|
||||
// Convenience references to exec fields used throughout the turn loop.
|
||||
messages := exec.messages
|
||||
pendingMessages := exec.pendingMessages
|
||||
maxMediaSize := pipeline.Cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
finalContent := exec.finalContent
|
||||
|
||||
for ts.currentIteration() < ts.agent.MaxIterations || len(exec.pendingMessages) > 0 || func() bool {
|
||||
graceful, _ := ts.gracefulInterruptRequested()
|
||||
return graceful
|
||||
}() {
|
||||
if ts.hardAbortRequested() {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
iteration := ts.currentIteration() + 1
|
||||
ts.setIteration(iteration)
|
||||
ts.setPhase(TurnPhaseRunning)
|
||||
|
||||
if iteration > 1 {
|
||||
// For subsequent iterations, read from exec.pendingMessages which
|
||||
// is where ExecuteTools (or initial poll) deposits steering.
|
||||
// We do NOT call dequeueSteeringMessagesForScope here because
|
||||
// steering was already consumed from al.steering by ExecuteTools.
|
||||
if len(exec.pendingMessages) > 0 {
|
||||
pendingMessages = append(pendingMessages, exec.pendingMessages...)
|
||||
exec.pendingMessages = nil
|
||||
}
|
||||
} else if !ts.opts.SkipInitialSteeringPoll {
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScopeWithFallback(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if parent turn has ended (SubTurn support from HEAD)
|
||||
if ts.parentTurnState != nil && ts.IsParentEnded() {
|
||||
if !ts.critical {
|
||||
logger.InfoCF("agent", "Parent turn ended, non-critical SubTurn exiting gracefully", map[string]any{
|
||||
"agent_id": ts.agentID,
|
||||
"iteration": iteration,
|
||||
"turn_id": ts.turnID,
|
||||
})
|
||||
break
|
||||
}
|
||||
logger.InfoCF("agent", "Parent turn ended, critical SubTurn continues running", map[string]any{
|
||||
"agent_id": ts.agentID,
|
||||
"iteration": iteration,
|
||||
"turn_id": ts.turnID,
|
||||
})
|
||||
}
|
||||
|
||||
// Poll for pending SubTurn results (from HEAD)
|
||||
if ts.pendingResults != nil {
|
||||
select {
|
||||
case result, ok := <-ts.pendingResults:
|
||||
if ok && result != nil && result.ForLLM != "" {
|
||||
content := al.cfg.FilterSensitiveData(result.ForLLM)
|
||||
msg := providers.Message{Role: "user", Content: fmt.Sprintf("[SubTurn Result] %s", content)}
|
||||
pendingMessages = append(pendingMessages, msg)
|
||||
}
|
||||
default:
|
||||
// No results available
|
||||
}
|
||||
}
|
||||
|
||||
// Inject pending steering messages
|
||||
if len(pendingMessages) > 0 {
|
||||
resolvedPending := resolveMediaRefs(pendingMessages, al.mediaStore, maxMediaSize)
|
||||
totalContentLen := 0
|
||||
for i, pm := range pendingMessages {
|
||||
messages = append(messages, resolvedPending[i])
|
||||
totalContentLen += len(pm.Content)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
|
||||
ts.recordPersistedMessage(pm)
|
||||
ts.ingestMessage(turnCtx, al, pm)
|
||||
}
|
||||
logger.InfoCF("agent", "Injected steering message into context",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"content_len": len(pm.Content),
|
||||
"media_count": len(pm.Media),
|
||||
})
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindSteeringInjected,
|
||||
ts.eventMeta("runTurn", "turn.steering.injected"),
|
||||
SteeringInjectedPayload{
|
||||
Count: len(pendingMessages),
|
||||
TotalContentLen: totalContentLen,
|
||||
},
|
||||
)
|
||||
// Clear exec.pendingMessages after injection so InitialSteeringMessages
|
||||
// are not re-injected on subsequent iterations (Issue 2 fix).
|
||||
exec.pendingMessages = nil
|
||||
}
|
||||
// Always sync messages into exec.messages so CallLLM sees the updated state
|
||||
exec.messages = messages
|
||||
|
||||
logger.DebugCF("agent", "LLM iteration",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"max": ts.agent.MaxIterations,
|
||||
})
|
||||
|
||||
// Execute LLM call via Pipeline
|
||||
ts.setPhase(TurnPhaseRunning)
|
||||
ctrl, callErr := pipeline.CallLLM(ctx, turnCtx, ts, exec, iteration)
|
||||
if callErr != nil {
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, callErr
|
||||
}
|
||||
messages = exec.messages
|
||||
pendingMessages = exec.pendingMessages
|
||||
finalContent = exec.finalContent
|
||||
|
||||
switch ctrl {
|
||||
case ControlContinue:
|
||||
continue
|
||||
case ControlBreak:
|
||||
// Hard abort: delegate to abortTurn (sets TurnEndStatusAborted)
|
||||
if exec.abortedByHardAbort {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
// Hook abort (HookActionAbortTurn): sets TurnEndStatusError, returns error
|
||||
if exec.abortedByHook {
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, fmt.Errorf("hook requested turn abort")
|
||||
}
|
||||
// Ensure empty response falls back to DefaultResponse
|
||||
if finalContent == "" {
|
||||
finalContent = ts.opts.DefaultResponse
|
||||
}
|
||||
return pipeline.Finalize(ctx, turnCtx, ts, exec, turnStatus, finalContent)
|
||||
case ControlToolLoop:
|
||||
// Execute tools via Pipeline
|
||||
toolCtrl := pipeline.ExecuteTools(ctx, turnCtx, ts, exec, iteration)
|
||||
switch toolCtrl {
|
||||
case ToolControlContinue:
|
||||
// Re-read exec.messages since ExecuteTools may have updated it
|
||||
// (added tool results/skipped messages) before returning ControlContinue
|
||||
messages = exec.messages
|
||||
continue
|
||||
case ToolControlBreak:
|
||||
// Hard abort: delegate to abortTurn (sets TurnEndStatusAborted)
|
||||
if exec.abortedByHardAbort {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
// Hook abort (HookActionAbortTurn): sets TurnEndStatusError, returns error
|
||||
if exec.abortedByHook {
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, fmt.Errorf("hook requested turn abort")
|
||||
}
|
||||
// ExecuteTools returned ControlBreak:
|
||||
// - allResponsesHandled=true: finalize without DefaultResponse (exec.finalContent empty)
|
||||
// - allResponsesHandled=false: coordinator applies DefaultResponse before finalize
|
||||
if exec.allResponsesHandled {
|
||||
finalContent = ""
|
||||
}
|
||||
return pipeline.Finalize(ctx, turnCtx, ts, exec, turnStatus, finalContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ts.hardAbortRequested() {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
if finalContent == "" {
|
||||
if ts.currentIteration() >= ts.agent.MaxIterations && ts.agent.MaxIterations > 0 {
|
||||
finalContent = toolLimitResponse
|
||||
} else {
|
||||
finalContent = ts.opts.DefaultResponse
|
||||
}
|
||||
}
|
||||
|
||||
// Check hard abort before finalizing (may have been set during tool execution)
|
||||
if ts.hardAbortRequested() {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
return pipeline.Finalize(ctx, turnCtx, ts, exec, turnStatus, finalContent)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) abortTurn(ts *turnState) (turnResult, error) {
|
||||
ts.setPhase(TurnPhaseAborted)
|
||||
if !ts.opts.NoHistory {
|
||||
if err := ts.restoreSession(ts.agent); err != nil {
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("abortTurn", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "session_restore",
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
return turnResult{}, err
|
||||
}
|
||||
}
|
||||
return turnResult{status: TurnEndStatusAborted}, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) selectCandidates(
|
||||
agent *AgentInstance,
|
||||
userMsg string,
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string, usedLight bool) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
if !usedLight {
|
||||
logger.DebugCF("agent", "Model routing: primary model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"light_model": agent.Router.LightModel(),
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveContextManager() ContextManager {
|
||||
name := al.cfg.Agents.Defaults.ContextManager
|
||||
if name == "" || name == "legacy" {
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
factory, ok := lookupContextManager(name)
|
||||
if !ok {
|
||||
logger.WarnCF("agent", "Unknown context manager, falling back to legacy", map[string]any{
|
||||
"name": name,
|
||||
})
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
cm, err := factory(al.cfg.Agents.Defaults.ContextManagerConfig, al)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to create context manager, falling back to legacy", map[string]any{
|
||||
"name": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return &legacyContextManager{al: al}
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
func (al *AgentLoop) askSideQuestion(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
opts *processOptions,
|
||||
question string,
|
||||
) (string, error) {
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("askSideQuestion: no agent available for /btw")
|
||||
}
|
||||
|
||||
question = strings.TrimSpace(question)
|
||||
if question == "" {
|
||||
return "", fmt.Errorf("askSideQuestion: %w", fmt.Errorf("Usage: /btw <question>"))
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
normalizeProcessOptionsInPlace(opts)
|
||||
}
|
||||
|
||||
var media []string
|
||||
var channel, chatID, senderID, senderDisplayName string
|
||||
if opts != nil {
|
||||
media = opts.Media
|
||||
channel = opts.Channel
|
||||
chatID = opts.ChatID
|
||||
senderID = opts.SenderID
|
||||
senderDisplayName = opts.SenderDisplayName
|
||||
}
|
||||
|
||||
// Build messages with context but WITHOUT adding to session history
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if opts != nil && !opts.NoHistory {
|
||||
if resp, err := al.contextManager.Assemble(ctx, &AssembleRequest{
|
||||
SessionKey: opts.SessionKey,
|
||||
Budget: agent.ContextWindow,
|
||||
MaxTokens: agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
}
|
||||
|
||||
messages := agent.ContextBuilder.BuildMessages(
|
||||
history,
|
||||
summary,
|
||||
question,
|
||||
media,
|
||||
channel,
|
||||
chatID,
|
||||
senderID,
|
||||
senderDisplayName,
|
||||
)
|
||||
|
||||
maxMediaSize := al.GetConfig().Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(agent, question, messages)
|
||||
selectedModelName := sideQuestionModelName(agent, usedLight)
|
||||
|
||||
llmOpts := map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID + ":btw",
|
||||
}
|
||||
|
||||
hookModelChanged := false
|
||||
callProvider := func(
|
||||
ctx context.Context,
|
||||
candidate providers.FallbackCandidate,
|
||||
model string,
|
||||
forceModel bool,
|
||||
callMessages []providers.Message,
|
||||
) (*providers.LLMResponse, error) {
|
||||
provider, providerModel, cleanup, err := al.isolatedSideQuestionProvider(agent, selectedModelName, candidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
if !forceModel || strings.TrimSpace(model) == "" {
|
||||
model = providerModel
|
||||
}
|
||||
callOpts := llmOpts
|
||||
if _, exists := callOpts["thinking_level"]; !exists && agent.ThinkingLevel != ThinkingOff {
|
||||
if tc, ok := provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
|
||||
callOpts = shallowCloneLLMOptions(llmOpts)
|
||||
callOpts["thinking_level"] = string(agent.ThinkingLevel)
|
||||
}
|
||||
}
|
||||
return provider.Chat(ctx, callMessages, nil, model, callOpts)
|
||||
}
|
||||
|
||||
turnCtx := newTurnContext(nil, nil, nil)
|
||||
if opts != nil {
|
||||
turnCtx = newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope)
|
||||
}
|
||||
llmModel := activeModel
|
||||
if al.hooks != nil {
|
||||
llmReq, decision := al.hooks.BeforeLLM(ctx, &LLMHookRequest{
|
||||
Meta: EventMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.request",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
},
|
||||
Context: cloneTurnContext(turnCtx),
|
||||
Model: llmModel,
|
||||
Messages: messages,
|
||||
Tools: nil,
|
||||
Options: llmOpts,
|
||||
GracefulTerminal: false,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmReq != nil {
|
||||
if strings.TrimSpace(llmReq.Model) != "" && llmReq.Model != llmModel {
|
||||
hookModelChanged = true
|
||||
}
|
||||
llmModel = llmReq.Model
|
||||
messages = llmReq.Messages
|
||||
llmOpts = llmReq.Options
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
|
||||
case HookActionHardAbort:
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
|
||||
}
|
||||
}
|
||||
if hookModelChanged {
|
||||
// Hook-selected models must not continue through the pre-hook fallback
|
||||
// candidate list, otherwise fallback execution would call the original
|
||||
// candidate model and silently ignore the hook decision.
|
||||
activeCandidates = nil
|
||||
}
|
||||
|
||||
callSideLLM := func(callMessages []providers.Message) (*providers.LLMResponse, error) {
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, err := al.fallback.Execute(
|
||||
ctx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, providerName, model string) (*providers.LLMResponse, error) {
|
||||
candidate := providers.FallbackCandidate{Provider: providerName, Model: model}
|
||||
for _, activeCandidate := range activeCandidates {
|
||||
if activeCandidate.Provider == providerName && activeCandidate.Model == model {
|
||||
candidate = activeCandidate
|
||||
break
|
||||
}
|
||||
}
|
||||
return callProvider(ctx, candidate, model, false, callMessages)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
|
||||
var candidate providers.FallbackCandidate
|
||||
if len(activeCandidates) > 0 {
|
||||
candidate = activeCandidates[0]
|
||||
}
|
||||
return callProvider(ctx, candidate, llmModel, hookModelChanged, callMessages)
|
||||
}
|
||||
|
||||
// Retry without media if vision is unsupported
|
||||
// Note: Vision retry is only applied to the initial call. If fallback chain
|
||||
// is used, vision errors from fallback providers will not trigger retry.
|
||||
var resp *providers.LLMResponse
|
||||
var err error
|
||||
resp, err = callSideLLM(messages)
|
||||
if err != nil && hasMediaRefs(messages) && isVisionUnsupportedError(err) {
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
EventMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.retry",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
},
|
||||
LLMRetryPayload{
|
||||
Attempt: 1,
|
||||
MaxRetries: 1,
|
||||
Reason: "vision_unsupported",
|
||||
Error: err.Error(),
|
||||
Backoff: 0,
|
||||
},
|
||||
)
|
||||
messagesWithoutMedia := stripMessageMedia(messages)
|
||||
resp, err = callSideLLM(messagesWithoutMedia)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Apply after_llm hooks
|
||||
if al.hooks != nil {
|
||||
llmResp, decision := al.hooks.AfterLLM(ctx, &LLMHookResponse{
|
||||
Meta: EventMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.response",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
},
|
||||
Context: cloneTurnContext(turnCtx),
|
||||
Model: llmModel,
|
||||
Response: resp,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmResp != nil && llmResp.Response != nil {
|
||||
resp = llmResp.Response
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
return "", fmt.Errorf("hook aborted turn during after_llm: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
return sideQuestionResponseContent(resp), nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) isolatedSideQuestionProvider(
|
||||
agent *AgentInstance,
|
||||
baseModelName string,
|
||||
candidate providers.FallbackCandidate,
|
||||
) (providers.LLMProvider, string, func(), error) {
|
||||
if agent == nil {
|
||||
return nil, "", func() {}, fmt.Errorf("isolatedSideQuestionProvider: no agent available for /btw")
|
||||
}
|
||||
|
||||
modelCfg, err := al.sideQuestionModelConfig(agent, baseModelName, candidate)
|
||||
if err != nil {
|
||||
return nil, "", func() {}, fmt.Errorf("isolatedSideQuestionProvider: %w", err)
|
||||
}
|
||||
|
||||
factory := al.providerFactory
|
||||
if factory == nil {
|
||||
factory = providers.CreateProviderFromConfig
|
||||
}
|
||||
provider, modelID, err := factory(modelCfg)
|
||||
if err != nil {
|
||||
return nil, "", func() {}, fmt.Errorf("isolatedSideQuestionProvider: %w", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
closeProviderIfStateful(provider)
|
||||
}
|
||||
return provider, modelID, cleanup, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) sideQuestionModelConfig(
|
||||
agent *AgentInstance,
|
||||
baseModelName string,
|
||||
candidate providers.FallbackCandidate,
|
||||
) (*config.ModelConfig, error) {
|
||||
if agent == nil {
|
||||
return nil, fmt.Errorf("sideQuestionModelConfig: no agent available for /btw")
|
||||
}
|
||||
|
||||
// If candidate has an identity key, use that
|
||||
if name := modelNameFromIdentityKey(candidate.IdentityKey); name != "" {
|
||||
modelCfg, err := resolvedModelConfig(al.GetConfig(), name, agent.Workspace)
|
||||
if err == nil {
|
||||
return modelCfg, nil
|
||||
}
|
||||
// Fallback: create a minimal config if lookup fails
|
||||
}
|
||||
|
||||
// Otherwise, clean up the base model name and use it
|
||||
baseModelName = strings.TrimSpace(baseModelName)
|
||||
modelCfg, err := resolvedModelConfig(al.GetConfig(), baseModelName, agent.Workspace)
|
||||
if err != nil {
|
||||
// Fallback: create a minimal config for test scenarios
|
||||
model := strings.TrimSpace(baseModelName)
|
||||
if candidate.Model != "" {
|
||||
model = candidate.Model
|
||||
}
|
||||
if candidate.Provider != "" && candidate.Model != "" {
|
||||
model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
|
||||
} else {
|
||||
model = ensureProtocolModel(model)
|
||||
}
|
||||
return &config.ModelConfig{
|
||||
ModelName: baseModelName,
|
||||
Model: model,
|
||||
Workspace: agent.Workspace,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// If candidate specifies a different provider/model, override
|
||||
clone := *modelCfg
|
||||
if candidate.Provider != "" && candidate.Model != "" {
|
||||
clone.Model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
|
||||
}
|
||||
return &clone, nil
|
||||
}
|
||||
@@ -0,0 +1,551 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Mock Providers for turn_coord Tests
|
||||
// =============================================================================
|
||||
|
||||
// simpleConvProvider returns a simple text response without tools
|
||||
type simpleConvProvider struct{}
|
||||
|
||||
func (p *simpleConvProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Hello! How can I help you today?",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *simpleConvProvider) GetDefaultModel() string {
|
||||
return "simple-model"
|
||||
}
|
||||
|
||||
// toolCallRespProvider returns a tool call response
|
||||
type toolCallRespProvider struct {
|
||||
toolName string
|
||||
toolArgs map[string]any
|
||||
response string
|
||||
callCount int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (p *toolCallRespProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.callCount++
|
||||
count := p.callCount
|
||||
p.mu.Unlock()
|
||||
|
||||
// First call returns a tool call, subsequent calls return final response
|
||||
if count == 1 {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Let me search for that information.",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: p.toolName,
|
||||
Arguments: p.toolArgs,
|
||||
},
|
||||
},
|
||||
FinishReason: "tool_calls",
|
||||
}, nil
|
||||
}
|
||||
return &providers.LLMResponse{
|
||||
Content: p.response,
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *toolCallRespProvider) GetDefaultModel() string {
|
||||
return "tool-model"
|
||||
}
|
||||
|
||||
// errorProvider simulates various error conditions
|
||||
type errorProvider struct {
|
||||
errType string
|
||||
callCount int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (p *errorProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.callCount++
|
||||
p.mu.Unlock()
|
||||
|
||||
switch p.errType {
|
||||
case "timeout":
|
||||
return nil, context.DeadlineExceeded
|
||||
case "context_length":
|
||||
return nil, errors.New("context_length_exceeded")
|
||||
case "vision":
|
||||
return nil, errors.New("vision_unsupported")
|
||||
default:
|
||||
return nil, errors.New("unknown error")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *errorProvider) GetDefaultModel() string {
|
||||
return "error-model"
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Test Helper Functions
|
||||
// =============================================================================
|
||||
|
||||
func newTurnCoordTestLoop(t *testing.T, provider providers.LLMProvider) (*AgentLoop, *AgentInstance, func()) {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
return al, agent, func() {
|
||||
al.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestProcessOpts(sessionKey string) processOptions {
|
||||
return processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: "cli",
|
||||
ChatID: "test-chat",
|
||||
UserMessage: "test message",
|
||||
DefaultResponse: "I couldn't process your request.",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
NoHistory: false,
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Pipeline Method Tests: SetupTurn
|
||||
// =============================================================================
|
||||
|
||||
func TestPipeline_SetupTurn_BasicInitialization(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
if exec == nil {
|
||||
t.Fatal("expected non-nil turnExecution")
|
||||
}
|
||||
if len(exec.messages) == 0 {
|
||||
t.Error("expected messages to be populated")
|
||||
}
|
||||
if exec.iteration != 0 {
|
||||
t.Errorf("expected iteration 0, got %d", exec.iteration)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Pipeline Method Tests: CallLLM
|
||||
// =============================================================================
|
||||
|
||||
func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallLLM failed: %v", err)
|
||||
}
|
||||
if ctrl != ControlBreak {
|
||||
t.Errorf("expected ControlBreak, got %v", ctrl)
|
||||
}
|
||||
if exec.response == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
if exec.response.Content == "" {
|
||||
t.Error("expected non-empty content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_WithToolCall(t *testing.T) {
|
||||
provider := &toolCallRespProvider{
|
||||
toolName: "web_search",
|
||||
toolArgs: map[string]any{"query": "test"},
|
||||
response: "Found information about test.",
|
||||
}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallLLM failed: %v", err)
|
||||
}
|
||||
if ctrl != ControlToolLoop {
|
||||
t.Errorf("expected ControlToolLoop, got %v", ctrl)
|
||||
}
|
||||
if len(exec.normalizedToolCalls) == 0 {
|
||||
t.Fatal("expected tool calls")
|
||||
}
|
||||
if exec.normalizedToolCalls[0].Name != "web_search" {
|
||||
t.Errorf("expected tool name 'web_search', got %q", exec.normalizedToolCalls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_TimeoutRetry(t *testing.T) {
|
||||
errorPrv := &errorProvider{errType: "timeout"}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
// Should retry and eventually fail after max retries
|
||||
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err == nil {
|
||||
t.Error("expected error after retries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
|
||||
errorPrv := &errorProvider{errType: "context_length"}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
// Should trigger context compression and retry
|
||||
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
// May succeed after compression or fail - either is acceptable
|
||||
t.Logf("CallLLM result after context error: err=%v", err)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Pipeline Method Tests: ExecuteTools
|
||||
// =============================================================================
|
||||
|
||||
func TestPipeline_ExecuteTools_NoTools(t *testing.T) {
|
||||
// Provider returns no tool calls, so ExecuteTools should not be called
|
||||
// This test verifies the ControlBreak path from CallLLM
|
||||
provider := &simpleConvProvider{}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
// First CallLLM returns ControlBreak (no tools)
|
||||
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("CallLLM failed: %v", err)
|
||||
}
|
||||
|
||||
if ctrl != ControlBreak {
|
||||
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
||||
}
|
||||
// No tools to execute, Finalize should be called directly
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// runTurn Integration Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestRunTurn_SimpleConversation(t *testing.T) {
|
||||
provider := &simpleConvProvider{}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
opts := makeTestProcessOpts("test-session-simple")
|
||||
|
||||
ts := newTurnState(agent, opts, turnEventScope{
|
||||
turnID: "turn-simple",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
result, err := al.runTurn(context.Background(), ts, pipeline)
|
||||
if err != nil {
|
||||
t.Fatalf("runTurn failed: %v", err)
|
||||
}
|
||||
if result.status != TurnEndStatusCompleted {
|
||||
t.Errorf("expected status Completed, got %v", result.status)
|
||||
}
|
||||
if result.finalContent == "" {
|
||||
t.Error("expected non-empty finalContent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunTurn_MaxIterations(t *testing.T) {
|
||||
// Provider always returns tool calls, should hit max iterations
|
||||
provider := &toolCallRespProvider{
|
||||
toolName: "search",
|
||||
toolArgs: map[string]any{"q": "x"},
|
||||
response: "done",
|
||||
}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
// Override max iterations to 2
|
||||
agent.MaxIterations = 2
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
opts := makeTestProcessOpts("test-session-maxiter")
|
||||
|
||||
ts := newTurnState(agent, opts, turnEventScope{
|
||||
turnID: "turn-maxiter",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
result, err := al.runTurn(context.Background(), ts, pipeline)
|
||||
if err != nil {
|
||||
t.Fatalf("runTurn failed: %v", err)
|
||||
}
|
||||
// Should complete due to max iterations
|
||||
if result.status != TurnEndStatusCompleted {
|
||||
t.Errorf("expected status Completed, got %v", result.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunTurn_HardAbort(t *testing.T) {
|
||||
// Provider simulates a slow response, but we'll abort mid-turn
|
||||
slowProvider := &slowMockProvider{delay: 10 * time.Second}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, slowProvider)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
opts := makeTestProcessOpts("test-session-abort")
|
||||
|
||||
ts := newTurnState(agent, opts, turnEventScope{
|
||||
turnID: "turn-abort",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
// Run in goroutine with abort after short delay
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
al.runTurn(context.Background(), ts, pipeline)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Request hard abort
|
||||
ts.requestHardAbort()
|
||||
|
||||
// Wait for runTurn to complete
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("runTurn did not complete after abort")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunTurn_SteeringMessageInjection(t *testing.T) {
|
||||
provider := &simpleConvProvider{}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
opts := makeTestProcessOpts("test-session-steering")
|
||||
|
||||
ts := newTurnState(agent, opts, turnEventScope{
|
||||
turnID: "turn-steering",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
// Enqueue steering message before runTurn
|
||||
steeringMsg := providers.Message{
|
||||
Role: "user",
|
||||
Content: "Steering message",
|
||||
}
|
||||
al.Steer(steeringMsg)
|
||||
|
||||
result, err := al.runTurn(context.Background(), ts, pipeline)
|
||||
if err != nil {
|
||||
t.Fatalf("runTurn failed: %v", err)
|
||||
}
|
||||
if result.status != TurnEndStatusCompleted {
|
||||
t.Errorf("expected status Completed, got %v", result.status)
|
||||
}
|
||||
// Steering message should have been injected
|
||||
}
|
||||
|
||||
func TestRunTurn_GracefulInterrupt(t *testing.T) {
|
||||
provider := &toolCallRespProvider{
|
||||
toolName: "search",
|
||||
toolArgs: map[string]any{"q": "test"},
|
||||
response: "Final response after interrupt",
|
||||
}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
opts := makeTestProcessOpts("test-session-graceful")
|
||||
|
||||
ts := newTurnState(agent, opts, turnEventScope{
|
||||
turnID: "turn-graceful",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
// Run in goroutine with graceful interrupt after first iteration
|
||||
done := make(chan struct{})
|
||||
var result turnResult
|
||||
|
||||
go func() {
|
||||
result, _ = al.runTurn(context.Background(), ts, pipeline)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give it a moment to start first iteration
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Request graceful interrupt
|
||||
ts.requestGracefulInterrupt("Please stop")
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("runTurn did not complete after graceful interrupt")
|
||||
}
|
||||
|
||||
// Should complete gracefully
|
||||
if result.status != TurnEndStatusCompleted {
|
||||
t.Errorf("expected status Completed, got %v", result.status)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// turnState Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestTurnState_GracefulInterruptRequested(t *testing.T) {
|
||||
ts := &turnState{
|
||||
gracefulInterrupt: false,
|
||||
gracefulInterruptHint: "",
|
||||
}
|
||||
|
||||
// Initially should not be requested
|
||||
requested, _ := ts.gracefulInterruptRequested()
|
||||
if requested {
|
||||
t.Error("expected no interrupt initially")
|
||||
}
|
||||
|
||||
// Request interrupt
|
||||
ts.requestGracefulInterrupt("test hint")
|
||||
|
||||
requested, hint := ts.gracefulInterruptRequested()
|
||||
if !requested {
|
||||
t.Error("expected interrupt to be requested")
|
||||
}
|
||||
if hint != "test hint" {
|
||||
t.Errorf("expected hint 'test hint', got %q", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTurnState_HardAbortRequested(t *testing.T) {
|
||||
ts := &turnState{
|
||||
hardAbort: false,
|
||||
}
|
||||
|
||||
if ts.hardAbortRequested() {
|
||||
t.Error("expected no hard abort initially")
|
||||
}
|
||||
|
||||
ts.requestHardAbort()
|
||||
|
||||
if !ts.hardAbortRequested() {
|
||||
t.Error("expected hard abort to be requested")
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
@@ -14,6 +16,10 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// TurnPhase - represents the current phase of a turn
|
||||
// =============================================================================
|
||||
|
||||
type TurnPhase string
|
||||
|
||||
const (
|
||||
@@ -25,6 +31,65 @@ const (
|
||||
TurnPhaseAborted TurnPhase = "aborted"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Control signals - returned from Pipeline methods to drive runTurn's coordinator loop
|
||||
// =============================================================================
|
||||
|
||||
type Control int
|
||||
|
||||
const (
|
||||
// ControlContinue tells the coordinator to jump back to the top of the turn loop
|
||||
// (equivalent to the original "goto turnLoop").
|
||||
ControlContinue Control = iota
|
||||
// ControlBreak tells the coordinator to exit the turn loop and proceed to Finalize.
|
||||
ControlBreak
|
||||
// ControlToolLoop tells the coordinator to execute the tool loop.
|
||||
ControlToolLoop
|
||||
)
|
||||
|
||||
// ToolControl signals returned from ExecuteTools to drive tool loop iteration.
|
||||
type ToolControl int
|
||||
|
||||
const (
|
||||
// ToolControlContinue tells the tool loop to jump to the next iteration
|
||||
// (pendingMessages arrived, SubTurn results, etc.).
|
||||
ToolControlContinue ToolControl = iota
|
||||
// ToolControlBreak tells the tool loop to exit and return to the coordinator.
|
||||
ToolControlBreak
|
||||
// ToolControlFinalize tells the coordinator that all tool responses were
|
||||
// handled and the turn should finalize without another LLM call.
|
||||
ToolControlFinalize
|
||||
)
|
||||
|
||||
// LLMPhase indicates which phase the turn is executing in.
|
||||
type LLMPhase int
|
||||
|
||||
const (
|
||||
LLMPhaseSetup LLMPhase = iota
|
||||
LLMPhasePreLLM
|
||||
LLMPhaseLLMCall
|
||||
LLMPhaseProcessing
|
||||
LLMPhaseToolLoop
|
||||
LLMPhaseTools
|
||||
LLMPhaseFinalizing
|
||||
LLMPhaseCompleted
|
||||
LLMPhaseAborted
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// turnResult - returned from runTurn
|
||||
// =============================================================================
|
||||
|
||||
type turnResult struct {
|
||||
finalContent string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ActiveTurnInfo - public info about an active turn
|
||||
// =============================================================================
|
||||
|
||||
type ActiveTurnInfo struct {
|
||||
TurnID string
|
||||
AgentID string
|
||||
@@ -40,12 +105,70 @@ type ActiveTurnInfo struct {
|
||||
ChildTurnIDs []string
|
||||
}
|
||||
|
||||
type turnResult struct {
|
||||
// =============================================================================
|
||||
// turnExecution - mutable state that persists across turn loop iterations
|
||||
// =============================================================================
|
||||
|
||||
type turnExecution struct {
|
||||
// Core message state (accumulates throughout the turn)
|
||||
messages []providers.Message // built from ContextBuilder, grows per-iteration
|
||||
pendingMessages []providers.Message // steering/SubTurn messages awaiting injection
|
||||
history []providers.Message // from ContextManager.Assemble
|
||||
summary string
|
||||
|
||||
// Turn output
|
||||
finalContent string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
|
||||
// Iteration tracking
|
||||
iteration int
|
||||
|
||||
// Per-iteration state set by Pipeline.PreLLM
|
||||
activeCandidates []providers.FallbackCandidate
|
||||
activeModel string
|
||||
activeProvider providers.LLMProvider
|
||||
usedLight bool
|
||||
|
||||
// LLM call per-iteration state
|
||||
response *providers.LLMResponse
|
||||
normalizedToolCalls []providers.ToolCall
|
||||
allResponsesHandled bool
|
||||
callMessages []providers.Message
|
||||
providerToolDefs []providers.ToolDefinition
|
||||
llmModel string
|
||||
llmOpts map[string]any
|
||||
gracefulTerminal bool
|
||||
useNativeSearch bool
|
||||
|
||||
// Phase tracking
|
||||
phase LLMPhase
|
||||
|
||||
// Abort signaling for coordinator (set by Pipeline methods)
|
||||
abortedByHardAbort bool // true when hard abort triggered during LLM/tools
|
||||
abortedByHook bool // true when HookActionAbortTurn triggered
|
||||
}
|
||||
|
||||
// newTurnExecution creates a turnExecution initialized from turnState and options.
|
||||
func newTurnExecution(
|
||||
agent *AgentInstance,
|
||||
opts processOptions,
|
||||
history []providers.Message,
|
||||
summary string,
|
||||
messages []providers.Message,
|
||||
) *turnExecution {
|
||||
return &turnExecution{
|
||||
history: history,
|
||||
summary: summary,
|
||||
messages: messages,
|
||||
pendingMessages: append([]providers.Message(nil), opts.InitialSteeringMessages...),
|
||||
iteration: 0,
|
||||
phase: LLMPhaseSetup,
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// turnState - the full state for a turn, constructed once per turn
|
||||
// =============================================================================
|
||||
|
||||
type turnState struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
@@ -109,6 +232,10 @@ type turnState struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// turnState constructors and active turn management
|
||||
// =============================================================================
|
||||
|
||||
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
|
||||
ts := &turnState{
|
||||
agent: agent,
|
||||
@@ -194,6 +321,10 @@ func (al *AgentLoop) GetActiveTurnBySession(sessionKey string) *ActiveTurnInfo {
|
||||
return &info
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// turnState - getters and setters
|
||||
// =============================================================================
|
||||
|
||||
func (ts *turnState) snapshot() ActiveTurnInfo {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
@@ -402,7 +533,9 @@ func (ts *turnState) interruptHintMessage() providers.Message {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SubTurn-related methods
|
||||
// =============================================================================
|
||||
|
||||
// Finish marks the turn as finished and closes the pendingResults channel
|
||||
func (ts *turnState) Finish(isHardAbort bool) {
|
||||
@@ -493,7 +626,9 @@ func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
|
||||
ts.lastUsage = usage
|
||||
}
|
||||
|
||||
// Context helper functions for SubTurn
|
||||
// =============================================================================
|
||||
// Context helper functions for turnState
|
||||
// =============================================================================
|
||||
|
||||
type turnStateKeyType struct{}
|
||||
|
||||
Reference in New Issue
Block a user