mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
fix(agent): honor explicit thinking off (#2898)
* fix(agent): honor explicit thinking off * fix(agent): address thinking off lint failures * Clarify unset thinking level display * fix ci
This commit is contained in:
@@ -319,6 +319,7 @@ func (al *AgentLoop) buildCommandsRuntime(
|
||||
agent.Provider = nextProvider
|
||||
agent.Candidates = nextCandidates
|
||||
agent.ThinkingLevel = parseThinkingLevel(modelCfg.ThinkingLevel)
|
||||
agent.ThinkingLevelConfigured = isConfiguredThinkingLevel(modelCfg.ThinkingLevel)
|
||||
|
||||
if oldProvider != nil && oldProvider != nextProvider {
|
||||
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -138,6 +139,121 @@ func (r *recordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type thinkingRecordingProvider struct {
|
||||
lastOptions map[string]any
|
||||
}
|
||||
|
||||
func (r *thinkingRecordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastOptions = maps.Clone(opts)
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *thinkingRecordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func (r *thinkingRecordingProvider) SupportsThinking() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type thinkingOptionRecordingProvider struct {
|
||||
lastOptions map[string]any
|
||||
}
|
||||
|
||||
func (r *thinkingOptionRecordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastOptions = maps.Clone(opts)
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *thinkingOptionRecordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type reasoningOptionRecordingProvider struct {
|
||||
lastOptions map[string]any
|
||||
}
|
||||
|
||||
func (r *reasoningOptionRecordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastOptions = maps.Clone(opts)
|
||||
return &providers.LLMResponse{
|
||||
Content: "final answer",
|
||||
ReasoningContent: "thinking trace",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *reasoningOptionRecordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type reasoningResponseProvider struct{}
|
||||
|
||||
func (p *reasoningResponseProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ReasoningContent: "thinking trace",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *reasoningResponseProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type sideQuestionFallbackTestProvider struct {
|
||||
model string
|
||||
}
|
||||
|
||||
func (p *sideQuestionFallbackTestProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
if p.model == "test-model" {
|
||||
return nil, context.DeadlineExceeded
|
||||
}
|
||||
return &providers.LLMResponse{
|
||||
ReasoningContent: "thinking trace",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *sideQuestionFallbackTestProvider) GetDefaultModel() string {
|
||||
return p.model
|
||||
}
|
||||
|
||||
type modelRewriteHook struct {
|
||||
model string
|
||||
}
|
||||
@@ -386,6 +502,463 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_DoesNotPassImplicitThinkingOffToCapableProvider(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &thinkingRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if _, ok := provider.lastOptions["thinking_level"]; ok {
|
||||
t.Fatalf("thinking_level option should be omitted when unset, got %#v", provider.lastOptions["thinking_level"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_PassesExplicitThinkingOffToCapableProvider(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
provider := &thinkingRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if got := provider.lastOptions["thinking_level"]; got != "off" {
|
||||
t.Fatalf("thinking_level option = %#v, want %q", got, "off")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_PassesExplicitThinkingOffToProviderWithoutThinkingCapability(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
provider := &thinkingOptionRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if got := provider.lastOptions["thinking_level"]; got != "off" {
|
||||
t.Fatalf("thinking_level option = %#v, want %q", got, "off")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SuppressesReasoningWhenThinkingOff(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &reasoningResponseProvider{})
|
||||
|
||||
response, err := al.runAgentLoop(
|
||||
context.Background(),
|
||||
al.GetRegistry().GetDefaultAgent(),
|
||||
processOptions{
|
||||
SessionKey: "agent:main:pico:chat-1",
|
||||
Channel: "pico",
|
||||
ChatID: "chat-1",
|
||||
UserMessage: "hello",
|
||||
SendResponse: false,
|
||||
DefaultResponse: defaultResponse,
|
||||
NoHistory: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected no reasoning outbound when thinking is off, got %+v", outbound)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BeforeLLMModelRewriteReevaluatesThinkingOff(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "plain-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &reasoningOptionRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "off-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
SenderID: "user1",
|
||||
ChatID: "pico:test-session",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "final answer" {
|
||||
t.Fatalf("processMessage() response = %q, want final answer", response)
|
||||
}
|
||||
if got := provider.lastOptions["thinking_level"]; got != "off" {
|
||||
t.Fatalf("thinking_level option = %#v, want off after hook model rewrite", got)
|
||||
}
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected no reasoning outbound after hook rewrote to off model, got %+v", outbound)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BeforeLLMModelRewriteDoesNotLeakThinkingOff(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
ModelName: "off-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &reasoningOptionRecordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "plain-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "pico",
|
||||
SenderID: "user1",
|
||||
ChatID: "pico:test-session",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "final answer" {
|
||||
t.Fatalf("processMessage() response = %q, want final answer", response)
|
||||
}
|
||||
if _, ok := provider.lastOptions["thinking_level"]; ok {
|
||||
t.Fatalf(
|
||||
"thinking_level option should be cleared after hook rewrote away from off model, got %#v",
|
||||
provider.lastOptions["thinking_level"],
|
||||
)
|
||||
}
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "thinking trace" {
|
||||
t.Fatalf("reasoning outbound content = %q, want thinking trace", outbound.Content)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("expected reasoning outbound after hook rewrote away from off model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwCommandSuppressesReasoningWhenThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "test-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "test-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain privately",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if strings.Contains(response, "thinking trace") {
|
||||
t.Fatalf("processMessage() response = %q, should not expose reasoning with thinking off", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwHookModelRewriteReevaluatesThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "plain-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "plain-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "plain-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "off-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain privately",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if strings.Contains(response, "thinking trace") {
|
||||
t.Fatalf(
|
||||
"processMessage() response = %q, should not expose reasoning after hook rewrote to off model",
|
||||
response,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwHookModelRewriteDoesNotLeakThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "off-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "off-model",
|
||||
Model: "openai/off-model",
|
||||
ThinkingLevel: "off",
|
||||
},
|
||||
{
|
||||
ModelName: "plain-model",
|
||||
Model: "openai/plain-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "off-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "off-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
if err := al.MountHook(NamedHook("rewrite-model", modelRewriteHook{model: "plain-model"})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain privately",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "thinking trace" {
|
||||
t.Fatalf("processMessage() response = %q, want reasoning after hook rewrote away from off model", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_BtwFallbackDoesNotInheritPrimaryThinkingOff(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
ModelFallbacks: []string{"openai/fallback-model"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test-model",
|
||||
ThinkingLevel: "off",
|
||||
}},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &sideQuestionFallbackTestProvider{model: "test-model"})
|
||||
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
|
||||
model := ""
|
||||
if mc != nil {
|
||||
_, model = providers.ExtractProtocol(mc)
|
||||
}
|
||||
if model == "" {
|
||||
model = "test-model"
|
||||
}
|
||||
return &sideQuestionFallbackTestProvider{model: model}, model, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/btw explain fallback reasoning",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "thinking trace" {
|
||||
t.Fatalf("processMessage() response = %q, want fallback reasoning when fallback has no off", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
skillDir := filepath.Join(tmpDir, "skills", "shell")
|
||||
@@ -3178,6 +3751,300 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_FallbackReceivesExplicitThinkingOff(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackCalls := 0
|
||||
fallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fallbackCalls++
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("fallback server path = %q, want /chat/completions", r.URL.Path)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode fallback request: %v", err)
|
||||
}
|
||||
if got := req["model"]; got != "doubao-seed-1-6-flash-250828" {
|
||||
t.Fatalf("fallback request model = %#v, want doubao-seed-1-6-flash-250828", got)
|
||||
}
|
||||
thinking, ok := req["thinking"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("fallback request thinking = %#v, want map", req["thinking"])
|
||||
}
|
||||
if got := thinking["type"]; got != "disabled" {
|
||||
t.Fatalf("fallback request thinking.type = %#v, want disabled", got)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "fallback reply"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode fallback response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"doubao-fallback"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-fallback",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
ThinkingLevel: "off",
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want fallback reply", resp)
|
||||
}
|
||||
if fallbackCalls != 1 {
|
||||
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_PrimaryThinkingOffDoesNotLeakToFallback(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("fallback server path = %q, want /chat/completions", r.URL.Path)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode fallback request: %v", err)
|
||||
}
|
||||
if _, ok := req["thinking"]; ok {
|
||||
t.Fatalf("fallback request should not inherit primary thinking off, got thinking=%#v", req["thinking"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "fallback reply"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode fallback response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"doubao-fallback"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
ThinkingLevel: "off",
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-fallback",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want fallback reply", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_FallbackThinkingOffUsesCandidateIdentity(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("fallback server path = %q, want /chat/completions", r.URL.Path)
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode fallback request: %v", err)
|
||||
}
|
||||
thinking, ok := req["thinking"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("fallback request thinking = %#v, want map", req["thinking"])
|
||||
}
|
||||
if got := thinking["type"]; got != "disabled" {
|
||||
t.Fatalf("fallback request thinking.type = %#v, want disabled", got)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "fallback reply"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode fallback response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
ModelFallbacks: []string{"doubao-off"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-default",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "doubao-off",
|
||||
Model: "openai/doubao-seed-1-6-flash-250828",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
ThinkingLevel: "off",
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
})
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want fallback reply", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
|
||||
// that when a candidate has no model_list entry it is absent from CandidateProviders
|
||||
// and the fallback closure falls back to activeProvider instead of panicking.
|
||||
|
||||
@@ -560,6 +560,7 @@ func sideQuestionModelName(agent *AgentInstance, usedLight bool) string {
|
||||
}
|
||||
|
||||
func modelNameFromIdentityKey(identityKey string) string {
|
||||
identityKey = strings.TrimSpace(identityKey)
|
||||
if identityKey == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ type AgentInstance struct {
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
ThinkingLevel ThinkingLevel
|
||||
ThinkingLevelConfigured bool
|
||||
ContextWindow int
|
||||
SummarizeMessageThreshold int
|
||||
SummarizeTokenPercent int
|
||||
@@ -184,6 +185,7 @@ func NewAgentInstance(
|
||||
thinkingLevelStr = mc.ThinkingLevel
|
||||
}
|
||||
thinkingLevel := parseThinkingLevel(thinkingLevelStr)
|
||||
thinkingLevelConfigured := isConfiguredThinkingLevel(thinkingLevelStr)
|
||||
|
||||
summarizeMessageThreshold := defaults.SummarizeMessageThreshold
|
||||
if summarizeMessageThreshold == 0 {
|
||||
@@ -251,6 +253,7 @@ func NewAgentInstance(
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
ThinkingLevel: thinkingLevel,
|
||||
ThinkingLevelConfigured: thinkingLevelConfigured,
|
||||
ContextWindow: contextWindow,
|
||||
SummarizeMessageThreshold: summarizeMessageThreshold,
|
||||
SummarizeTokenPercent: summarizeTokenPercent,
|
||||
|
||||
+24
-13
@@ -73,14 +73,7 @@ func (p *Pipeline) CallLLM(
|
||||
if exec.useNativeSearch {
|
||||
exec.llmOpts["native_search"] = true
|
||||
}
|
||||
if ts.agent.ThinkingLevel != ThinkingOff {
|
||||
if tc, ok := ts.agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
|
||||
exec.llmOpts["thinking_level"] = string(ts.agent.ThinkingLevel)
|
||||
} else {
|
||||
logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring",
|
||||
map[string]any{"agent_id": ts.agent.ID, "thinking_level": string(ts.agent.ThinkingLevel)})
|
||||
}
|
||||
}
|
||||
applyTurnThinkingOptions(exec, ts.agent, exec.activeProvider, true)
|
||||
|
||||
exec.llmModel = exec.activeModel
|
||||
|
||||
@@ -105,6 +98,7 @@ func (p *Pipeline) CallLLM(
|
||||
exec.llmOpts = llmReq.Options
|
||||
if strings.TrimSpace(exec.llmModel) != "" && exec.llmModel != prevModel {
|
||||
p.applyBeforeLLMModelRewrite(ts, exec)
|
||||
applyTurnThinkingOptions(exec, ts.agent, exec.activeProvider, true)
|
||||
}
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
@@ -172,21 +166,33 @@ func (p *Pipeline) CallLLM(
|
||||
}
|
||||
|
||||
if len(exec.activeCandidates) > 1 && p.Fallback != nil {
|
||||
fbResult, fbErr := p.Fallback.Execute(
|
||||
fbResult, fbErr := p.Fallback.ExecuteCandidate(
|
||||
providerCtx,
|
||||
exec.activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
func(ctx context.Context, candidate providers.FallbackCandidate) (*providers.LLMResponse, error) {
|
||||
candidateProvider, err := providerForFallbackCandidate(
|
||||
ts.agent,
|
||||
exec.activeProvider,
|
||||
exec.activeCandidates,
|
||||
provider,
|
||||
model,
|
||||
candidate.Provider,
|
||||
candidate.Model,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, exec.llmOpts)
|
||||
callOpts := shallowCloneLLMOptions(exec.llmOpts)
|
||||
delete(callOpts, "thinking_level")
|
||||
candidateCfg := resolveActiveModelConfig(
|
||||
p.Cfg,
|
||||
ts.agent.Workspace,
|
||||
[]providers.FallbackCandidate{candidate},
|
||||
candidate.Model,
|
||||
p.Cfg.Agents.Defaults.Provider,
|
||||
)
|
||||
candidateThinking := thinkingSettingsFromModelConfig(candidateCfg)
|
||||
applyThinkingOption(callOpts, candidateProvider, candidateThinking, true, ts.agent.ID)
|
||||
exec.suppressReasoning = shouldSuppressReasoningFor(candidateThinking)
|
||||
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, candidate.Model, callOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
@@ -469,6 +475,11 @@ func (p *Pipeline) CallLLM(
|
||||
}
|
||||
}
|
||||
|
||||
if exec.suppressReasoning {
|
||||
exec.response.Reasoning = ""
|
||||
exec.response.ReasoningContent = ""
|
||||
exec.response.ReasoningDetails = nil
|
||||
}
|
||||
reasoningContent := responseReasoningContent(exec.response)
|
||||
shouldPublishPicoToolCallInterim := ts.channel == "pico" && len(exec.response.ToolCalls) > 0
|
||||
if shouldPublishPicoToolCallInterim {
|
||||
|
||||
@@ -85,7 +85,7 @@ func (p *Pipeline) tryConfiguredStreamingLLM(
|
||||
exec.llmOpts,
|
||||
func(chunk providers.StreamChunk) {
|
||||
recordChunk()
|
||||
if strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
if !exec.suppressReasoning && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
publisher.UpdateReasoning(ctx, chunk.ReasoningContent)
|
||||
}
|
||||
if strings.TrimSpace(chunk.Content) != "" {
|
||||
|
||||
@@ -517,6 +517,46 @@ func TestConfiguredStreamingStreamsPicoReasoningBeforeAnswerContent(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfiguredStreamingSuppressesPicoReasoningWhenThinkingOff(t *testing.T) {
|
||||
cfg := newConfiguredStreamingTestConfig(t, true, true, nil)
|
||||
cfg.ModelList[0].ThinkingLevel = "off"
|
||||
streamer := &recordingStreamer{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
msgBus.SetStreamDelegate(configuredStreamingDelegate{streamer: streamer})
|
||||
provider := &configuredStreamingProvider{
|
||||
eventPlan: []configuredStreamingEventCall{{
|
||||
chunks: []providers.StreamChunk{
|
||||
{ReasoningContent: "thinking"},
|
||||
{Content: "answer"},
|
||||
},
|
||||
response: &providers.LLMResponse{
|
||||
Content: "answer",
|
||||
ReasoningContent: "thinking",
|
||||
},
|
||||
}},
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
got := runConfiguredStreamingTurn(t, al, "pico")
|
||||
if got != "answer" {
|
||||
t.Fatalf("response = %q, want answer", got)
|
||||
}
|
||||
if len(streamer.reasoningUpdates) != 0 {
|
||||
t.Fatalf("reasoning updates = %v, want none when thinking is off", streamer.reasoningUpdates)
|
||||
}
|
||||
if len(streamer.reasoningFinalized) != 0 {
|
||||
t.Fatalf("reasoning finalized = %v, want none when thinking is off", streamer.reasoningFinalized)
|
||||
}
|
||||
if len(streamer.updates) != 1 || streamer.updates[0] != "answer" {
|
||||
t.Fatalf("content updates = %v, want [answer]", streamer.updates)
|
||||
}
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected no reasoning outbound when thinking is off, got %+v", outbound)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfiguredStreamingFinalFlushFailureAfterVisibleOutputReturnsErrorWithoutFallbackOrCancel(t *testing.T) {
|
||||
cfg := newConfiguredStreamingTestConfig(t, true, true, nil)
|
||||
streamer := &failingFinalizeStreamer{err: errors.New("final failed")}
|
||||
|
||||
+91
-1
@@ -1,6 +1,12 @@
|
||||
package agent
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ThinkingLevel controls how the provider sends thinking parameters.
|
||||
//
|
||||
@@ -37,3 +43,87 @@ func parseThinkingLevel(level string) ThinkingLevel {
|
||||
return ThinkingOff
|
||||
}
|
||||
}
|
||||
|
||||
func isConfiguredThinkingLevel(level string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "off", "low", "medium", "high", "xhigh", "adaptive":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type thinkingSettings struct {
|
||||
level ThinkingLevel
|
||||
configured bool
|
||||
}
|
||||
|
||||
func thinkingSettingsFromModelConfig(mc *config.ModelConfig) thinkingSettings {
|
||||
if mc == nil || !isConfiguredThinkingLevel(mc.ThinkingLevel) {
|
||||
return thinkingSettings{}
|
||||
}
|
||||
return thinkingSettings{
|
||||
level: parseThinkingLevel(mc.ThinkingLevel),
|
||||
configured: true,
|
||||
}
|
||||
}
|
||||
|
||||
func activeThinkingSettings(agent *AgentInstance, modelCfg *config.ModelConfig) thinkingSettings {
|
||||
if settings := thinkingSettingsFromModelConfig(modelCfg); settings.configured {
|
||||
return settings
|
||||
}
|
||||
if modelCfg == nil && agent != nil {
|
||||
return thinkingSettings{
|
||||
level: agent.ThinkingLevel,
|
||||
configured: agent.ThinkingLevelConfigured,
|
||||
}
|
||||
}
|
||||
return thinkingSettings{}
|
||||
}
|
||||
|
||||
func applyThinkingOption(
|
||||
opts map[string]any,
|
||||
provider providers.LLMProvider,
|
||||
settings thinkingSettings,
|
||||
warnUnsupported bool,
|
||||
agentID string,
|
||||
) {
|
||||
if opts == nil || !settings.configured {
|
||||
return
|
||||
}
|
||||
if settings.level == ThinkingOff {
|
||||
opts["thinking_level"] = string(settings.level)
|
||||
return
|
||||
}
|
||||
if tc, ok := provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
|
||||
opts["thinking_level"] = string(settings.level)
|
||||
return
|
||||
}
|
||||
if warnUnsupported {
|
||||
logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring",
|
||||
map[string]any{"agent_id": agentID, "thinking_level": string(settings.level)})
|
||||
}
|
||||
}
|
||||
|
||||
func applyTurnThinkingOptions(
|
||||
exec *turnExecution,
|
||||
agent *AgentInstance,
|
||||
provider providers.LLMProvider,
|
||||
warnUnsupported bool,
|
||||
) {
|
||||
if exec == nil || exec.llmOpts == nil {
|
||||
return
|
||||
}
|
||||
delete(exec.llmOpts, "thinking_level")
|
||||
settings := activeThinkingSettings(agent, exec.activeModelConfig)
|
||||
agentID := ""
|
||||
if agent != nil {
|
||||
agentID = agent.ID
|
||||
}
|
||||
applyThinkingOption(exec.llmOpts, provider, settings, warnUnsupported, agentID)
|
||||
exec.suppressReasoning = shouldSuppressReasoningFor(settings)
|
||||
}
|
||||
|
||||
func shouldSuppressReasoningFor(settings thinkingSettings) bool {
|
||||
return settings.configured && settings.level == ThinkingOff
|
||||
}
|
||||
|
||||
+49
-23
@@ -423,6 +423,7 @@ func (al *AgentLoop) askSideQuestion(
|
||||
}
|
||||
|
||||
hookModelChanged := false
|
||||
sideSuppressReasoning := false
|
||||
callProvider := func(
|
||||
ctx context.Context,
|
||||
candidate providers.FallbackCandidate,
|
||||
@@ -430,7 +431,15 @@ func (al *AgentLoop) askSideQuestion(
|
||||
forceModel bool,
|
||||
callMessages []providers.Message,
|
||||
) (*providers.LLMResponse, error) {
|
||||
provider, providerModel, cleanup, err := al.isolatedSideQuestionProvider(agent, selectedModelName, candidate)
|
||||
baseModelName := selectedModelName
|
||||
if forceModel && strings.TrimSpace(model) != "" {
|
||||
baseModelName = model
|
||||
}
|
||||
provider, providerModel, modelCfg, cleanup, err := al.isolatedSideQuestionProvider(
|
||||
agent,
|
||||
baseModelName,
|
||||
candidate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -439,10 +448,12 @@ func (al *AgentLoop) askSideQuestion(
|
||||
model = providerModel
|
||||
}
|
||||
callOpts := llmOpts
|
||||
if _, exists := callOpts["thinking_level"]; !exists && agent.ThinkingLevel != ThinkingOff {
|
||||
if tc, ok := provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
|
||||
settings := thinkingSettingsFromModelConfig(modelCfg)
|
||||
sideSuppressReasoning = shouldSuppressReasoningFor(settings)
|
||||
if _, exists := callOpts["thinking_level"]; !exists {
|
||||
if settings.configured {
|
||||
callOpts = shallowCloneLLMOptions(llmOpts)
|
||||
callOpts["thinking_level"] = string(agent.ThinkingLevel)
|
||||
applyThinkingOption(callOpts, provider, settings, false, agent.ID)
|
||||
}
|
||||
}
|
||||
return provider.Chat(ctx, callMessages, nil, model, callOpts)
|
||||
@@ -500,18 +511,11 @@ func (al *AgentLoop) askSideQuestion(
|
||||
|
||||
callSideLLM := func(callMessages []providers.Message) (*providers.LLMResponse, error) {
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, err := al.fallback.Execute(
|
||||
fbResult, err := al.fallback.ExecuteCandidate(
|
||||
ctx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, providerName, model string) (*providers.LLMResponse, error) {
|
||||
candidate := providers.FallbackCandidate{Provider: providerName, Model: model}
|
||||
for _, activeCandidate := range activeCandidates {
|
||||
if activeCandidate.Provider == providerName && activeCandidate.Model == model {
|
||||
candidate = activeCandidate
|
||||
break
|
||||
}
|
||||
}
|
||||
return callProvider(ctx, candidate, model, false, callMessages)
|
||||
func(ctx context.Context, candidate providers.FallbackCandidate) (*providers.LLMResponse, error) {
|
||||
return callProvider(ctx, candidate, candidate.Model, false, callMessages)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -584,6 +588,11 @@ func (al *AgentLoop) askSideQuestion(
|
||||
return "", fmt.Errorf("hook aborted turn during after_llm: %s", reason)
|
||||
}
|
||||
}
|
||||
if sideSuppressReasoning {
|
||||
resp.Reasoning = ""
|
||||
resp.ReasoningContent = ""
|
||||
resp.ReasoningDetails = nil
|
||||
}
|
||||
|
||||
return sideQuestionResponseContent(resp), nil
|
||||
}
|
||||
@@ -592,14 +601,14 @@ func (al *AgentLoop) isolatedSideQuestionProvider(
|
||||
agent *AgentInstance,
|
||||
baseModelName string,
|
||||
candidate providers.FallbackCandidate,
|
||||
) (providers.LLMProvider, string, func(), error) {
|
||||
) (providers.LLMProvider, string, *config.ModelConfig, func(), error) {
|
||||
if agent == nil {
|
||||
return nil, "", func() {}, fmt.Errorf("isolatedSideQuestionProvider: no agent available for /btw")
|
||||
return nil, "", nil, func() {}, fmt.Errorf("isolatedSideQuestionProvider: no agent available for /btw")
|
||||
}
|
||||
|
||||
modelCfg, err := al.sideQuestionModelConfig(agent, baseModelName, candidate)
|
||||
if err != nil {
|
||||
return nil, "", func() {}, fmt.Errorf("isolatedSideQuestionProvider: %w", err)
|
||||
return nil, "", nil, func() {}, fmt.Errorf("isolatedSideQuestionProvider: %w", err)
|
||||
}
|
||||
|
||||
factory := al.providerFactory
|
||||
@@ -608,13 +617,13 @@ func (al *AgentLoop) isolatedSideQuestionProvider(
|
||||
}
|
||||
provider, modelID, err := factory(modelCfg)
|
||||
if err != nil {
|
||||
return nil, "", func() {}, fmt.Errorf("isolatedSideQuestionProvider: %w", err)
|
||||
return nil, "", nil, func() {}, fmt.Errorf("isolatedSideQuestionProvider: %w", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
closeProviderIfStateful(provider)
|
||||
}
|
||||
return provider, modelID, cleanup, nil
|
||||
return provider, modelID, modelCfg, cleanup, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) sideQuestionModelConfig(
|
||||
@@ -626,7 +635,15 @@ func (al *AgentLoop) sideQuestionModelConfig(
|
||||
return nil, fmt.Errorf("sideQuestionModelConfig: no agent available for /btw")
|
||||
}
|
||||
|
||||
// If candidate has an identity key, use that
|
||||
if name := modelAliasFromCandidateIdentityKey(candidate.IdentityKey); name != "" {
|
||||
modelCfg, err := resolvedModelConfig(al.GetConfig(), name, agent.Workspace)
|
||||
if err == nil {
|
||||
return modelCfg, nil
|
||||
}
|
||||
// Fallback: create a minimal config if lookup fails
|
||||
}
|
||||
|
||||
// Older identity keys used provider/model; keep resolving those by model.
|
||||
if name := modelNameFromIdentityKey(candidate.IdentityKey); name != "" {
|
||||
modelCfg, err := resolvedModelConfig(al.GetConfig(), name, agent.Workspace)
|
||||
if err == nil {
|
||||
@@ -635,6 +652,18 @@ func (al *AgentLoop) sideQuestionModelConfig(
|
||||
// Fallback: create a minimal config if lookup fails
|
||||
}
|
||||
|
||||
if candidate.Provider != "" && candidate.Model != "" {
|
||||
candidateRef := providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
|
||||
if modelCfg, err := resolvedModelConfig(al.GetConfig(), candidateRef, agent.Workspace); err == nil {
|
||||
return modelCfg, nil
|
||||
}
|
||||
return &config.ModelConfig{
|
||||
ModelName: candidateRef,
|
||||
Model: candidateRef,
|
||||
Workspace: agent.Workspace,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Otherwise, clean up the base model name and use it
|
||||
baseModelName = strings.TrimSpace(baseModelName)
|
||||
modelCfg, err := resolvedModelConfig(al.GetConfig(), baseModelName, agent.Workspace)
|
||||
@@ -658,8 +687,5 @@ func (al *AgentLoop) sideQuestionModelConfig(
|
||||
|
||||
// If candidate specifies a different provider/model, override
|
||||
clone := *modelCfg
|
||||
if candidate.Provider != "" && candidate.Model != "" {
|
||||
clone.Model = providers.NormalizeProvider(candidate.Provider) + "/" + candidate.Model
|
||||
}
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
@@ -138,6 +138,7 @@ type turnExecution struct {
|
||||
allResponsesHandled bool
|
||||
streamingPublisher *streamingChunkPublisher
|
||||
streamingFallback bool
|
||||
suppressReasoning bool
|
||||
callMessages []providers.Message
|
||||
providerToolDefs []providers.ToolDefinition
|
||||
llmModel string
|
||||
|
||||
@@ -118,6 +118,22 @@ func (fc *FallbackChain) Execute(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
return fc.ExecuteCandidate(
|
||||
ctx,
|
||||
candidates,
|
||||
func(ctx context.Context, candidate FallbackCandidate) (*LLMResponse, error) {
|
||||
return run(ctx, candidate.Provider, candidate.Model)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// ExecuteCandidate runs the fallback chain and passes the complete candidate
|
||||
// to the caller so model-list identity metadata remains available.
|
||||
func (fc *FallbackChain) ExecuteCandidate(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, candidate FallbackCandidate) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("fallback: no candidates configured")
|
||||
@@ -181,7 +197,7 @@ func (fc *FallbackChain) Execute(
|
||||
|
||||
// Execute the run function.
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
resp, err := run(ctx, candidate)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
|
||||
@@ -193,6 +193,8 @@ func (p *Provider) buildRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
p.applyThinkingControl(requestBody, model, options)
|
||||
|
||||
// Merge extra body fields configured per-provider/model.
|
||||
// These are injected last so they take precedence over defaults.
|
||||
maps.Copy(requestBody, p.extraBody)
|
||||
@@ -200,6 +202,81 @@ func (p *Provider) buildRequestBody(
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func (p *Provider) applyThinkingControl(requestBody map[string]any, model string, options map[string]any) {
|
||||
level, ok := normalizedThinkingLevel(options)
|
||||
if !ok || level != "off" {
|
||||
return
|
||||
}
|
||||
|
||||
switch p.thinkingControlKind(model) {
|
||||
case "thinking_type":
|
||||
requestBody["thinking"] = map[string]any{"type": "disabled"}
|
||||
case "enable_thinking":
|
||||
requestBody["enable_thinking"] = false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizedThinkingLevel(options map[string]any) (string, bool) {
|
||||
raw, ok := options["thinking_level"].(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "off", "low", "medium", "high", "xhigh", "adaptive":
|
||||
return strings.ToLower(strings.TrimSpace(raw)), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) thinkingControlKind(model string) string {
|
||||
providerName := strings.ToLower(strings.TrimSpace(p.providerName))
|
||||
lowerModel := strings.ToLower(strings.TrimSpace(model))
|
||||
|
||||
switch providerName {
|
||||
case "volcengine":
|
||||
return "thinking_type"
|
||||
case "zhipu", "zai":
|
||||
return "thinking_type"
|
||||
case "qwen", "qwen-portal", "qwen-intl", "qwen-international", "dashscope-intl", "qwen-us", "dashscope-us":
|
||||
return "enable_thinking"
|
||||
case "modelscope":
|
||||
if strings.Contains(lowerModel, "qwen") {
|
||||
return "enable_thinking"
|
||||
}
|
||||
}
|
||||
|
||||
if providerName == "openai" || providerName == "" {
|
||||
if isVolcengineHost(p.apiBase) || strings.Contains(lowerModel, "doubao") {
|
||||
return "thinking_type"
|
||||
}
|
||||
if isDashScopeHost(p.apiBase) || strings.Contains(lowerModel, "qwen") {
|
||||
return "enable_thinking"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func isVolcengineHost(apiBase string) bool {
|
||||
host := normalizedHostname(apiBase)
|
||||
return host == "volcengine.com" || strings.HasSuffix(host, ".volcengine.com") ||
|
||||
host == "volces.com" || strings.HasSuffix(host, ".volces.com")
|
||||
}
|
||||
|
||||
func isDashScopeHost(apiBase string) bool {
|
||||
host := normalizedHostname(apiBase)
|
||||
return host == "dashscope.aliyuncs.com" || strings.HasSuffix(host, ".dashscope.aliyuncs.com")
|
||||
}
|
||||
|
||||
func normalizedHostname(rawURL string) string {
|
||||
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
}
|
||||
|
||||
func (p *Provider) applyCustomHeaders(req *http.Request) {
|
||||
for k, v := range p.customHeaders {
|
||||
if strings.TrimSpace(k) == "" {
|
||||
|
||||
@@ -62,6 +62,69 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestBody_DisablesDoubaoThinkingWhenThinkingLevelOff(t *testing.T) {
|
||||
p := NewProvider("key", "https://ark.cn-beijing.volces.com/api/v3", "")
|
||||
p.SetProviderName("openai")
|
||||
|
||||
body := p.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"doubao-seed-1-6-flash-250828",
|
||||
map[string]any{"thinking_level": "off"},
|
||||
)
|
||||
|
||||
thinking, ok := body["thinking"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("thinking = %#v, want map", body["thinking"])
|
||||
}
|
||||
if got := thinking["type"]; got != "disabled" {
|
||||
t.Fatalf("thinking.type = %#v, want %q", got, "disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestBody_DisablesModelDependentQwenThinkingWhenThinkingLevelOff(t *testing.T) {
|
||||
p := NewProvider("key", "https://api-inference.modelscope.cn/v1", "")
|
||||
p.SetProviderName("modelscope")
|
||||
|
||||
body := p.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"qwen3-coder-plus",
|
||||
map[string]any{"thinking_level": "off"},
|
||||
)
|
||||
|
||||
if got := body["enable_thinking"]; got != false {
|
||||
t.Fatalf("enable_thinking = %#v, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestBody_PreservesDoubaoRequestWhenThinkingLevelIsNotOff(t *testing.T) {
|
||||
p := NewProvider("key", "https://ark.cn-beijing.volces.com/api/v3", "")
|
||||
p.SetProviderName("openai")
|
||||
|
||||
for _, level := range []string{"low", "adaptive", "unexpected"} {
|
||||
t.Run(level, func(t *testing.T) {
|
||||
body := p.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"doubao-seed-1-6-flash-250828",
|
||||
map[string]any{"thinking_level": level},
|
||||
)
|
||||
|
||||
if _, ok := body["thinking"]; ok {
|
||||
t.Fatalf(
|
||||
"thinking should be omitted for %q to preserve existing behavior, got %#v",
|
||||
level,
|
||||
body["thinking"],
|
||||
)
|
||||
}
|
||||
if _, ok := body["enable_thinking"]; ok {
|
||||
t.Fatalf("enable_thinking should be omitted for %q, got %#v", level, body["enable_thinking"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ParsesToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
|
||||
@@ -720,7 +720,7 @@ export function AddModelSheet({
|
||||
<Input
|
||||
value={form.thinkingLevel}
|
||||
onChange={setField("thinkingLevel")}
|
||||
placeholder="off"
|
||||
placeholder={t("models.field.providerDefault")}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
|
||||
@@ -690,7 +690,7 @@ export function EditModelSheet({
|
||||
<Input
|
||||
value={form.thinkingLevel}
|
||||
onChange={setField("thinkingLevel")}
|
||||
placeholder="off"
|
||||
placeholder={t("models.field.providerDefault")}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
|
||||
@@ -313,7 +313,8 @@
|
||||
"rpm": "Rate Limit (RPM)",
|
||||
"rpmHint": "Maximum requests per minute. 0 = no limit.",
|
||||
"thinkingLevel": "Thinking Level",
|
||||
"thinkingLevelHint": "Extended thinking budget: off, low, medium, high, xhigh, adaptive.",
|
||||
"thinkingLevelHint": "Leave blank to omit thinking_level and use the provider default. Values: off, low, medium, high, xhigh, adaptive.",
|
||||
"providerDefault": "provider default",
|
||||
"maxTokensField": "Max Tokens Field",
|
||||
"maxTokensFieldHint": "Override the request field name for max tokens, e.g. max_completion_tokens.",
|
||||
"toolSchemaTransform": "Tool Schema Transform",
|
||||
|
||||
@@ -312,7 +312,8 @@
|
||||
"rpm": "Limite de Taxa (RPM)",
|
||||
"rpmHint": "Máximo de requisições por minuto. 0 = sem limite.",
|
||||
"thinkingLevel": "Nível de Pensamento",
|
||||
"thinkingLevelHint": "Orçamento de pensamento estendido: off, low, medium, high, xhigh, adaptive.",
|
||||
"thinkingLevelHint": "Deixe em branco para omitir thinking_level e usar o padrão do provider. Valores: off, low, medium, high, xhigh, adaptive.",
|
||||
"providerDefault": "padrão do provider",
|
||||
"maxTokensField": "Campo de Max Tokens",
|
||||
"maxTokensFieldHint": "Sobrescreve o nome do campo de max tokens na requisição, ex: max_completion_tokens.",
|
||||
"toolSchemaTransform": "Transformação de Schema de Ferramentas",
|
||||
|
||||
@@ -313,7 +313,8 @@
|
||||
"rpm": "速率限制(RPM)",
|
||||
"rpmHint": "每分钟最大请求数,0 表示不限制。",
|
||||
"thinkingLevel": "思考级别",
|
||||
"thinkingLevelHint": "扩展思考预算:off、low、medium、high、xhigh、adaptive。",
|
||||
"thinkingLevelHint": "留空则不传 thinking_level,使用 Provider 默认值。可选值:off、low、medium、high、xhigh、adaptive。",
|
||||
"providerDefault": "Provider 默认",
|
||||
"maxTokensField": "Max Tokens 字段名",
|
||||
"maxTokensFieldHint": "覆盖请求中 max_tokens 的字段名,例如 max_completion_tokens。",
|
||||
"toolSchemaTransform": "工具 Schema 转换",
|
||||
|
||||
Reference in New Issue
Block a user