mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): honor explicit thinking off (#2898)
* fix(agent): honor explicit thinking off * fix(agent): address thinking off lint failures * Clarify unset thinking level display * fix ci
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -138,6 +139,121 @@ func (r *recordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type thinkingRecordingProvider struct {
|
||||
lastOptions map[string]any
|
||||
}
|
||||
|
||||
func (r *thinkingRecordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastOptions = maps.Clone(opts)
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *thinkingRecordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func (r *thinkingRecordingProvider) SupportsThinking() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type thinkingOptionRecordingProvider struct {
|
||||
lastOptions map[string]any
|
||||
}
|
||||
|
||||
func (r *thinkingOptionRecordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastOptions = maps.Clone(opts)
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *thinkingOptionRecordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type reasoningOptionRecordingProvider struct {
|
||||
lastOptions map[string]any
|
||||
}
|
||||
|
||||
func (r *reasoningOptionRecordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastOptions = maps.Clone(opts)
|
||||
return &providers.LLMResponse{
|
||||
Content: "final answer",
|
||||
ReasoningContent: "thinking trace",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *reasoningOptionRecordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type reasoningResponseProvider struct{}
|
||||
|
||||
func (p *reasoningResponseProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ReasoningContent: "thinking trace",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *reasoningResponseProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type sideQuestionFallbackTestProvider struct {
|
||||
model string
|
||||
}
|
||||
|
||||
func (p *sideQuestionFallbackTestProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
if p.model == "test-model" {
|
||||
return nil, context.DeadlineExceeded
|
||||
}
|
||||
return &providers.LLMResponse{
|
||||
ReasoningContent: "thinking trace",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *sideQuestionFallbackTestProvider) GetDefaultModel() string {
|
||||
return p.model
|
||||
}
|
||||
|
||||
type modelRewriteHook struct {
|
||||
model string
|
||||
}
|
||||
@@ -386,6 +502,463 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_DoesNotPassImplicitThinkingOffToCapableProvider(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &thinkingRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if _, ok := provider.lastOptions["thinking_level"]; ok {
|
||||
t.Fatalf("thinking_level option should be omitted when unset, got %#v", provider.lastOptions["thinking_level"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_PassesExplicitThinkingOffToCapableProvider(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
provider := &thinkingRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if got := provider.lastOptions["thinking_level"]; got != "off" {
|
||||
t.Fatalf("thinking_level option = %#v, want %q", got, "off")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_PassesExplicitThinkingOffToProviderWithoutThinkingCapability(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
provider := &thinkingOptionRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if got := provider.lastOptions["thinking_level"]; got != "off" {
|
||||
t.Fatalf("thinking_level option = %#v, want %q", got, "off")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SuppressesReasoningWhenThinkingOff(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &reasoningResponseProvider{})
|
||||
|
||||
response, err := al.runAgentLoop(
|
||||
context.Background(),
|
||||
al.GetRegistry().GetDefaultAgent(),
|
||||
processOptions{
|
||||
SessionKey: "agent:main:pico:chat-1",
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
UserMessage: "hello",
|
||||
SendResponse: false,
|
||||
DefaultResponse: defaultResponse,
|
||||
NoHistory: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected no reasoning outbound when thinking is off, got %+v", outbound)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BeforeLLMModelRewriteReevaluatesThinkingOff(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "plain-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &reasoningOptionRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "off-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
SenderID: "user1",
|
||||
ChatID: "pico:test-session",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "final answer" {
|
||||
t.Fatalf("processMessage() response = %q, want final answer", response)
|
||||
}
|
||||
if got := provider.lastOptions["thinking_level"]; got != "off" {
|
||||
t.Fatalf("thinking_level option = %#v, want off after hook model rewrite", got)
|
||||
}
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected no reasoning outbound after hook rewrote to off model, got %+v", outbound)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BeforeLLMModelRewriteDoesNotLeakThinkingOff(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "off-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &reasoningOptionRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "plain-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
SenderID: "user1",
|
||||
ChatID: "pico:test-session",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "final answer" {
|
||||
t.Fatalf("processMessage() response = %q, want final answer", response)
|
||||
}
|
||||
if _, ok := provider.lastOptions["thinking_level"]; ok {
|
||||
t.Fatalf(
|
||||
"thinking_level option should be cleared after hook rewrote away from off model, got %#v",
|
||||
provider.lastOptions["thinking_level"],
|
||||
)
|
||||
}
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "thinking trace" {
|
||||
t.Fatalf("reasoning outbound content = %q, want thinking trace", outbound.Content)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("expected reasoning outbound after hook rewrote away from off model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandSuppressesReasoningWhenThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "test-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "test-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain privately",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if strings.Contains(response, "thinking trace") {
|
||||
t.Fatalf("processMessage() response = %q, should not expose reasoning with thinking off", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwHookModelRewriteReevaluatesThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "plain-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "plain-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "plain-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "off-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain privately",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if strings.Contains(response, "thinking trace") {
|
||||
t.Fatalf(
|
||||
"processMessage() response = %q, should not expose reasoning after hook rewrote to off model",
|
||||
response,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwHookModelRewriteDoesNotLeakThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "off-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "off-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "off-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "plain-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain privately",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "thinking trace" {
|
||||
t.Fatalf("processMessage() response = %q, want reasoning after hook rewrote away from off model", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwFallbackDoesNotInheritPrimaryThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
ModelFallbacks: []string{"openai/fallback-model"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "test-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "test-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain fallback reasoning",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "thinking trace" {
|
||||
t.Fatalf("processMessage() response = %q, want fallback reasoning when fallback has no off", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
skillDir := filepath.Join(tmpDir, "skills", "shell")
|
||||
@@ -3178,6 +3751,300 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_FallbackReceivesExplicitThinkingOff(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackCalls := 0
|
||||
fallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fallbackCalls++
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("fallback server path = %q, want /chat/completions", r.URL.Path)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode fallback request: %v", err)
|
||||
}
|
||||
if got := req["model"]; got != "doubao-seed-1-6-flash-250828" {
|
||||
t.Fatalf("fallback request model = %#v, want doubao-seed-1-6-flash-250828", got)
|
||||
}
|
||||
thinking, ok := req["thinking"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("fallback request thinking = %#v, want map", req["thinking"])
|
||||
}
|
||||
if got := thinking["type"]; got != "disabled" {
|
||||
t.Fatalf("fallback request thinking.type = %#v, want disabled", got)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "fallback reply"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode fallback response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"doubao-fallback"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-fallback",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
ThinkingLevel: "off",
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want fallback reply", resp)
|
||||
}
|
||||
if fallbackCalls != 1 {
|
||||
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_PrimaryThinkingOffDoesNotLeakToFallback(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("fallback server path = %q, want /chat/completions", r.URL.Path)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode fallback request: %v", err)
|
||||
}
|
||||
if _, ok := req["thinking"]; ok {
|
||||
t.Fatalf("fallback request should not inherit primary thinking off, got thinking=%#v", req["thinking"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "fallback reply"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode fallback response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"doubao-fallback"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
ThinkingLevel: "off",
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-fallback",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want fallback reply", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_FallbackThinkingOffUsesCandidateIdentity(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("fallback server path = %q, want /chat/completions", r.URL.Path)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode fallback request: %v", err)
|
||||
}
|
||||
thinking, ok := req["thinking"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("fallback request thinking = %#v, want map", req["thinking"])
|
||||
}
|
||||
if got := thinking["type"]; got != "disabled" {
|
||||
t.Fatalf("fallback request thinking.type = %#v, want disabled", got)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "fallback reply"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode fallback response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"doubao-off"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-default",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-off",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
ThinkingLevel: "off",
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want fallback reply", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
|
||||
// that when a candidate has no model_list entry it is absent from CandidateProviders
|
||||
// and the fallback closure falls back to activeProvider instead of panicking.
|
||||
|
||||
Reference in New Issue
Block a user