feat(agent): support btw side questions (#2532)

This commit is contained in:
lxowalle
2026-04-16 10:53:09 +08:00
committed by GitHub
parent a8d0b03515
commit e22b4e1eee
23 changed files with 1737 additions and 70 deletions
+510 -47
View File
@@ -70,6 +70,8 @@ type AgentLoop struct {
activeRequests sync.WaitGroup
reloadFunc func() error
providerFactory func(*config.ModelConfig) (providers.LLMProvider, string, error)
}
// processOptions configures how a message is processed
@@ -159,6 +161,7 @@ func NewAgentLoop(
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
al.providerFactory = providers.CreateProviderFromConfig
al.hooks = NewHookManager(eventBus)
configureHookManagerFromConfig(al.hooks, cfg)
al.contextManager = al.resolveContextManager()
@@ -479,10 +482,12 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// running. Only messages that resolve to the active turn scope are
// redirected into steering; other inbound messages are requeued.
drainCancel := func() {}
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
drainCtx, cancel := context.WithCancel(ctx)
drainCancel = cancel
go al.drainBusToSteering(drainCtx, activeScope, activeAgentID)
if !isBtwCommand(msg.Content) {
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
drainCtx, cancel := context.WithCancel(ctx)
drainCancel = cancel
go al.drainBusToSteering(drainCtx, ctx, activeScope, activeAgentID)
}
}
// Process message
@@ -604,7 +609,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// active scope into the steering queue. Messages from other scopes are requeued
// so they can be processed normally after the active turn. It drains all
// immediately available messages, blocking for the first one until ctx is done.
func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) {
func (al *AgentLoop) drainBusToSteering(ctx, priorityCtx context.Context, activeScope, activeAgentID string) {
blocking := true
var requeue []bus.InboundMessage
defer func() {
@@ -656,6 +661,17 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, active
// Transcribe audio if needed before steering, so the agent sees text.
msg, _ = al.transcribeAudioInMessage(ctx, msg)
// Handle priority commands (e.g. /btw) outside the steering queue, without
// blocking this drain from enqueueing later messages for the active turn.
if isBtwCommand(msg.Content) {
priorityMsg := msg
go al.handlePriorityCommandAsync(priorityCtx, priorityMsg)
// A priority command is not a steering interrupt. Keep waiting for the
// next inbound message while the active turn is still running.
blocking = true
continue
}
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
map[string]any{
"channel": msg.Channel,
@@ -1532,6 +1548,359 @@ func (al *AgentLoop) ProcessHeartbeat(
})
}
func sideQuestionModelName(agent *AgentInstance, usedLight bool) string {
if agent == nil {
return ""
}
if usedLight && agent.Router != nil {
if lightModel := strings.TrimSpace(agent.Router.LightModel()); lightModel != "" {
return lightModel
}
}
return agent.Model
}
func modelNameFromIdentityKey(identityKey string) string {
const prefix = "model_name:"
if strings.HasPrefix(identityKey, prefix) {
return strings.TrimSpace(strings.TrimPrefix(identityKey, prefix))
}
return ""
}
func closeProviderIfStateful(provider providers.LLMProvider) {
if stateful, ok := provider.(providers.StatefulProvider); ok {
stateful.Close()
}
}
func cloneLLMOptions(src map[string]any) map[string]any {
dst := make(map[string]any, len(src)+1)
for key, value := range src {
dst[key] = value
}
return dst
}
func (al *AgentLoop) isolatedSideQuestionProvider(
agent *AgentInstance,
baseModelName string,
candidate providers.FallbackCandidate,
) (providers.LLMProvider, string, func(), error) {
if agent == nil {
return nil, "", func() {}, fmt.Errorf("no agent available for /btw")
}
modelCfg, err := al.sideQuestionModelConfig(agent, baseModelName, candidate)
if err != nil {
return nil, "", func() {}, err
}
factory := al.providerFactory
if factory == nil {
factory = providers.CreateProviderFromConfig
}
provider, modelID, err := factory(modelCfg)
if err != nil {
return nil, "", func() {}, err
}
cleanup := func() {
closeProviderIfStateful(provider)
}
return provider, modelID, cleanup, nil
}
func (al *AgentLoop) sideQuestionModelConfig(
agent *AgentInstance,
baseModelName string,
candidate providers.FallbackCandidate,
) (*config.ModelConfig, error) {
if agent == nil {
return nil, fmt.Errorf("no agent available for /btw")
}
if name := modelNameFromIdentityKey(candidate.IdentityKey); name != "" {
return resolvedModelConfig(al.GetConfig(), name, agent.Workspace)
}
baseModelName = strings.TrimSpace(baseModelName)
modelCfg, err := resolvedModelConfig(al.GetConfig(), baseModelName, agent.Workspace)
if err != nil {
model := strings.TrimSpace(baseModelName)
if candidate.Model != "" {
model = candidate.Model
}
if candidate.Provider != "" && candidate.Model != "" {
model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
} else {
model = ensureProtocolModel(model)
}
return &config.ModelConfig{
ModelName: baseModelName,
Model: model,
Workspace: agent.Workspace,
}, nil
}
clone := *modelCfg
if candidate.Provider != "" && candidate.Model != "" {
clone.Model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
}
return &clone, nil
}
func (al *AgentLoop) askSideQuestion(
ctx context.Context,
agent *AgentInstance,
opts *processOptions,
question string,
) (string, error) {
if agent == nil {
return "", fmt.Errorf("no agent available for /btw")
}
question = strings.TrimSpace(question)
if question == "" {
return "", fmt.Errorf("Usage: /btw <question>")
}
if opts != nil {
normalizeProcessOptionsInPlace(opts)
}
var media []string
var channel, chatID, senderID, senderDisplayName string
if opts != nil {
media = opts.Media
channel = opts.Channel
chatID = opts.ChatID
senderID = opts.SenderID
senderDisplayName = opts.SenderDisplayName
}
var history []providers.Message
var summary string
if opts != nil {
if !opts.NoHistory {
if resp, err := al.contextManager.Assemble(ctx, &AssembleRequest{
SessionKey: opts.SessionKey,
Budget: agent.ContextWindow,
MaxTokens: agent.MaxTokens,
}); err == nil && resp != nil {
history = resp.History
summary = resp.Summary
}
}
}
messages := agent.ContextBuilder.BuildMessages(
history,
summary,
question,
media,
channel,
chatID,
senderID,
senderDisplayName,
)
maxMediaSize := al.GetConfig().Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
activeCandidates, activeModel, usedLight := al.selectCandidates(agent, question, messages)
selectedModelName := sideQuestionModelName(agent, usedLight)
llmOpts := map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID + ":btw",
}
hookModelChanged := false
callProvider := func(
ctx context.Context,
candidate providers.FallbackCandidate,
model string,
forceModel bool,
callMessages []providers.Message,
) (*providers.LLMResponse, error) {
provider, providerModel, cleanup, err := al.isolatedSideQuestionProvider(agent, selectedModelName, candidate)
if err != nil {
return nil, err
}
defer cleanup()
if !forceModel || strings.TrimSpace(model) == "" {
model = providerModel
}
callOpts := llmOpts
if _, exists := callOpts["thinking_level"]; !exists && agent.ThinkingLevel != ThinkingOff {
if tc, ok := provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
callOpts = cloneLLMOptions(llmOpts)
callOpts["thinking_level"] = string(agent.ThinkingLevel)
}
}
return provider.Chat(ctx, callMessages, nil, model, callOpts)
}
turnCtx := newTurnContext(nil, nil, nil)
if opts != nil {
turnCtx = newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope)
}
llmModel := activeModel
if al.hooks != nil {
llmReq, decision := al.hooks.BeforeLLM(ctx, &LLMHookRequest{
Meta: EventMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.request",
turnContext: cloneTurnContext(turnCtx),
},
Context: cloneTurnContext(turnCtx),
Model: llmModel,
Messages: messages,
Tools: nil,
Options: llmOpts,
GracefulTerminal: false,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if llmReq != nil {
if strings.TrimSpace(llmReq.Model) != "" && llmReq.Model != llmModel {
hookModelChanged = true
}
llmModel = llmReq.Model
messages = llmReq.Messages
llmOpts = llmReq.Options
}
case HookActionAbortTurn:
reason := decision.Reason
if reason == "" {
reason = "hook requested turn abort"
}
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
case HookActionHardAbort:
reason := decision.Reason
if reason == "" {
reason = "hook requested turn abort"
}
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
}
}
if hookModelChanged {
// Hook-selected models must not continue through the pre-hook fallback
// candidate list, otherwise fallback execution would call the original
// candidate model and silently ignore the hook decision.
activeCandidates = nil
}
callSideLLM := func(callMessages []providers.Message) (*providers.LLMResponse, error) {
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, err := al.fallback.Execute(
ctx,
activeCandidates,
func(ctx context.Context, providerName, model string) (*providers.LLMResponse, error) {
candidate := providers.FallbackCandidate{Provider: providerName, Model: model}
for _, activeCandidate := range activeCandidates {
if activeCandidate.Provider == providerName && activeCandidate.Model == model {
candidate = activeCandidate
break
}
}
return callProvider(ctx, candidate, model, false, callMessages)
},
)
if err != nil {
return nil, err
}
return fbResult.Response, nil
}
var candidate providers.FallbackCandidate
if len(activeCandidates) > 0 {
candidate = activeCandidates[0]
}
return callProvider(ctx, candidate, llmModel, hookModelChanged, callMessages)
}
resp, _, _, err := callLLMWithVisionUnsupportedRetry(
messages,
callSideLLM,
func(originalErr error) {
al.emitEvent(
EventKindLLMRetry,
EventMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.retry",
turnContext: cloneTurnContext(turnCtx),
},
LLMRetryPayload{
Attempt: 1,
MaxRetries: 1,
Reason: "vision_unsupported",
Error: originalErr.Error(),
Backoff: 0,
},
)
},
)
if err != nil {
return "", err
}
if resp == nil {
return "", nil
}
resp, err = al.applySideQuestionAfterLLM(ctx, turnCtx, llmModel, resp)
if err != nil {
return "", err
}
return sideQuestionResponseContent(resp), nil
}
func (al *AgentLoop) applySideQuestionAfterLLM(
ctx context.Context,
turnCtx *TurnContext,
model string,
response *providers.LLMResponse,
) (*providers.LLMResponse, error) {
if response == nil || al.hooks == nil {
return response, nil
}
llmResp, decision := al.hooks.AfterLLM(ctx, &LLMHookResponse{
Meta: EventMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.response",
turnContext: cloneTurnContext(turnCtx),
},
Context: cloneTurnContext(turnCtx),
Model: model,
Response: response,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if llmResp != nil && llmResp.Response != nil {
response = llmResp.Response
}
case HookActionAbortTurn, HookActionHardAbort:
reason := decision.Reason
if reason == "" {
reason = "hook requested turn abort"
}
return nil, fmt.Errorf("hook aborted turn during after_llm: %s", reason)
}
return response, nil
}
func sideQuestionResponseContent(response *providers.LLMResponse) string {
if response == nil {
return ""
}
if response.Content != "" {
return response.Content
}
return response.ReasoningContent
}
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
msg = bus.NormalizeInboundMessage(msg)
@@ -2363,10 +2732,42 @@ turnLoop:
var response *providers.LLMResponse
var err error
maxRetries := 2
callHasMedia := messagesContainMedia(callMessages)
didStripMedia := false
for retry := 0; retry <= maxRetries; retry++ {
response, err = callLLM(callMessages, providerToolDefs)
response, callMessages, _, err = callLLMWithVisionUnsupportedRetry(
callMessages,
func(messagesForRetry []providers.Message) (*providers.LLMResponse, error) {
return callLLM(messagesForRetry, providerToolDefs)
},
func(originalErr error) {
if !ts.opts.NoHistory {
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
ts.agent.Sessions.SetHistory(ts.sessionKey, stripMessageMedia(history))
// Keep persistedMessages aligned so abort restore-point trimming remains correct.
ts.mu.Lock()
for i := range ts.persistedMessages {
ts.persistedMessages[i].Media = nil
}
ts.mu.Unlock()
ts.refreshRestorePointFromSession(ts.agent)
}
messages = stripMessageMedia(messages)
al.emitEvent(
EventKindLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: 1,
MaxRetries: 1,
Reason: "vision_unsupported",
Error: originalErr.Error(),
Backoff: 0,
},
)
},
)
if err == nil {
break
}
@@ -2375,45 +2776,6 @@ turnLoop:
return al.abortTurn(ts)
}
// If the provider/model doesn't support multimodal inputs, retry once with media stripped
// so the session doesn't get "stuck" after a user sends an image.
if callHasMedia && !didStripMedia && isVisionUnsupportedError(err) {
didStripMedia = true
if !ts.opts.NoHistory {
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
ts.agent.Sessions.SetHistory(ts.sessionKey, stripMessageMedia(history))
// Keep persistedMessages aligned so abort restore-point trimming remains correct.
ts.mu.Lock()
for i := range ts.persistedMessages {
ts.persistedMessages[i].Media = nil
}
ts.mu.Unlock()
ts.refreshRestorePointFromSession(ts.agent)
}
messages = stripMessageMedia(messages)
callMessages = stripMessageMedia(callMessages)
callHasMedia = false
al.emitEvent(
EventKindLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: 1,
MaxRetries: 1,
Reason: "vision_unsupported",
Error: err.Error(),
Backoff: 0,
},
)
response, err = callLLM(callMessages, providerToolDefs)
if err == nil {
break
}
}
errMsg := strings.ToLower(err.Error())
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(errMsg, "deadline exceeded") ||
@@ -3748,6 +4110,11 @@ func activeSkillNames(agent *AgentInstance, opts processOptions) []string {
return resolved
}
func isBtwCommand(content string) bool {
cmdName, ok := commands.CommandName(content)
return ok && cmdName == "btw"
}
func (al *AgentLoop) applyExplicitSkillCommand(
raw string,
agent *AgentInstance,
@@ -3856,6 +4223,9 @@ func (al *AgentLoop) buildCommandsRuntime(
if agent.ContextBuilder != nil {
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
}
rt.AskSideQuestion = func(ctx context.Context, question string) (string, error) {
return al.askSideQuestion(ctx, agent, opts, question)
}
rt.GetModelInfo = func() (string, string) {
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
}
@@ -3975,6 +4345,99 @@ func mapCommandError(result commands.ExecuteResult) string {
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
}
func (al *AgentLoop) tryHandlePriorityCommand(ctx context.Context, msg bus.InboundMessage) (bool, bus.OutboundMessage) {
if !isBtwCommand(msg.Content) {
return false, bus.OutboundMessage{}
}
route, agent, err := al.resolveMessageRoute(msg)
if err != nil || agent == nil {
if err != nil {
logger.ErrorCF("agent", fmt.Sprintf("Error resolving route for /btw: %v", err), nil)
return true, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Context: outboundContextFromInbound(
&msg.Context,
msg.Channel,
msg.ChatID,
msg.Context.ReplyToMessageID,
),
Content: fmt.Sprintf("Error processing message: %v", err),
}
}
logger.WarnCF("agent", "/btw command unavailable: no agent resolved", nil)
return true, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Context: outboundContextFromInbound(
&msg.Context,
msg.Channel,
msg.ChatID,
msg.Context.ReplyToMessageID,
),
Content: "Command unavailable in current context.",
}
}
allocation := al.allocateRouteSession(route, msg)
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
msg.SessionKey = sessionKey
opts := processOptions{
Dispatch: DispatchRequest{
SessionKey: sessionKey,
SessionAliases: buildSessionAliases(sessionKey, append(allocation.SessionAliases, msg.SessionKey)...),
InboundContext: cloneInboundContext(&msg.Context),
RouteResult: cloneResolvedRoute(&route),
SessionScope: session.CloneScope(&allocation.Scope),
UserMessage: msg.Content,
Media: append([]string(nil), msg.Media...),
},
SessionKey: sessionKey,
SenderID: msg.SenderID,
SenderDisplayName: msg.Sender.DisplayName,
}
cmdCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
response, handled := al.handleCommand(cmdCtx, msg, agent, &opts)
if !handled {
return false, bus.OutboundMessage{}
}
agentID, outboundSessionKey, scope := outboundTurnMetadata(agent.ID, sessionKey, &allocation.Scope)
return true, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Context: outboundContextFromInbound(
&msg.Context,
msg.Channel,
msg.ChatID,
msg.Context.ReplyToMessageID,
),
AgentID: agentID,
SessionKey: outboundSessionKey,
Scope: scope,
Content: response,
}
}
func (al *AgentLoop) handlePriorityCommandAsync(ctx context.Context, msg bus.InboundMessage) {
handled, outbound := al.tryHandlePriorityCommand(ctx, msg)
if !handled || outbound.Content == "" {
return
}
publishCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := al.bus.PublishOutbound(publishCtx, outbound); err != nil {
logger.WarnCF("agent", "Failed to publish priority command response", map[string]any{
"error": err.Error(),
"channel": outbound.Channel,
})
}
}
// isNativeSearchProvider reports whether the given LLM provider implements
// NativeSearchCapable and returns true for SupportsNativeSearch.
func isNativeSearchProvider(p providers.LLMProvider) bool {