feat(agent): support btw side questions (#2532)

This commit is contained in:
lxowalle
2026-04-16 10:53:09 +08:00
committed by GitHub
parent a8d0b03515
commit e22b4e1eee
23 changed files with 1737 additions and 70 deletions
+89
View File
@@ -111,6 +111,8 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
type llmObserverHook struct {
eventCh chan Event
lastInbound *bus.InboundContext
lastRoute *routing.ResolvedRoute
lastScope *session.SessionScope
}
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
@@ -129,6 +131,8 @@ func (h *llmObserverHook) BeforeLLM(
) (*LLMHookRequest, HookDecision, error) {
if req.Context != nil {
h.lastInbound = cloneInboundContext(req.Context.Inbound)
h.lastRoute = cloneResolvedRoute(req.Context.Route)
h.lastScope = session.CloneScope(req.Context.Scope)
}
next := req.Clone()
next.Model = "hook-model"
@@ -230,6 +234,91 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
}
}
func TestAgentLoop_BtwCommand_UsesLLMHooks(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
useTestSideQuestionProvider(al, provider)
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
response, handled := al.handleCommand(context.Background(), bus.InboundMessage{
Context: bus.InboundContext{
Channel: "cli",
ChatID: "direct",
ChatType: "direct",
SenderID: "hook-user",
},
Content: "/btw hello",
}, agent, &processOptions{
Dispatch: DispatchRequest{
SessionKey: "session-1",
InboundContext: &bus.InboundContext{
Channel: "cli",
ChatID: "direct",
ChatType: "direct",
SenderID: "hook-user",
},
RouteResult: &routing.ResolvedRoute{
AgentID: "main",
Channel: "cli",
AccountID: routing.DefaultAccountID,
SessionPolicy: routing.SessionPolicy{
Dimensions: []string{"sender"},
},
MatchedBy: "default",
},
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "cli",
Account: routing.DefaultAccountID,
Dimensions: []string{"sender"},
Values: map[string]string{
"sender": "hook-user",
},
},
UserMessage: "/btw hello",
},
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
SenderID: "hook-user",
SenderDisplayName: "Hook User",
})
if !handled {
t.Fatal("expected /btw command to be handled")
}
if response != "hooked content" {
t.Fatalf("expected hooked content, got %q", response)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "hook-model" {
t.Fatalf("expected model hook-model, got %q", lastModel)
}
if hook.lastInbound == nil {
t.Fatal("expected hook to receive inbound context")
}
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
}
if hook.lastInbound.ChatID != "direct" {
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
}
if hook.lastRoute == nil || hook.lastRoute.AgentID != "main" {
t.Fatalf("expected hook route context for /btw, got %+v", hook.lastRoute)
}
if hook.lastScope == nil || hook.lastScope.Values["sender"] != "hook-user" {
t.Fatalf("expected hook session scope for /btw, got %+v", hook.lastScope)
}
}
type toolHookProvider struct {
mu sync.Mutex
calls int
+21
View File
@@ -29,6 +29,27 @@ func stripMessageMedia(messages []providers.Message) []providers.Message {
return stripped
}
func callLLMWithVisionUnsupportedRetry(
messages []providers.Message,
call func([]providers.Message) (*providers.LLMResponse, error),
beforeRetry func(error),
) (*providers.LLMResponse, []providers.Message, bool, error) {
response, err := call(messages)
if err == nil {
return response, messages, false, nil
}
if !messagesContainMedia(messages) || !isVisionUnsupportedError(err) {
return response, messages, false, err
}
if beforeRetry != nil {
beforeRetry(err)
}
stripped := stripMessageMedia(messages)
response, err = call(stripped)
return response, stripped, true, err
}
func isVisionUnsupportedError(err error) bool {
if err == nil {
return false
+510 -47
View File
@@ -70,6 +70,8 @@ type AgentLoop struct {
activeRequests sync.WaitGroup
reloadFunc func() error
providerFactory func(*config.ModelConfig) (providers.LLMProvider, string, error)
}
// processOptions configures how a message is processed
@@ -159,6 +161,7 @@ func NewAgentLoop(
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
al.providerFactory = providers.CreateProviderFromConfig
al.hooks = NewHookManager(eventBus)
configureHookManagerFromConfig(al.hooks, cfg)
al.contextManager = al.resolveContextManager()
@@ -479,10 +482,12 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// running. Only messages that resolve to the active turn scope are
// redirected into steering; other inbound messages are requeued.
drainCancel := func() {}
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
drainCtx, cancel := context.WithCancel(ctx)
drainCancel = cancel
go al.drainBusToSteering(drainCtx, activeScope, activeAgentID)
if !isBtwCommand(msg.Content) {
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
drainCtx, cancel := context.WithCancel(ctx)
drainCancel = cancel
go al.drainBusToSteering(drainCtx, ctx, activeScope, activeAgentID)
}
}
// Process message
@@ -604,7 +609,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// active scope into the steering queue. Messages from other scopes are requeued
// so they can be processed normally after the active turn. It drains all
// immediately available messages, blocking for the first one until ctx is done.
func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) {
func (al *AgentLoop) drainBusToSteering(ctx, priorityCtx context.Context, activeScope, activeAgentID string) {
blocking := true
var requeue []bus.InboundMessage
defer func() {
@@ -656,6 +661,17 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, active
// Transcribe audio if needed before steering, so the agent sees text.
msg, _ = al.transcribeAudioInMessage(ctx, msg)
// Handle priority commands (e.g. /btw) outside the steering queue, without
// blocking this drain from enqueueing later messages for the active turn.
if isBtwCommand(msg.Content) {
priorityMsg := msg
go al.handlePriorityCommandAsync(priorityCtx, priorityMsg)
// A priority command is not a steering interrupt. Keep waiting for the
// next inbound message while the active turn is still running.
blocking = true
continue
}
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
map[string]any{
"channel": msg.Channel,
@@ -1532,6 +1548,359 @@ func (al *AgentLoop) ProcessHeartbeat(
})
}
func sideQuestionModelName(agent *AgentInstance, usedLight bool) string {
if agent == nil {
return ""
}
if usedLight && agent.Router != nil {
if lightModel := strings.TrimSpace(agent.Router.LightModel()); lightModel != "" {
return lightModel
}
}
return agent.Model
}
func modelNameFromIdentityKey(identityKey string) string {
const prefix = "model_name:"
if strings.HasPrefix(identityKey, prefix) {
return strings.TrimSpace(strings.TrimPrefix(identityKey, prefix))
}
return ""
}
func closeProviderIfStateful(provider providers.LLMProvider) {
if stateful, ok := provider.(providers.StatefulProvider); ok {
stateful.Close()
}
}
func cloneLLMOptions(src map[string]any) map[string]any {
dst := make(map[string]any, len(src)+1)
for key, value := range src {
dst[key] = value
}
return dst
}
func (al *AgentLoop) isolatedSideQuestionProvider(
agent *AgentInstance,
baseModelName string,
candidate providers.FallbackCandidate,
) (providers.LLMProvider, string, func(), error) {
if agent == nil {
return nil, "", func() {}, fmt.Errorf("no agent available for /btw")
}
modelCfg, err := al.sideQuestionModelConfig(agent, baseModelName, candidate)
if err != nil {
return nil, "", func() {}, err
}
factory := al.providerFactory
if factory == nil {
factory = providers.CreateProviderFromConfig
}
provider, modelID, err := factory(modelCfg)
if err != nil {
return nil, "", func() {}, err
}
cleanup := func() {
closeProviderIfStateful(provider)
}
return provider, modelID, cleanup, nil
}
func (al *AgentLoop) sideQuestionModelConfig(
agent *AgentInstance,
baseModelName string,
candidate providers.FallbackCandidate,
) (*config.ModelConfig, error) {
if agent == nil {
return nil, fmt.Errorf("no agent available for /btw")
}
if name := modelNameFromIdentityKey(candidate.IdentityKey); name != "" {
return resolvedModelConfig(al.GetConfig(), name, agent.Workspace)
}
baseModelName = strings.TrimSpace(baseModelName)
modelCfg, err := resolvedModelConfig(al.GetConfig(), baseModelName, agent.Workspace)
if err != nil {
model := strings.TrimSpace(baseModelName)
if candidate.Model != "" {
model = candidate.Model
}
if candidate.Provider != "" && candidate.Model != "" {
model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
} else {
model = ensureProtocolModel(model)
}
return &config.ModelConfig{
ModelName: baseModelName,
Model: model,
Workspace: agent.Workspace,
}, nil
}
clone := *modelCfg
if candidate.Provider != "" && candidate.Model != "" {
clone.Model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
}
return &clone, nil
}
func (al *AgentLoop) askSideQuestion(
ctx context.Context,
agent *AgentInstance,
opts *processOptions,
question string,
) (string, error) {
if agent == nil {
return "", fmt.Errorf("no agent available for /btw")
}
question = strings.TrimSpace(question)
if question == "" {
return "", fmt.Errorf("Usage: /btw <question>")
}
if opts != nil {
normalizeProcessOptionsInPlace(opts)
}
var media []string
var channel, chatID, senderID, senderDisplayName string
if opts != nil {
media = opts.Media
channel = opts.Channel
chatID = opts.ChatID
senderID = opts.SenderID
senderDisplayName = opts.SenderDisplayName
}
var history []providers.Message
var summary string
if opts != nil {
if !opts.NoHistory {
if resp, err := al.contextManager.Assemble(ctx, &AssembleRequest{
SessionKey: opts.SessionKey,
Budget: agent.ContextWindow,
MaxTokens: agent.MaxTokens,
}); err == nil && resp != nil {
history = resp.History
summary = resp.Summary
}
}
}
messages := agent.ContextBuilder.BuildMessages(
history,
summary,
question,
media,
channel,
chatID,
senderID,
senderDisplayName,
)
maxMediaSize := al.GetConfig().Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
activeCandidates, activeModel, usedLight := al.selectCandidates(agent, question, messages)
selectedModelName := sideQuestionModelName(agent, usedLight)
llmOpts := map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID + ":btw",
}
hookModelChanged := false
callProvider := func(
ctx context.Context,
candidate providers.FallbackCandidate,
model string,
forceModel bool,
callMessages []providers.Message,
) (*providers.LLMResponse, error) {
provider, providerModel, cleanup, err := al.isolatedSideQuestionProvider(agent, selectedModelName, candidate)
if err != nil {
return nil, err
}
defer cleanup()
if !forceModel || strings.TrimSpace(model) == "" {
model = providerModel
}
callOpts := llmOpts
if _, exists := callOpts["thinking_level"]; !exists && agent.ThinkingLevel != ThinkingOff {
if tc, ok := provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
callOpts = cloneLLMOptions(llmOpts)
callOpts["thinking_level"] = string(agent.ThinkingLevel)
}
}
return provider.Chat(ctx, callMessages, nil, model, callOpts)
}
turnCtx := newTurnContext(nil, nil, nil)
if opts != nil {
turnCtx = newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope)
}
llmModel := activeModel
if al.hooks != nil {
llmReq, decision := al.hooks.BeforeLLM(ctx, &LLMHookRequest{
Meta: EventMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.request",
turnContext: cloneTurnContext(turnCtx),
},
Context: cloneTurnContext(turnCtx),
Model: llmModel,
Messages: messages,
Tools: nil,
Options: llmOpts,
GracefulTerminal: false,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if llmReq != nil {
if strings.TrimSpace(llmReq.Model) != "" && llmReq.Model != llmModel {
hookModelChanged = true
}
llmModel = llmReq.Model
messages = llmReq.Messages
llmOpts = llmReq.Options
}
case HookActionAbortTurn:
reason := decision.Reason
if reason == "" {
reason = "hook requested turn abort"
}
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
case HookActionHardAbort:
reason := decision.Reason
if reason == "" {
reason = "hook requested turn abort"
}
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
}
}
if hookModelChanged {
// Hook-selected models must not continue through the pre-hook fallback
// candidate list, otherwise fallback execution would call the original
// candidate model and silently ignore the hook decision.
activeCandidates = nil
}
callSideLLM := func(callMessages []providers.Message) (*providers.LLMResponse, error) {
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, err := al.fallback.Execute(
ctx,
activeCandidates,
func(ctx context.Context, providerName, model string) (*providers.LLMResponse, error) {
candidate := providers.FallbackCandidate{Provider: providerName, Model: model}
for _, activeCandidate := range activeCandidates {
if activeCandidate.Provider == providerName && activeCandidate.Model == model {
candidate = activeCandidate
break
}
}
return callProvider(ctx, candidate, model, false, callMessages)
},
)
if err != nil {
return nil, err
}
return fbResult.Response, nil
}
var candidate providers.FallbackCandidate
if len(activeCandidates) > 0 {
candidate = activeCandidates[0]
}
return callProvider(ctx, candidate, llmModel, hookModelChanged, callMessages)
}
resp, _, _, err := callLLMWithVisionUnsupportedRetry(
messages,
callSideLLM,
func(originalErr error) {
al.emitEvent(
EventKindLLMRetry,
EventMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.retry",
turnContext: cloneTurnContext(turnCtx),
},
LLMRetryPayload{
Attempt: 1,
MaxRetries: 1,
Reason: "vision_unsupported",
Error: originalErr.Error(),
Backoff: 0,
},
)
},
)
if err != nil {
return "", err
}
if resp == nil {
return "", nil
}
resp, err = al.applySideQuestionAfterLLM(ctx, turnCtx, llmModel, resp)
if err != nil {
return "", err
}
return sideQuestionResponseContent(resp), nil
}
func (al *AgentLoop) applySideQuestionAfterLLM(
ctx context.Context,
turnCtx *TurnContext,
model string,
response *providers.LLMResponse,
) (*providers.LLMResponse, error) {
if response == nil || al.hooks == nil {
return response, nil
}
llmResp, decision := al.hooks.AfterLLM(ctx, &LLMHookResponse{
Meta: EventMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.response",
turnContext: cloneTurnContext(turnCtx),
},
Context: cloneTurnContext(turnCtx),
Model: model,
Response: response,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if llmResp != nil && llmResp.Response != nil {
response = llmResp.Response
}
case HookActionAbortTurn, HookActionHardAbort:
reason := decision.Reason
if reason == "" {
reason = "hook requested turn abort"
}
return nil, fmt.Errorf("hook aborted turn during after_llm: %s", reason)
}
return response, nil
}
func sideQuestionResponseContent(response *providers.LLMResponse) string {
if response == nil {
return ""
}
if response.Content != "" {
return response.Content
}
return response.ReasoningContent
}
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
msg = bus.NormalizeInboundMessage(msg)
@@ -2363,10 +2732,42 @@ turnLoop:
var response *providers.LLMResponse
var err error
maxRetries := 2
callHasMedia := messagesContainMedia(callMessages)
didStripMedia := false
for retry := 0; retry <= maxRetries; retry++ {
response, err = callLLM(callMessages, providerToolDefs)
response, callMessages, _, err = callLLMWithVisionUnsupportedRetry(
callMessages,
func(messagesForRetry []providers.Message) (*providers.LLMResponse, error) {
return callLLM(messagesForRetry, providerToolDefs)
},
func(originalErr error) {
if !ts.opts.NoHistory {
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
ts.agent.Sessions.SetHistory(ts.sessionKey, stripMessageMedia(history))
// Keep persistedMessages aligned so abort restore-point trimming remains correct.
ts.mu.Lock()
for i := range ts.persistedMessages {
ts.persistedMessages[i].Media = nil
}
ts.mu.Unlock()
ts.refreshRestorePointFromSession(ts.agent)
}
messages = stripMessageMedia(messages)
al.emitEvent(
EventKindLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: 1,
MaxRetries: 1,
Reason: "vision_unsupported",
Error: originalErr.Error(),
Backoff: 0,
},
)
},
)
if err == nil {
break
}
@@ -2375,45 +2776,6 @@ turnLoop:
return al.abortTurn(ts)
}
// If the provider/model doesn't support multimodal inputs, retry once with media stripped
// so the session doesn't get "stuck" after a user sends an image.
if callHasMedia && !didStripMedia && isVisionUnsupportedError(err) {
didStripMedia = true
if !ts.opts.NoHistory {
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
ts.agent.Sessions.SetHistory(ts.sessionKey, stripMessageMedia(history))
// Keep persistedMessages aligned so abort restore-point trimming remains correct.
ts.mu.Lock()
for i := range ts.persistedMessages {
ts.persistedMessages[i].Media = nil
}
ts.mu.Unlock()
ts.refreshRestorePointFromSession(ts.agent)
}
messages = stripMessageMedia(messages)
callMessages = stripMessageMedia(callMessages)
callHasMedia = false
al.emitEvent(
EventKindLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: 1,
MaxRetries: 1,
Reason: "vision_unsupported",
Error: err.Error(),
Backoff: 0,
},
)
response, err = callLLM(callMessages, providerToolDefs)
if err == nil {
break
}
}
errMsg := strings.ToLower(err.Error())
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(errMsg, "deadline exceeded") ||
@@ -3748,6 +4110,11 @@ func activeSkillNames(agent *AgentInstance, opts processOptions) []string {
return resolved
}
func isBtwCommand(content string) bool {
cmdName, ok := commands.CommandName(content)
return ok && cmdName == "btw"
}
func (al *AgentLoop) applyExplicitSkillCommand(
raw string,
agent *AgentInstance,
@@ -3856,6 +4223,9 @@ func (al *AgentLoop) buildCommandsRuntime(
if agent.ContextBuilder != nil {
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
}
rt.AskSideQuestion = func(ctx context.Context, question string) (string, error) {
return al.askSideQuestion(ctx, agent, opts, question)
}
rt.GetModelInfo = func() (string, string) {
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
}
@@ -3975,6 +4345,99 @@ func mapCommandError(result commands.ExecuteResult) string {
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
}
func (al *AgentLoop) tryHandlePriorityCommand(ctx context.Context, msg bus.InboundMessage) (bool, bus.OutboundMessage) {
if !isBtwCommand(msg.Content) {
return false, bus.OutboundMessage{}
}
route, agent, err := al.resolveMessageRoute(msg)
if err != nil || agent == nil {
if err != nil {
logger.ErrorCF("agent", fmt.Sprintf("Error resolving route for /btw: %v", err), nil)
return true, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Context: outboundContextFromInbound(
&msg.Context,
msg.Channel,
msg.ChatID,
msg.Context.ReplyToMessageID,
),
Content: fmt.Sprintf("Error processing message: %v", err),
}
}
logger.WarnCF("agent", "/btw command unavailable: no agent resolved", nil)
return true, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Context: outboundContextFromInbound(
&msg.Context,
msg.Channel,
msg.ChatID,
msg.Context.ReplyToMessageID,
),
Content: "Command unavailable in current context.",
}
}
allocation := al.allocateRouteSession(route, msg)
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
msg.SessionKey = sessionKey
opts := processOptions{
Dispatch: DispatchRequest{
SessionKey: sessionKey,
SessionAliases: buildSessionAliases(sessionKey, append(allocation.SessionAliases, msg.SessionKey)...),
InboundContext: cloneInboundContext(&msg.Context),
RouteResult: cloneResolvedRoute(&route),
SessionScope: session.CloneScope(&allocation.Scope),
UserMessage: msg.Content,
Media: append([]string(nil), msg.Media...),
},
SessionKey: sessionKey,
SenderID: msg.SenderID,
SenderDisplayName: msg.Sender.DisplayName,
}
cmdCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
response, handled := al.handleCommand(cmdCtx, msg, agent, &opts)
if !handled {
return false, bus.OutboundMessage{}
}
agentID, outboundSessionKey, scope := outboundTurnMetadata(agent.ID, sessionKey, &allocation.Scope)
return true, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Context: outboundContextFromInbound(
&msg.Context,
msg.Channel,
msg.ChatID,
msg.Context.ReplyToMessageID,
),
AgentID: agentID,
SessionKey: outboundSessionKey,
Scope: scope,
Content: response,
}
}
func (al *AgentLoop) handlePriorityCommandAsync(ctx context.Context, msg bus.InboundMessage) {
handled, outbound := al.tryHandlePriorityCommand(ctx, msg)
if !handled || outbound.Content == "" {
return
}
publishCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := al.bus.PublishOutbound(publishCtx, outbound); err != nil {
logger.WarnCF("agent", "Failed to publish priority command response", map[string]any{
"error": err.Error(),
"channel": outbound.Channel,
})
}
}
// isNativeSearchProvider reports whether the given LLM provider implements
// NativeSearchCapable and returns true for SupportsNativeSearch.
func isNativeSearchProvider(p providers.LLMProvider) bool {
+343
View File
@@ -9,6 +9,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"testing"
@@ -80,6 +81,7 @@ func newStartedTestChannelManager(
type recordingProvider struct {
lastMessages []providers.Message
lastModel string
}
func (r *recordingProvider) Chat(
@@ -90,6 +92,7 @@ func (r *recordingProvider) Chat(
opts map[string]any,
) (*providers.LLMResponse, error) {
r.lastMessages = append([]providers.Message(nil), messages...)
r.lastModel = model
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
@@ -100,6 +103,47 @@ func (r *recordingProvider) GetDefaultModel() string {
return "mock-model"
}
type closeTrackingProvider struct {
recordingProvider
closed bool
}
func (p *closeTrackingProvider) Close() {
p.closed = true
}
type modelRewriteHook struct {
model string
}
func (h modelRewriteHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = h.model
return next, HookDecision{Action: HookActionModify}, nil
}
func (h modelRewriteHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
}
func useTestSideQuestionProvider(al *AgentLoop, provider providers.LLMProvider) {
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
model := provider.GetDefaultModel()
if mc != nil {
if _, modelID := providers.ExtractProtocol(mc.Model); modelID != "" {
model = modelID
}
}
return provider, model, nil
}
}
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
@@ -235,6 +279,305 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
}
}
func TestProcessMessage_BtwCommandRunsWithoutPersistingHistory(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
msg := bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain side effects",
}
route, _, err := al.resolveMessageRoute(msg)
if err != nil {
t.Fatalf("resolveMessageRoute() error = %v", err)
}
allocation := al.allocateRouteSession(route, msg)
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
initialHistory := []providers.Message{
{Role: "user", Content: "We decided to avoid global state."},
{Role: "assistant", Content: "Right, keep it request-scoped."},
}
defaultAgent.Sessions.SetHistory(sessionKey, initialHistory)
defaultAgent.Sessions.SetSummary(sessionKey, "The team decided to keep state request-scoped.")
response, err := al.processMessage(context.Background(), msg)
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages")
}
if len(provider.lastMessages) != 4 {
t.Fatalf("provider messages len = %d, want 4 (system + prior history + user)", len(provider.lastMessages))
}
if !reflect.DeepEqual(provider.lastMessages[1:3], initialHistory) {
t.Fatalf("provider history = %#v, want %#v", provider.lastMessages[1:3], initialHistory)
}
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "explain side effects" {
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
}
history := al.GetRegistry().GetDefaultAgent().Sessions.GetHistory(sessionKey)
if !reflect.DeepEqual(history, initialHistory) {
t.Fatalf("session history = %#v, want %#v", history, initialHistory)
}
}
func TestProcessMessage_BtwCommandIncludesRequestContextAndMedia(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "discord",
SenderID: "discord:123",
Sender: bus.SenderInfo{
DisplayName: "Alice",
},
ChatID: "group-1",
Content: "/btw describe this image",
Media: []string{"media://image-1"},
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages")
}
systemPrompt := provider.lastMessages[0].Content
if !strings.Contains(systemPrompt, "## Current Session\nChannel: discord\nChat ID: group-1") {
t.Fatalf("system prompt missing current session context:\n%s", systemPrompt)
}
if !strings.Contains(systemPrompt, "## Current Sender\nCurrent sender: Alice (ID: discord:123)") {
t.Fatalf("system prompt missing current sender context:\n%s", systemPrompt)
}
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "describe this image" {
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
}
if !reflect.DeepEqual(lastMessage.Media, []string{"media://image-1"}) {
t.Fatalf("last provider media = %#v, want media ref", lastMessage.Media)
}
}
func TestProcessMessage_BtwCommandUsesIsolatedProvider(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
mainProvider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, mainProvider)
var sideProvider *closeTrackingProvider
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
sideProvider = &closeTrackingProvider{}
return sideProvider, "isolated-model", nil
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain isolation",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(mainProvider.lastMessages) != 0 {
t.Fatalf("main provider was used for /btw: %+v", mainProvider.lastMessages)
}
if sideProvider == nil {
t.Fatal("side question provider factory was not called")
}
if !sideProvider.closed {
t.Fatal("isolated stateful /btw provider was not closed")
}
if len(sideProvider.lastMessages) == 0 {
t.Fatal("isolated provider did not receive messages")
}
}
func TestProcessMessage_BtwCommandRetriesWithoutMediaOnVisionUnsupported(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &visionUnsupportedMediaProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw describe this image",
Media: []string{"data:image/png;base64,abc123"},
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "ok" {
t.Fatalf("processMessage() response = %q, want %q", response, "ok")
}
if provider.calls != 2 {
t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2)
}
if !slices.Equal(provider.mediaSeen, []bool{true, false}) {
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false})
}
}
func TestProcessMessage_BtwCommandUsesProviderFactoryModel(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "lb-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []*config.ModelConfig{
{ModelName: "lb-model", Model: "openai/lb-model-a"},
{ModelName: "lb-model", Model: "openai/lb-model-b"},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
var wantModel string
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
if mc == nil {
t.Fatal("expected model config")
}
_, modelID := providers.ExtractProtocol(mc.Model)
wantModel = "factory-" + modelID
return provider, wantModel, nil
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain load balancing",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if provider.lastModel != wantModel {
t.Fatalf("/btw model = %q, want provider factory model %q", provider.lastModel, wantModel)
}
}
func TestProcessMessage_BtwCommandHookModelBypassesFallbackCandidates(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "primary-model",
ModelFallbacks: []string{"fallback-model"},
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "hook-model"})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain hook routing",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if provider.lastModel != "hook-model" {
t.Fatalf("/btw model = %q, want hook-selected model", provider.lastModel)
}
}
func TestHandleCommand_UseCommandRejectsUnknownSkill(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
+488 -8
View File
@@ -405,7 +405,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
done := make(chan struct{})
go func() {
al.drainBusToSteering(ctx, activeScope, activeAgentID)
al.drainBusToSteering(ctx, ctx, activeScope, activeAgentID)
close(done)
}()
@@ -566,12 +566,14 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
}
type blockingDirectProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
firstResp string
finalResp string
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
secondStarted chan struct{}
releaseSecond chan struct{}
firstResp string
finalResp string
}
func (p *blockingDirectProvider) Chat(
@@ -586,11 +588,15 @@ func (p *blockingDirectProvider) Chat(
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
secondStarted := p.secondStarted
releaseSecond := p.releaseSecond
firstResp := p.firstResp
finalResp := p.finalResp
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
p.firstStarted = nil
}
if call == 2 && p.secondStarted != nil {
close(p.secondStarted)
}
p.mu.Unlock()
@@ -604,6 +610,14 @@ func (p *blockingDirectProvider) Chat(
}
_ = firstStarted
_ = secondStarted
if call == 2 && releaseSecond != nil {
select {
case <-releaseSecond:
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &providers.LLMResponse{Content: finalResp}, nil
}
@@ -611,6 +625,73 @@ func (p *blockingDirectProvider) GetDefaultModel() string {
return "blocking-direct-mock"
}
type blockedBtwWithFollowupProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
secondStarted chan struct{}
releaseSecond chan struct{}
thirdStarted chan struct{}
thirdMessages []providers.Message
}
func (p *blockedBtwWithFollowupProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
secondStarted := p.secondStarted
releaseSecond := p.releaseSecond
thirdStarted := p.thirdStarted
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
}
if call == 2 && p.secondStarted != nil {
close(p.secondStarted)
}
if call == 3 {
p.thirdMessages = append([]providers.Message(nil), messages...)
if p.thirdStarted != nil {
close(p.thirdStarted)
}
}
p.mu.Unlock()
switch call {
case 1:
_ = firstStarted
select {
case <-releaseFirst:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: "long turn finished"}, nil
case 2:
_ = secondStarted
select {
case <-releaseSecond:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: "btw delayed reply"}, nil
default:
_ = thirdStarted
return &providers.LLMResponse{Content: "continued after follow-up"}, nil
}
}
func (p *blockedBtwWithFollowupProvider) GetDefaultModel() string {
return "blocked-btw-followup-mock"
}
type interruptibleTool struct {
name string
started chan struct{}
@@ -1010,6 +1091,405 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
}
}
func TestAgentLoop_Steering_BtwCommandBypassesQueuedTurn(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
firstResp: "long turn finished",
finalResp: "btw immediate reply",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute sleep 60, then send OK",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw what is the current progress?",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
messageTool, ok := al.GetRegistry().GetDefaultAgent().Tools.Get("message")
var mt *tools.MessageTool
if !ok {
mt = tools.NewMessageTool()
al.RegisterTool(mt)
} else {
var typeOK bool
mt, typeOK = messageTool.(*tools.MessageTool)
if !typeOK {
t.Fatal("expected message tool type")
}
}
mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return nil
})
if result := mt.Execute(context.Background(), map[string]any{
"channel": "test",
"chat_id": "chat1",
"content": "already sent from busy turn",
}); result == nil || result.IsError {
t.Fatalf("message tool setup result = %+v, want successful send", result)
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw immediate reply" {
t.Fatalf("expected /btw reply before long turn completion, got %q", outbound.Content)
}
if outbound.AgentID != routing.DefaultAgentID {
t.Fatalf("expected /btw outbound agent_id %q, got %q", routing.DefaultAgentID, outbound.AgentID)
}
route, _, err := al.resolveMessageRoute(btw)
if err != nil {
t.Fatalf("resolveMessageRoute(/btw) error = %v", err)
}
expectedSessionKey := resolveScopeKey(al.allocateRouteSession(route, btw).SessionKey, btw.SessionKey)
if outbound.SessionKey != expectedSessionKey {
t.Fatalf("expected /btw outbound session_key %q, got %q", expectedSessionKey, outbound.SessionKey)
}
if outbound.Scope == nil ||
outbound.Scope.AgentID != routing.DefaultAgentID ||
outbound.Scope.Channel != "test" {
t.Fatalf(
"expected /btw outbound scope for agent %q on test channel, got %+v",
routing.DefaultAgentID,
outbound.Scope,
)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw outbound response")
}
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
t.Fatalf("expected /btw to bypass steering queue, got %v", msgs)
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
t.Fatalf("expected busy turn final response to stay suppressed, got %q", outbound.Content)
case <-time.After(2 * time.Second):
}
provider.mu.Lock()
callCount := provider.calls
provider.mu.Unlock()
if callCount != 2 {
t.Fatalf("provider call count = %d, want 2", callCount)
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_Steering_BtwCommandSurvivesActiveTurnCompletion(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
secondStarted: make(chan struct{}),
releaseSecond: make(chan struct{}),
firstResp: "long turn finished",
finalResp: "btw delayed reply",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute a long turn",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw can you still answer?",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case <-provider.secondStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw LLM call to start")
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "long turn finished" {
t.Fatalf("expected first outbound to be long turn response, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for long turn response")
}
close(provider.releaseSecond)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw delayed reply" {
t.Fatalf("expected /btw response after drain cancellation, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for delayed /btw response")
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_Steering_BlockedBtwDoesNotBlockFollowupContinuation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockedBtwWithFollowupProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
secondStarted: make(chan struct{}),
releaseSecond: make(chan struct{}),
thirdStarted: make(chan struct{}),
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute a long turn",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw this side question blocks",
}
followup := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "normal follow-up while btw is blocked",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case <-provider.secondStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, followup); err != nil {
t.Fatalf("publish follow-up inbound: %v", err)
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "continued after follow-up" {
t.Fatalf("expected continuation response before /btw release, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for follow-up continuation response")
}
provider.mu.Lock()
thirdMessages := append([]providers.Message(nil), provider.thirdMessages...)
provider.mu.Unlock()
foundFollowup := false
for _, msg := range thirdMessages {
if msg.Role == "user" && msg.Content == followup.Content {
foundFollowup = true
break
}
}
if !foundFollowup {
t.Fatalf("continuation messages did not include follow-up: %+v", thirdMessages)
}
close(provider.releaseSecond)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw delayed reply" {
t.Fatalf("expected delayed /btw response, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for delayed /btw response")
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {