mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): support btw side questions (#2532)
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -80,6 +81,7 @@ func newStartedTestChannelManager(
|
||||
|
||||
type recordingProvider struct {
|
||||
lastMessages []providers.Message
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (r *recordingProvider) Chat(
|
||||
@@ -90,6 +92,7 @@ func (r *recordingProvider) Chat(
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastMessages = append([]providers.Message(nil), messages...)
|
||||
r.lastModel = model
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
@@ -100,6 +103,47 @@ func (r *recordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type closeTrackingProvider struct {
|
||||
recordingProvider
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (p *closeTrackingProvider) Close() {
|
||||
p.closed = true
|
||||
}
|
||||
|
||||
type modelRewriteHook struct {
|
||||
model string
|
||||
}
|
||||
|
||||
func (h modelRewriteHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = h.model
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h modelRewriteHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func useTestSideQuestionProvider(al *AgentLoop, provider providers.LLMProvider) {
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := provider.GetDefaultModel()
|
||||
if mc != nil {
|
||||
if _, modelID := providers.ExtractProtocol(mc.Model); modelID != "" {
|
||||
model = modelID
|
||||
}
|
||||
}
|
||||
return provider, model, nil
|
||||
}
|
||||
}
|
||||
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
@@ -235,6 +279,305 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandRunsWithoutPersistingHistory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain side effects",
|
||||
}
|
||||
route, _, err := al.resolveMessageRoute(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
allocation := al.allocateRouteSession(route, msg)
|
||||
sessionKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey)
|
||||
initialHistory := []providers.Message{
|
||||
{Role: "user", Content: "We decided to avoid global state."},
|
||||
{Role: "assistant", Content: "Right, keep it request-scoped."},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, initialHistory)
|
||||
defaultAgent.Sessions.SetSummary(sessionKey, "The team decided to keep state request-scoped.")
|
||||
|
||||
response, err := al.processMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(provider.lastMessages) == 0 {
|
||||
t.Fatal("provider did not receive any messages")
|
||||
}
|
||||
if len(provider.lastMessages) != 4 {
|
||||
t.Fatalf("provider messages len = %d, want 4 (system + prior history + user)", len(provider.lastMessages))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(provider.lastMessages[1:3], initialHistory) {
|
||||
t.Fatalf("provider history = %#v, want %#v", provider.lastMessages[1:3], initialHistory)
|
||||
}
|
||||
|
||||
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
||||
if lastMessage.Role != "user" || lastMessage.Content != "explain side effects" {
|
||||
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
|
||||
}
|
||||
|
||||
history := al.GetRegistry().GetDefaultAgent().Sessions.GetHistory(sessionKey)
|
||||
if !reflect.DeepEqual(history, initialHistory) {
|
||||
t.Fatalf("session history = %#v, want %#v", history, initialHistory)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandIncludesRequestContextAndMedia(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "discord",
|
||||
SenderID: "discord:123",
|
||||
Sender: bus.SenderInfo{
|
||||
DisplayName: "Alice",
|
||||
},
|
||||
ChatID: "group-1",
|
||||
Content: "/btw describe this image",
|
||||
Media: []string{"media://image-1"},
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(provider.lastMessages) == 0 {
|
||||
t.Fatal("provider did not receive any messages")
|
||||
}
|
||||
|
||||
systemPrompt := provider.lastMessages[0].Content
|
||||
if !strings.Contains(systemPrompt, "## Current Session\nChannel: discord\nChat ID: group-1") {
|
||||
t.Fatalf("system prompt missing current session context:\n%s", systemPrompt)
|
||||
}
|
||||
if !strings.Contains(systemPrompt, "## Current Sender\nCurrent sender: Alice (ID: discord:123)") {
|
||||
t.Fatalf("system prompt missing current sender context:\n%s", systemPrompt)
|
||||
}
|
||||
|
||||
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
||||
if lastMessage.Role != "user" || lastMessage.Content != "describe this image" {
|
||||
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
|
||||
}
|
||||
if !reflect.DeepEqual(lastMessage.Media, []string{"media://image-1"}) {
|
||||
t.Fatalf("last provider media = %#v, want media ref", lastMessage.Media)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandUsesIsolatedProvider(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
mainProvider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, mainProvider)
|
||||
var sideProvider *closeTrackingProvider
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
sideProvider = &closeTrackingProvider{}
|
||||
return sideProvider, "isolated-model", nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain isolation",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(mainProvider.lastMessages) != 0 {
|
||||
t.Fatalf("main provider was used for /btw: %+v", mainProvider.lastMessages)
|
||||
}
|
||||
if sideProvider == nil {
|
||||
t.Fatal("side question provider factory was not called")
|
||||
}
|
||||
if !sideProvider.closed {
|
||||
t.Fatal("isolated stateful /btw provider was not closed")
|
||||
}
|
||||
if len(sideProvider.lastMessages) == 0 {
|
||||
t.Fatal("isolated provider did not receive messages")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandRetriesWithoutMediaOnVisionUnsupported(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &visionUnsupportedMediaProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw describe this image",
|
||||
Media: []string{"data:image/png;base64,abc123"},
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "ok" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "ok")
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2)
|
||||
}
|
||||
if !slices.Equal(provider.mediaSeen, []bool{true, false}) {
|
||||
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandUsesProviderFactoryModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "lb-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "lb-model", Model: "openai/lb-model-a"},
|
||||
{ModelName: "lb-model", Model: "openai/lb-model-b"},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
var wantModel string
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
if mc == nil {
|
||||
t.Fatal("expected model config")
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(mc.Model)
|
||||
wantModel = "factory-" + modelID
|
||||
return provider, wantModel, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain load balancing",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if provider.lastModel != wantModel {
|
||||
t.Fatalf("/btw model = %q, want provider factory model %q", provider.lastModel, wantModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandHookModelBypassesFallbackCandidates(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"fallback-model"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "hook-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain hook routing",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if provider.lastModel != "hook-model" {
|
||||
t.Fatalf("/btw model = %q, want hook-selected model", provider.lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleCommand_UseCommandRejectsUnknownSkill(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
|
||||
Reference in New Issue
Block a user