feat(agent): support btw side questions (#2532)

This commit is contained in:
lxowalle
2026-04-16 10:53:09 +08:00
committed by GitHub
parent a8d0b03515
commit e22b4e1eee
23 changed files with 1737 additions and 70 deletions
+343
View File
@@ -9,6 +9,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"testing"
@@ -80,6 +81,7 @@ func newStartedTestChannelManager(
type recordingProvider struct {
lastMessages []providers.Message
lastModel string
}
func (r *recordingProvider) Chat(
@@ -90,6 +92,7 @@ func (r *recordingProvider) Chat(
opts map[string]any,
) (*providers.LLMResponse, error) {
r.lastMessages = append([]providers.Message(nil), messages...)
r.lastModel = model
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
@@ -100,6 +103,47 @@ func (r *recordingProvider) GetDefaultModel() string {
return "mock-model"
}
type closeTrackingProvider struct {
recordingProvider
closed bool
}
func (p *closeTrackingProvider) Close() {
p.closed = true
}
type modelRewriteHook struct {
model string
}
func (h modelRewriteHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = h.model
return next, HookDecision{Action: HookActionModify}, nil
}
func (h modelRewriteHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
}
func useTestSideQuestionProvider(al *AgentLoop, provider providers.LLMProvider) {
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
model := provider.GetDefaultModel()
if mc != nil {
if _, modelID := providers.ExtractProtocol(mc.Model); modelID != "" {
model = modelID
}
}
return provider, model, nil
}
}
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
@@ -235,6 +279,305 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
}
}
func TestProcessMessage_BtwCommandRunsWithoutPersistingHistory(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
msg := bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain side effects",
}
route, _, err := al.resolveMessageRoute(msg)
if err != nil {
t.Fatalf("resolveMessageRoute() error = %v", err)
}
allocation := al.allocateRouteSession(route, msg)
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
initialHistory := []providers.Message{
{Role: "user", Content: "We decided to avoid global state."},
{Role: "assistant", Content: "Right, keep it request-scoped."},
}
defaultAgent.Sessions.SetHistory(sessionKey, initialHistory)
defaultAgent.Sessions.SetSummary(sessionKey, "The team decided to keep state request-scoped.")
response, err := al.processMessage(context.Background(), msg)
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages")
}
if len(provider.lastMessages) != 4 {
t.Fatalf("provider messages len = %d, want 4 (system + prior history + user)", len(provider.lastMessages))
}
if !reflect.DeepEqual(provider.lastMessages[1:3], initialHistory) {
t.Fatalf("provider history = %#v, want %#v", provider.lastMessages[1:3], initialHistory)
}
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "explain side effects" {
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
}
history := al.GetRegistry().GetDefaultAgent().Sessions.GetHistory(sessionKey)
if !reflect.DeepEqual(history, initialHistory) {
t.Fatalf("session history = %#v, want %#v", history, initialHistory)
}
}
func TestProcessMessage_BtwCommandIncludesRequestContextAndMedia(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "discord",
SenderID: "discord:123",
Sender: bus.SenderInfo{
DisplayName: "Alice",
},
ChatID: "group-1",
Content: "/btw describe this image",
Media: []string{"media://image-1"},
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages")
}
systemPrompt := provider.lastMessages[0].Content
if !strings.Contains(systemPrompt, "## Current Session\nChannel: discord\nChat ID: group-1") {
t.Fatalf("system prompt missing current session context:\n%s", systemPrompt)
}
if !strings.Contains(systemPrompt, "## Current Sender\nCurrent sender: Alice (ID: discord:123)") {
t.Fatalf("system prompt missing current sender context:\n%s", systemPrompt)
}
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "describe this image" {
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
}
if !reflect.DeepEqual(lastMessage.Media, []string{"media://image-1"}) {
t.Fatalf("last provider media = %#v, want media ref", lastMessage.Media)
}
}
func TestProcessMessage_BtwCommandUsesIsolatedProvider(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
mainProvider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, mainProvider)
var sideProvider *closeTrackingProvider
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
sideProvider = &closeTrackingProvider{}
return sideProvider, "isolated-model", nil
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain isolation",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(mainProvider.lastMessages) != 0 {
t.Fatalf("main provider was used for /btw: %+v", mainProvider.lastMessages)
}
if sideProvider == nil {
t.Fatal("side question provider factory was not called")
}
if !sideProvider.closed {
t.Fatal("isolated stateful /btw provider was not closed")
}
if len(sideProvider.lastMessages) == 0 {
t.Fatal("isolated provider did not receive messages")
}
}
func TestProcessMessage_BtwCommandRetriesWithoutMediaOnVisionUnsupported(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &visionUnsupportedMediaProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw describe this image",
Media: []string{"data:image/png;base64,abc123"},
}))
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "ok" {
t.Fatalf("processMessage() response = %q, want %q", response, "ok")
}
if provider.calls != 2 {
t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2)
}
if !slices.Equal(provider.mediaSeen, []bool{true, false}) {
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false})
}
}
func TestProcessMessage_BtwCommandUsesProviderFactoryModel(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "lb-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []*config.ModelConfig{
{ModelName: "lb-model", Model: "openai/lb-model-a"},
{ModelName: "lb-model", Model: "openai/lb-model-b"},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
var wantModel string
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
if mc == nil {
t.Fatal("expected model config")
}
_, modelID := providers.ExtractProtocol(mc.Model)
wantModel = "factory-" + modelID
return provider, wantModel, nil
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain load balancing",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if provider.lastModel != wantModel {
t.Fatalf("/btw model = %q, want provider factory model %q", provider.lastModel, wantModel)
}
}
func TestProcessMessage_BtwCommandHookModelBypassesFallbackCandidates(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "primary-model",
ModelFallbacks: []string{"fallback-model"},
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "hook-model"})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain hook routing",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if provider.lastModel != "hook-model" {
t.Fatalf("/btw model = %q, want hook-selected model", provider.lastModel)
}
}
func TestHandleCommand_UseCommandRejectsUnknownSkill(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{