mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): support btw side questions (#2532)
This commit is contained in:
@@ -111,6 +111,8 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
lastInbound *bus.InboundContext
|
||||
lastRoute *routing.ResolvedRoute
|
||||
lastScope *session.SessionScope
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
@@ -129,6 +131,8 @@ func (h *llmObserverHook) BeforeLLM(
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
if req.Context != nil {
|
||||
h.lastInbound = cloneInboundContext(req.Context.Inbound)
|
||||
h.lastRoute = cloneResolvedRoute(req.Context.Route)
|
||||
h.lastScope = session.CloneScope(req.Context.Scope)
|
||||
}
|
||||
next := req.Clone()
|
||||
next.Model = "hook-model"
|
||||
@@ -230,6 +234,91 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_BtwCommand_UsesLLMHooks(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
||||
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, handled := al.handleCommand(context.Background(), bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "hook-user",
|
||||
},
|
||||
Content: "/btw hello",
|
||||
}, agent, &processOptions{
|
||||
Dispatch: DispatchRequest{
|
||||
SessionKey: "session-1",
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "hook-user",
|
||||
},
|
||||
RouteResult: &routing.ResolvedRoute{
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
AccountID: routing.DefaultAccountID,
|
||||
SessionPolicy: routing.SessionPolicy{
|
||||
Dimensions: []string{"sender"},
|
||||
},
|
||||
MatchedBy: "default",
|
||||
},
|
||||
SessionScope: &session.SessionScope{
|
||||
Version: session.ScopeVersionV1,
|
||||
AgentID: "main",
|
||||
Channel: "cli",
|
||||
Account: routing.DefaultAccountID,
|
||||
Dimensions: []string{"sender"},
|
||||
Values: map[string]string{
|
||||
"sender": "hook-user",
|
||||
},
|
||||
},
|
||||
UserMessage: "/btw hello",
|
||||
},
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
SenderID: "hook-user",
|
||||
SenderDisplayName: "Hook User",
|
||||
})
|
||||
if !handled {
|
||||
t.Fatal("expected /btw command to be handled")
|
||||
}
|
||||
if response != "hooked content" {
|
||||
t.Fatalf("expected hooked content, got %q", response)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "hook-model" {
|
||||
t.Fatalf("expected model hook-model, got %q", lastModel)
|
||||
}
|
||||
if hook.lastInbound == nil {
|
||||
t.Fatal("expected hook to receive inbound context")
|
||||
}
|
||||
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
|
||||
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
|
||||
}
|
||||
if hook.lastInbound.ChatID != "direct" {
|
||||
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
|
||||
}
|
||||
if hook.lastRoute == nil || hook.lastRoute.AgentID != "main" {
|
||||
t.Fatalf("expected hook route context for /btw, got %+v", hook.lastRoute)
|
||||
}
|
||||
if hook.lastScope == nil || hook.lastScope.Values["sender"] != "hook-user" {
|
||||
t.Fatalf("expected hook session scope for /btw, got %+v", hook.lastScope)
|
||||
}
|
||||
}
|
||||
|
||||
type toolHookProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
|
||||
@@ -29,6 +29,27 @@ func stripMessageMedia(messages []providers.Message) []providers.Message {
|
||||
return stripped
|
||||
}
|
||||
|
||||
func callLLMWithVisionUnsupportedRetry(
|
||||
messages []providers.Message,
|
||||
call func([]providers.Message) (*providers.LLMResponse, error),
|
||||
beforeRetry func(error),
|
||||
) (*providers.LLMResponse, []providers.Message, bool, error) {
|
||||
response, err := call(messages)
|
||||
if err == nil {
|
||||
return response, messages, false, nil
|
||||
}
|
||||
if !messagesContainMedia(messages) || !isVisionUnsupportedError(err) {
|
||||
return response, messages, false, err
|
||||
}
|
||||
|
||||
if beforeRetry != nil {
|
||||
beforeRetry(err)
|
||||
}
|
||||
stripped := stripMessageMedia(messages)
|
||||
response, err = call(stripped)
|
||||
return response, stripped, true, err
|
||||
}
|
||||
|
||||
func isVisionUnsupportedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
|
||||
+510
-47
@@ -70,6 +70,8 @@ type AgentLoop struct {
|
||||
activeRequests sync.WaitGroup
|
||||
|
||||
reloadFunc func() error
|
||||
|
||||
providerFactory func(*config.ModelConfig) (providers.LLMProvider, string, error)
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -159,6 +161,7 @@ func NewAgentLoop(
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
al.providerFactory = providers.CreateProviderFromConfig
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
al.contextManager = al.resolveContextManager()
|
||||
@@ -479,10 +482,12 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// running. Only messages that resolve to the active turn scope are
|
||||
// redirected into steering; other inbound messages are requeued.
|
||||
drainCancel := func() {}
|
||||
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
|
||||
drainCtx, cancel := context.WithCancel(ctx)
|
||||
drainCancel = cancel
|
||||
go al.drainBusToSteering(drainCtx, activeScope, activeAgentID)
|
||||
if !isBtwCommand(msg.Content) {
|
||||
if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
|
||||
drainCtx, cancel := context.WithCancel(ctx)
|
||||
drainCancel = cancel
|
||||
go al.drainBusToSteering(drainCtx, ctx, activeScope, activeAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
// Process message
|
||||
@@ -604,7 +609,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// active scope into the steering queue. Messages from other scopes are requeued
|
||||
// so they can be processed normally after the active turn. It drains all
|
||||
// immediately available messages, blocking for the first one until ctx is done.
|
||||
func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) {
|
||||
func (al *AgentLoop) drainBusToSteering(ctx, priorityCtx context.Context, activeScope, activeAgentID string) {
|
||||
blocking := true
|
||||
var requeue []bus.InboundMessage
|
||||
defer func() {
|
||||
@@ -656,6 +661,17 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, active
|
||||
// Transcribe audio if needed before steering, so the agent sees text.
|
||||
msg, _ = al.transcribeAudioInMessage(ctx, msg)
|
||||
|
||||
// Handle priority commands (e.g. /btw) outside the steering queue, without
|
||||
// blocking this drain from enqueueing later messages for the active turn.
|
||||
if isBtwCommand(msg.Content) {
|
||||
priorityMsg := msg
|
||||
go al.handlePriorityCommandAsync(priorityCtx, priorityMsg)
|
||||
// A priority command is not a steering interrupt. Keep waiting for the
|
||||
// next inbound message while the active turn is still running.
|
||||
blocking = true
|
||||
continue
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Redirecting inbound message to steering queue",
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
@@ -1532,6 +1548,359 @@ func (al *AgentLoop) ProcessHeartbeat(
|
||||
})
|
||||
}
|
||||
|
||||
func sideQuestionModelName(agent *AgentInstance, usedLight bool) string {
|
||||
if agent == nil {
|
||||
return ""
|
||||
}
|
||||
if usedLight && agent.Router != nil {
|
||||
if lightModel := strings.TrimSpace(agent.Router.LightModel()); lightModel != "" {
|
||||
return lightModel
|
||||
}
|
||||
}
|
||||
return agent.Model
|
||||
}
|
||||
|
||||
func modelNameFromIdentityKey(identityKey string) string {
|
||||
const prefix = "model_name:"
|
||||
if strings.HasPrefix(identityKey, prefix) {
|
||||
return strings.TrimSpace(strings.TrimPrefix(identityKey, prefix))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func closeProviderIfStateful(provider providers.LLMProvider) {
|
||||
if stateful, ok := provider.(providers.StatefulProvider); ok {
|
||||
stateful.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func cloneLLMOptions(src map[string]any) map[string]any {
|
||||
dst := make(map[string]any, len(src)+1)
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func (al *AgentLoop) isolatedSideQuestionProvider(
|
||||
agent *AgentInstance,
|
||||
baseModelName string,
|
||||
candidate providers.FallbackCandidate,
|
||||
) (providers.LLMProvider, string, func(), error) {
|
||||
if agent == nil {
|
||||
return nil, "", func() {}, fmt.Errorf("no agent available for /btw")
|
||||
}
|
||||
|
||||
modelCfg, err := al.sideQuestionModelConfig(agent, baseModelName, candidate)
|
||||
if err != nil {
|
||||
return nil, "", func() {}, err
|
||||
}
|
||||
|
||||
factory := al.providerFactory
|
||||
if factory == nil {
|
||||
factory = providers.CreateProviderFromConfig
|
||||
}
|
||||
|
||||
provider, modelID, err := factory(modelCfg)
|
||||
if err != nil {
|
||||
return nil, "", func() {}, err
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
closeProviderIfStateful(provider)
|
||||
}
|
||||
return provider, modelID, cleanup, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) sideQuestionModelConfig(
|
||||
agent *AgentInstance,
|
||||
baseModelName string,
|
||||
candidate providers.FallbackCandidate,
|
||||
) (*config.ModelConfig, error) {
|
||||
if agent == nil {
|
||||
return nil, fmt.Errorf("no agent available for /btw")
|
||||
}
|
||||
|
||||
if name := modelNameFromIdentityKey(candidate.IdentityKey); name != "" {
|
||||
return resolvedModelConfig(al.GetConfig(), name, agent.Workspace)
|
||||
}
|
||||
|
||||
baseModelName = strings.TrimSpace(baseModelName)
|
||||
modelCfg, err := resolvedModelConfig(al.GetConfig(), baseModelName, agent.Workspace)
|
||||
if err != nil {
|
||||
model := strings.TrimSpace(baseModelName)
|
||||
if candidate.Model != "" {
|
||||
model = candidate.Model
|
||||
}
|
||||
if candidate.Provider != "" && candidate.Model != "" {
|
||||
model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
|
||||
} else {
|
||||
model = ensureProtocolModel(model)
|
||||
}
|
||||
return &config.ModelConfig{
|
||||
ModelName: baseModelName,
|
||||
Model: model,
|
||||
Workspace: agent.Workspace,
|
||||
}, nil
|
||||
}
|
||||
|
||||
clone := *modelCfg
|
||||
if candidate.Provider != "" && candidate.Model != "" {
|
||||
clone.Model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
|
||||
}
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) askSideQuestion(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
opts *processOptions,
|
||||
question string,
|
||||
) (string, error) {
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no agent available for /btw")
|
||||
}
|
||||
|
||||
question = strings.TrimSpace(question)
|
||||
if question == "" {
|
||||
return "", fmt.Errorf("Usage: /btw <question>")
|
||||
}
|
||||
|
||||
if opts != nil {
|
||||
normalizeProcessOptionsInPlace(opts)
|
||||
}
|
||||
var media []string
|
||||
var channel, chatID, senderID, senderDisplayName string
|
||||
if opts != nil {
|
||||
media = opts.Media
|
||||
channel = opts.Channel
|
||||
chatID = opts.ChatID
|
||||
senderID = opts.SenderID
|
||||
senderDisplayName = opts.SenderDisplayName
|
||||
}
|
||||
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if opts != nil {
|
||||
if !opts.NoHistory {
|
||||
if resp, err := al.contextManager.Assemble(ctx, &AssembleRequest{
|
||||
SessionKey: opts.SessionKey,
|
||||
Budget: agent.ContextWindow,
|
||||
MaxTokens: agent.MaxTokens,
|
||||
}); err == nil && resp != nil {
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages := agent.ContextBuilder.BuildMessages(
|
||||
history,
|
||||
summary,
|
||||
question,
|
||||
media,
|
||||
channel,
|
||||
chatID,
|
||||
senderID,
|
||||
senderDisplayName,
|
||||
)
|
||||
|
||||
maxMediaSize := al.GetConfig().Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(agent, question, messages)
|
||||
selectedModelName := sideQuestionModelName(agent, usedLight)
|
||||
|
||||
llmOpts := map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID + ":btw",
|
||||
}
|
||||
|
||||
hookModelChanged := false
|
||||
callProvider := func(
|
||||
ctx context.Context,
|
||||
candidate providers.FallbackCandidate,
|
||||
model string,
|
||||
forceModel bool,
|
||||
callMessages []providers.Message,
|
||||
) (*providers.LLMResponse, error) {
|
||||
provider, providerModel, cleanup, err := al.isolatedSideQuestionProvider(agent, selectedModelName, candidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
if !forceModel || strings.TrimSpace(model) == "" {
|
||||
model = providerModel
|
||||
}
|
||||
callOpts := llmOpts
|
||||
if _, exists := callOpts["thinking_level"]; !exists && agent.ThinkingLevel != ThinkingOff {
|
||||
if tc, ok := provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
|
||||
callOpts = cloneLLMOptions(llmOpts)
|
||||
callOpts["thinking_level"] = string(agent.ThinkingLevel)
|
||||
}
|
||||
}
|
||||
return provider.Chat(ctx, callMessages, nil, model, callOpts)
|
||||
}
|
||||
|
||||
turnCtx := newTurnContext(nil, nil, nil)
|
||||
if opts != nil {
|
||||
turnCtx = newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope)
|
||||
}
|
||||
llmModel := activeModel
|
||||
if al.hooks != nil {
|
||||
llmReq, decision := al.hooks.BeforeLLM(ctx, &LLMHookRequest{
|
||||
Meta: EventMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.request",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
},
|
||||
Context: cloneTurnContext(turnCtx),
|
||||
Model: llmModel,
|
||||
Messages: messages,
|
||||
Tools: nil,
|
||||
Options: llmOpts,
|
||||
GracefulTerminal: false,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmReq != nil {
|
||||
if strings.TrimSpace(llmReq.Model) != "" && llmReq.Model != llmModel {
|
||||
hookModelChanged = true
|
||||
}
|
||||
llmModel = llmReq.Model
|
||||
messages = llmReq.Messages
|
||||
llmOpts = llmReq.Options
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
|
||||
case HookActionHardAbort:
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
return "", fmt.Errorf("hook aborted turn during before_llm: %s", reason)
|
||||
}
|
||||
}
|
||||
if hookModelChanged {
|
||||
// Hook-selected models must not continue through the pre-hook fallback
|
||||
// candidate list, otherwise fallback execution would call the original
|
||||
// candidate model and silently ignore the hook decision.
|
||||
activeCandidates = nil
|
||||
}
|
||||
|
||||
callSideLLM := func(callMessages []providers.Message) (*providers.LLMResponse, error) {
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, err := al.fallback.Execute(
|
||||
ctx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, providerName, model string) (*providers.LLMResponse, error) {
|
||||
candidate := providers.FallbackCandidate{Provider: providerName, Model: model}
|
||||
for _, activeCandidate := range activeCandidates {
|
||||
if activeCandidate.Provider == providerName && activeCandidate.Model == model {
|
||||
candidate = activeCandidate
|
||||
break
|
||||
}
|
||||
}
|
||||
return callProvider(ctx, candidate, model, false, callMessages)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
|
||||
var candidate providers.FallbackCandidate
|
||||
if len(activeCandidates) > 0 {
|
||||
candidate = activeCandidates[0]
|
||||
}
|
||||
return callProvider(ctx, candidate, llmModel, hookModelChanged, callMessages)
|
||||
}
|
||||
|
||||
resp, _, _, err := callLLMWithVisionUnsupportedRetry(
|
||||
messages,
|
||||
callSideLLM,
|
||||
func(originalErr error) {
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
EventMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.retry",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
},
|
||||
LLMRetryPayload{
|
||||
Attempt: 1,
|
||||
MaxRetries: 1,
|
||||
Reason: "vision_unsupported",
|
||||
Error: originalErr.Error(),
|
||||
Backoff: 0,
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp == nil {
|
||||
return "", nil
|
||||
}
|
||||
resp, err = al.applySideQuestionAfterLLM(ctx, turnCtx, llmModel, resp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sideQuestionResponseContent(resp), nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) applySideQuestionAfterLLM(
|
||||
ctx context.Context,
|
||||
turnCtx *TurnContext,
|
||||
model string,
|
||||
response *providers.LLMResponse,
|
||||
) (*providers.LLMResponse, error) {
|
||||
if response == nil || al.hooks == nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
llmResp, decision := al.hooks.AfterLLM(ctx, &LLMHookResponse{
|
||||
Meta: EventMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.response",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
},
|
||||
Context: cloneTurnContext(turnCtx),
|
||||
Model: model,
|
||||
Response: response,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmResp != nil && llmResp.Response != nil {
|
||||
response = llmResp.Response
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
return nil, fmt.Errorf("hook aborted turn during after_llm: %s", reason)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func sideQuestionResponseContent(response *providers.LLMResponse) string {
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
if response.Content != "" {
|
||||
return response.Content
|
||||
}
|
||||
return response.ReasoningContent
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
msg = bus.NormalizeInboundMessage(msg)
|
||||
|
||||
@@ -2363,10 +2732,42 @@ turnLoop:
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
maxRetries := 2
|
||||
callHasMedia := messagesContainMedia(callMessages)
|
||||
didStripMedia := false
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
response, err = callLLM(callMessages, providerToolDefs)
|
||||
response, callMessages, _, err = callLLMWithVisionUnsupportedRetry(
|
||||
callMessages,
|
||||
func(messagesForRetry []providers.Message) (*providers.LLMResponse, error) {
|
||||
return callLLM(messagesForRetry, providerToolDefs)
|
||||
},
|
||||
func(originalErr error) {
|
||||
if !ts.opts.NoHistory {
|
||||
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
ts.agent.Sessions.SetHistory(ts.sessionKey, stripMessageMedia(history))
|
||||
|
||||
// Keep persistedMessages aligned so abort restore-point trimming remains correct.
|
||||
ts.mu.Lock()
|
||||
for i := range ts.persistedMessages {
|
||||
ts.persistedMessages[i].Media = nil
|
||||
}
|
||||
ts.mu.Unlock()
|
||||
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
}
|
||||
|
||||
messages = stripMessageMedia(messages)
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: 1,
|
||||
MaxRetries: 1,
|
||||
Reason: "vision_unsupported",
|
||||
Error: originalErr.Error(),
|
||||
Backoff: 0,
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
@@ -2375,45 +2776,6 @@ turnLoop:
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
// If the provider/model doesn't support multimodal inputs, retry once with media stripped
|
||||
// so the session doesn't get "stuck" after a user sends an image.
|
||||
if callHasMedia && !didStripMedia && isVisionUnsupportedError(err) {
|
||||
didStripMedia = true
|
||||
if !ts.opts.NoHistory {
|
||||
history = ts.agent.Sessions.GetHistory(ts.sessionKey)
|
||||
ts.agent.Sessions.SetHistory(ts.sessionKey, stripMessageMedia(history))
|
||||
|
||||
// Keep persistedMessages aligned so abort restore-point trimming remains correct.
|
||||
ts.mu.Lock()
|
||||
for i := range ts.persistedMessages {
|
||||
ts.persistedMessages[i].Media = nil
|
||||
}
|
||||
ts.mu.Unlock()
|
||||
|
||||
ts.refreshRestorePointFromSession(ts.agent)
|
||||
}
|
||||
|
||||
messages = stripMessageMedia(messages)
|
||||
callMessages = stripMessageMedia(callMessages)
|
||||
callHasMedia = false
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: 1,
|
||||
MaxRetries: 1,
|
||||
Reason: "vision_unsupported",
|
||||
Error: err.Error(),
|
||||
Backoff: 0,
|
||||
},
|
||||
)
|
||||
response, err = callLLM(callMessages, providerToolDefs)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
|
||||
strings.Contains(errMsg, "deadline exceeded") ||
|
||||
@@ -3748,6 +4110,11 @@ func activeSkillNames(agent *AgentInstance, opts processOptions) []string {
|
||||
return resolved
|
||||
}
|
||||
|
||||
func isBtwCommand(content string) bool {
|
||||
cmdName, ok := commands.CommandName(content)
|
||||
return ok && cmdName == "btw"
|
||||
}
|
||||
|
||||
func (al *AgentLoop) applyExplicitSkillCommand(
|
||||
raw string,
|
||||
agent *AgentInstance,
|
||||
@@ -3856,6 +4223,9 @@ func (al *AgentLoop) buildCommandsRuntime(
|
||||
if agent.ContextBuilder != nil {
|
||||
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
|
||||
}
|
||||
rt.AskSideQuestion = func(ctx context.Context, question string) (string, error) {
|
||||
return al.askSideQuestion(ctx, agent, opts, question)
|
||||
}
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
|
||||
}
|
||||
@@ -3975,6 +4345,99 @@ func mapCommandError(result commands.ExecuteResult) string {
|
||||
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) tryHandlePriorityCommand(ctx context.Context, msg bus.InboundMessage) (bool, bus.OutboundMessage) {
|
||||
if !isBtwCommand(msg.Content) {
|
||||
return false, bus.OutboundMessage{}
|
||||
}
|
||||
|
||||
route, agent, err := al.resolveMessageRoute(msg)
|
||||
if err != nil || agent == nil {
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", fmt.Sprintf("Error resolving route for /btw: %v", err), nil)
|
||||
return true, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Context: outboundContextFromInbound(
|
||||
&msg.Context,
|
||||
msg.Channel,
|
||||
msg.ChatID,
|
||||
msg.Context.ReplyToMessageID,
|
||||
),
|
||||
Content: fmt.Sprintf("Error processing message: %v", err),
|
||||
}
|
||||
}
|
||||
logger.WarnCF("agent", "/btw command unavailable: no agent resolved", nil)
|
||||
return true, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Context: outboundContextFromInbound(
|
||||
&msg.Context,
|
||||
msg.Channel,
|
||||
msg.ChatID,
|
||||
msg.Context.ReplyToMessageID,
|
||||
),
|
||||
Content: "Command unavailable in current context.",
|
||||
}
|
||||
}
|
||||
|
||||
allocation := al.allocateRouteSession(route, msg)
|
||||
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
|
||||
msg.SessionKey = sessionKey
|
||||
opts := processOptions{
|
||||
Dispatch: DispatchRequest{
|
||||
SessionKey: sessionKey,
|
||||
SessionAliases: buildSessionAliases(sessionKey, append(allocation.SessionAliases, msg.SessionKey)...),
|
||||
InboundContext: cloneInboundContext(&msg.Context),
|
||||
RouteResult: cloneResolvedRoute(&route),
|
||||
SessionScope: session.CloneScope(&allocation.Scope),
|
||||
UserMessage: msg.Content,
|
||||
Media: append([]string(nil), msg.Media...),
|
||||
},
|
||||
SessionKey: sessionKey,
|
||||
SenderID: msg.SenderID,
|
||||
SenderDisplayName: msg.Sender.DisplayName,
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
response, handled := al.handleCommand(cmdCtx, msg, agent, &opts)
|
||||
if !handled {
|
||||
return false, bus.OutboundMessage{}
|
||||
}
|
||||
agentID, outboundSessionKey, scope := outboundTurnMetadata(agent.ID, sessionKey, &allocation.Scope)
|
||||
return true, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Context: outboundContextFromInbound(
|
||||
&msg.Context,
|
||||
msg.Channel,
|
||||
msg.ChatID,
|
||||
msg.Context.ReplyToMessageID,
|
||||
),
|
||||
AgentID: agentID,
|
||||
SessionKey: outboundSessionKey,
|
||||
Scope: scope,
|
||||
Content: response,
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handlePriorityCommandAsync(ctx context.Context, msg bus.InboundMessage) {
|
||||
handled, outbound := al.tryHandlePriorityCommand(ctx, msg)
|
||||
if !handled || outbound.Content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
publishCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
if err := al.bus.PublishOutbound(publishCtx, outbound); err != nil {
|
||||
logger.WarnCF("agent", "Failed to publish priority command response", map[string]any{
|
||||
"error": err.Error(),
|
||||
"channel": outbound.Channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// isNativeSearchProvider reports whether the given LLM provider implements
|
||||
// NativeSearchCapable and returns true for SupportsNativeSearch.
|
||||
func isNativeSearchProvider(p providers.LLMProvider) bool {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -80,6 +81,7 @@ func newStartedTestChannelManager(
|
||||
|
||||
type recordingProvider struct {
|
||||
lastMessages []providers.Message
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (r *recordingProvider) Chat(
|
||||
@@ -90,6 +92,7 @@ func (r *recordingProvider) Chat(
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastMessages = append([]providers.Message(nil), messages...)
|
||||
r.lastModel = model
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
@@ -100,6 +103,47 @@ func (r *recordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type closeTrackingProvider struct {
|
||||
recordingProvider
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (p *closeTrackingProvider) Close() {
|
||||
p.closed = true
|
||||
}
|
||||
|
||||
type modelRewriteHook struct {
|
||||
model string
|
||||
}
|
||||
|
||||
func (h modelRewriteHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = h.model
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h modelRewriteHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func useTestSideQuestionProvider(al *AgentLoop, provider providers.LLMProvider) {
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := provider.GetDefaultModel()
|
||||
if mc != nil {
|
||||
if _, modelID := providers.ExtractProtocol(mc.Model); modelID != "" {
|
||||
model = modelID
|
||||
}
|
||||
}
|
||||
return provider, model, nil
|
||||
}
|
||||
}
|
||||
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
@@ -235,6 +279,305 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandRunsWithoutPersistingHistory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain side effects",
|
||||
}
|
||||
route, _, err := al.resolveMessageRoute(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
allocation := al.allocateRouteSession(route, msg)
|
||||
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
|
||||
initialHistory := []providers.Message{
|
||||
{Role: "user", Content: "We decided to avoid global state."},
|
||||
{Role: "assistant", Content: "Right, keep it request-scoped."},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, initialHistory)
|
||||
defaultAgent.Sessions.SetSummary(sessionKey, "The team decided to keep state request-scoped.")
|
||||
|
||||
response, err := al.processMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(provider.lastMessages) == 0 {
|
||||
t.Fatal("provider did not receive any messages")
|
||||
}
|
||||
if len(provider.lastMessages) != 4 {
|
||||
t.Fatalf("provider messages len = %d, want 4 (system + prior history + user)", len(provider.lastMessages))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(provider.lastMessages[1:3], initialHistory) {
|
||||
t.Fatalf("provider history = %#v, want %#v", provider.lastMessages[1:3], initialHistory)
|
||||
}
|
||||
|
||||
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
||||
if lastMessage.Role != "user" || lastMessage.Content != "explain side effects" {
|
||||
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
|
||||
}
|
||||
|
||||
history := al.GetRegistry().GetDefaultAgent().Sessions.GetHistory(sessionKey)
|
||||
if !reflect.DeepEqual(history, initialHistory) {
|
||||
t.Fatalf("session history = %#v, want %#v", history, initialHistory)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandIncludesRequestContextAndMedia(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "discord",
|
||||
SenderID: "discord:123",
|
||||
Sender: bus.SenderInfo{
|
||||
DisplayName: "Alice",
|
||||
},
|
||||
ChatID: "group-1",
|
||||
Content: "/btw describe this image",
|
||||
Media: []string{"media://image-1"},
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(provider.lastMessages) == 0 {
|
||||
t.Fatal("provider did not receive any messages")
|
||||
}
|
||||
|
||||
systemPrompt := provider.lastMessages[0].Content
|
||||
if !strings.Contains(systemPrompt, "## Current Session\nChannel: discord\nChat ID: group-1") {
|
||||
t.Fatalf("system prompt missing current session context:\n%s", systemPrompt)
|
||||
}
|
||||
if !strings.Contains(systemPrompt, "## Current Sender\nCurrent sender: Alice (ID: discord:123)") {
|
||||
t.Fatalf("system prompt missing current sender context:\n%s", systemPrompt)
|
||||
}
|
||||
|
||||
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
||||
if lastMessage.Role != "user" || lastMessage.Content != "describe this image" {
|
||||
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
|
||||
}
|
||||
if !reflect.DeepEqual(lastMessage.Media, []string{"media://image-1"}) {
|
||||
t.Fatalf("last provider media = %#v, want media ref", lastMessage.Media)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandUsesIsolatedProvider(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
mainProvider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, mainProvider)
|
||||
var sideProvider *closeTrackingProvider
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
sideProvider = &closeTrackingProvider{}
|
||||
return sideProvider, "isolated-model", nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain isolation",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(mainProvider.lastMessages) != 0 {
|
||||
t.Fatalf("main provider was used for /btw: %+v", mainProvider.lastMessages)
|
||||
}
|
||||
if sideProvider == nil {
|
||||
t.Fatal("side question provider factory was not called")
|
||||
}
|
||||
if !sideProvider.closed {
|
||||
t.Fatal("isolated stateful /btw provider was not closed")
|
||||
}
|
||||
if len(sideProvider.lastMessages) == 0 {
|
||||
t.Fatal("isolated provider did not receive messages")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandRetriesWithoutMediaOnVisionUnsupported(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &visionUnsupportedMediaProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw describe this image",
|
||||
Media: []string{"data:image/png;base64,abc123"},
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "ok" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "ok")
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2)
|
||||
}
|
||||
if !slices.Equal(provider.mediaSeen, []bool{true, false}) {
|
||||
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandUsesProviderFactoryModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "lb-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "lb-model", Model: "openai/lb-model-a"},
|
||||
{ModelName: "lb-model", Model: "openai/lb-model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
var wantModel string
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
if mc == nil {
|
||||
t.Fatal("expected model config")
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(mc.Model)
|
||||
wantModel = "factory-" + modelID
|
||||
return provider, wantModel, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain load balancing",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if provider.lastModel != wantModel {
|
||||
t.Fatalf("/btw model = %q, want provider factory model %q", provider.lastModel, wantModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandHookModelBypassesFallbackCandidates(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"fallback-model"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "hook-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain hook routing",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if provider.lastModel != "hook-model" {
|
||||
t.Fatalf("/btw model = %q, want hook-selected model", provider.lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleCommand_UseCommandRejectsUnknownSkill(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
|
||||
+488
-8
@@ -405,7 +405,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
al.drainBusToSteering(ctx, activeScope, activeAgentID)
|
||||
al.drainBusToSteering(ctx, ctx, activeScope, activeAgentID)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -566,12 +566,14 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
|
||||
}
|
||||
|
||||
type blockingDirectProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
firstResp string
|
||||
finalResp string
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
secondStarted chan struct{}
|
||||
releaseSecond chan struct{}
|
||||
firstResp string
|
||||
finalResp string
|
||||
}
|
||||
|
||||
func (p *blockingDirectProvider) Chat(
|
||||
@@ -586,11 +588,15 @@ func (p *blockingDirectProvider) Chat(
|
||||
call := p.calls
|
||||
firstStarted := p.firstStarted
|
||||
releaseFirst := p.releaseFirst
|
||||
secondStarted := p.secondStarted
|
||||
releaseSecond := p.releaseSecond
|
||||
firstResp := p.firstResp
|
||||
finalResp := p.finalResp
|
||||
if call == 1 && p.firstStarted != nil {
|
||||
close(p.firstStarted)
|
||||
p.firstStarted = nil
|
||||
}
|
||||
if call == 2 && p.secondStarted != nil {
|
||||
close(p.secondStarted)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
@@ -604,6 +610,14 @@ func (p *blockingDirectProvider) Chat(
|
||||
}
|
||||
|
||||
_ = firstStarted
|
||||
_ = secondStarted
|
||||
if call == 2 && releaseSecond != nil {
|
||||
select {
|
||||
case <-releaseSecond:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &providers.LLMResponse{Content: finalResp}, nil
|
||||
}
|
||||
|
||||
@@ -611,6 +625,73 @@ func (p *blockingDirectProvider) GetDefaultModel() string {
|
||||
return "blocking-direct-mock"
|
||||
}
|
||||
|
||||
type blockedBtwWithFollowupProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
secondStarted chan struct{}
|
||||
releaseSecond chan struct{}
|
||||
thirdStarted chan struct{}
|
||||
thirdMessages []providers.Message
|
||||
}
|
||||
|
||||
func (p *blockedBtwWithFollowupProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
call := p.calls
|
||||
firstStarted := p.firstStarted
|
||||
releaseFirst := p.releaseFirst
|
||||
secondStarted := p.secondStarted
|
||||
releaseSecond := p.releaseSecond
|
||||
thirdStarted := p.thirdStarted
|
||||
if call == 1 && p.firstStarted != nil {
|
||||
close(p.firstStarted)
|
||||
}
|
||||
if call == 2 && p.secondStarted != nil {
|
||||
close(p.secondStarted)
|
||||
}
|
||||
if call == 3 {
|
||||
p.thirdMessages = append([]providers.Message(nil), messages...)
|
||||
if p.thirdStarted != nil {
|
||||
close(p.thirdStarted)
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
switch call {
|
||||
case 1:
|
||||
_ = firstStarted
|
||||
select {
|
||||
case <-releaseFirst:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return &providers.LLMResponse{Content: "long turn finished"}, nil
|
||||
case 2:
|
||||
_ = secondStarted
|
||||
select {
|
||||
case <-releaseSecond:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return &providers.LLMResponse{Content: "btw delayed reply"}, nil
|
||||
default:
|
||||
_ = thirdStarted
|
||||
return &providers.LLMResponse{Content: "continued after follow-up"}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *blockedBtwWithFollowupProvider) GetDefaultModel() string {
|
||||
return "blocked-btw-followup-mock"
|
||||
}
|
||||
|
||||
type interruptibleTool struct {
|
||||
name string
|
||||
started chan struct{}
|
||||
@@ -1010,6 +1091,405 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_BtwCommandBypassesQueuedTurn(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
firstResp: "long turn finished",
|
||||
finalResp: "btw immediate reply",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "execute sleep 60, then send OK",
|
||||
}
|
||||
btw := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "/btw what is the current progress?",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
messageTool, ok := al.GetRegistry().GetDefaultAgent().Tools.Get("message")
|
||||
var mt *tools.MessageTool
|
||||
if !ok {
|
||||
mt = tools.NewMessageTool()
|
||||
al.RegisterTool(mt)
|
||||
} else {
|
||||
var typeOK bool
|
||||
mt, typeOK = messageTool.(*tools.MessageTool)
|
||||
if !typeOK {
|
||||
t.Fatal("expected message tool type")
|
||||
}
|
||||
}
|
||||
mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return nil
|
||||
})
|
||||
if result := mt.Execute(context.Background(), map[string]any{
|
||||
"channel": "test",
|
||||
"chat_id": "chat1",
|
||||
"content": "already sent from busy turn",
|
||||
}); result == nil || result.IsError {
|
||||
t.Fatalf("message tool setup result = %+v, want successful send", result)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
|
||||
t.Fatalf("publish /btw inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "btw immediate reply" {
|
||||
t.Fatalf("expected /btw reply before long turn completion, got %q", outbound.Content)
|
||||
}
|
||||
if outbound.AgentID != routing.DefaultAgentID {
|
||||
t.Fatalf("expected /btw outbound agent_id %q, got %q", routing.DefaultAgentID, outbound.AgentID)
|
||||
}
|
||||
route, _, err := al.resolveMessageRoute(btw)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute(/btw) error = %v", err)
|
||||
}
|
||||
expectedSessionKey := resolveScopeKey(al.allocateRouteSession(route, btw).SessionKey, btw.SessionKey)
|
||||
if outbound.SessionKey != expectedSessionKey {
|
||||
t.Fatalf("expected /btw outbound session_key %q, got %q", expectedSessionKey, outbound.SessionKey)
|
||||
}
|
||||
if outbound.Scope == nil ||
|
||||
outbound.Scope.AgentID != routing.DefaultAgentID ||
|
||||
outbound.Scope.Channel != "test" {
|
||||
t.Fatalf(
|
||||
"expected /btw outbound scope for agent %q on test channel, got %+v",
|
||||
routing.DefaultAgentID,
|
||||
outbound.Scope,
|
||||
)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /btw outbound response")
|
||||
}
|
||||
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
|
||||
t.Fatalf("expected /btw to bypass steering queue, got %v", msgs)
|
||||
}
|
||||
|
||||
close(provider.releaseFirst)
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected busy turn final response to stay suppressed, got %q", outbound.Content)
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
callCount := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if callCount != 2 {
|
||||
t.Fatalf("provider call count = %d, want 2", callCount)
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_BtwCommandSurvivesActiveTurnCompletion(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
secondStarted: make(chan struct{}),
|
||||
releaseSecond: make(chan struct{}),
|
||||
firstResp: "long turn finished",
|
||||
finalResp: "btw delayed reply",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "execute a long turn",
|
||||
}
|
||||
btw := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "/btw can you still answer?",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
|
||||
t.Fatalf("publish /btw inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.secondStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /btw LLM call to start")
|
||||
}
|
||||
|
||||
close(provider.releaseFirst)
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "long turn finished" {
|
||||
t.Fatalf("expected first outbound to be long turn response, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for long turn response")
|
||||
}
|
||||
|
||||
close(provider.releaseSecond)
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "btw delayed reply" {
|
||||
t.Fatalf("expected /btw response after drain cancellation, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for delayed /btw response")
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_BlockedBtwDoesNotBlockFollowupContinuation(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &blockedBtwWithFollowupProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
secondStarted: make(chan struct{}),
|
||||
releaseSecond: make(chan struct{}),
|
||||
thirdStarted: make(chan struct{}),
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "execute a long turn",
|
||||
}
|
||||
btw := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "/btw this side question blocks",
|
||||
}
|
||||
followup := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "normal follow-up while btw is blocked",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
|
||||
t.Fatalf("publish /btw inbound: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-provider.secondStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /btw LLM call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, followup); err != nil {
|
||||
t.Fatalf("publish follow-up inbound: %v", err)
|
||||
}
|
||||
close(provider.releaseFirst)
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "continued after follow-up" {
|
||||
t.Fatalf("expected continuation response before /btw release, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for follow-up continuation response")
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
thirdMessages := append([]providers.Message(nil), provider.thirdMessages...)
|
||||
provider.mu.Unlock()
|
||||
foundFollowup := false
|
||||
for _, msg := range thirdMessages {
|
||||
if msg.Role == "user" && msg.Content == followup.Content {
|
||||
foundFollowup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundFollowup {
|
||||
t.Fatalf("continuation messages did not include follow-up: %+v", thirdMessages)
|
||||
}
|
||||
|
||||
close(provider.releaseSecond)
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "btw delayed reply" {
|
||||
t.Fatalf("expected delayed /btw response, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for delayed /btw response")
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user