diff --git a/pkg/config/config.go b/pkg/config/config.go index eab770991..c4f1e751f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -674,10 +674,11 @@ type ModelConfig struct { Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers // Optional optimizations - RPM int `json:"rpm,omitempty"` // Requests per minute limit - MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") - RequestTimeout int `json:"request_timeout,omitempty"` - ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive + RPM int `json:"rpm,omitempty"` // Requests per minute limit + MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") + RequestTimeout int `json:"request_timeout,omitempty"` + ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive + ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body } // Validate checks if the ModelConfig has all required fields. diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 45906ee70..678f02000 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1099,3 +1099,59 @@ func TestConfigLogLevelEmpty(t *testing.T) { t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Agents.Defaults.LogLevel) } } + +func TestDefaultConfig_MinimaxExtraBody(t *testing.T) { + cfg := DefaultConfig() + + var minimaxCfg *ModelConfig + for i := range cfg.ModelList { + if cfg.ModelList[i].Model == "minimax/MiniMax-M2.5" { + minimaxCfg = &cfg.ModelList[i] + break + } + } + if minimaxCfg == nil { + t.Fatal("Minimax model not found in ModelList") + } + if minimaxCfg.ExtraBody == nil { + t.Fatal("Minimax ExtraBody should not be nil") + } + if got, ok := minimaxCfg.ExtraBody["reasoning_split"]; !ok || got != true { + t.Fatalf("Minimax ExtraBody[reasoning_split] = %v, want true", got) + } +} + +func TestModelConfig_ExtraBodyRoundTrip(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + + cfg := &Config{ + ModelList: []ModelConfig{ + { + ModelName: "test-model", + Model: "openai/test", + APIKey: "sk-test", + ExtraBody: map[string]any{"custom_field": "value", "num_field": 42}, + }, + }, + } + + if err := SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("SaveConfig error: %v", err) + } + + loaded, err := LoadConfig(cfgPath) + if err != nil { + t.Fatalf("LoadConfig error: %v", err) + } + + if loaded.ModelList[0].ExtraBody == nil { + t.Fatal("ExtraBody should not be nil after round-trip") + } + if got := loaded.ModelList[0].ExtraBody["custom_field"]; got != "value" { + t.Errorf("ExtraBody[custom_field] = %v, want value", got) + } + if got := loaded.ModelList[0].ExtraBody["num_field"]; got != float64(42) { + t.Errorf("ExtraBody[num_field] = %v, want 42", got) + } +} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index f4056eca6..d96b139d1 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -376,6 +376,7 @@ func DefaultConfig() *Config { Model: "minimax/MiniMax-M2.5", APIBase: "https://api.minimaxi.com/v1", APIKey: "", + ExtraBody: map[string]any{"reasoning_split": true}, }, // LongCat - https://longcat.chat/platform diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index a7fef8f5b..98e781da3 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -93,6 +93,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.Proxy, cfg.MaxTokensField, cfg.RequestTimeout, + cfg.ExtraBody, ), modelID, nil case "azure", "azure-openai": @@ -132,6 +133,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.Proxy, cfg.MaxTokensField, cfg.RequestTimeout, + cfg.ExtraBody, ), modelID, nil case "anthropic": @@ -157,6 +159,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.Proxy, cfg.MaxTokensField, cfg.RequestTimeout, + cfg.ExtraBody, ), modelID, nil case "anthropic-messages": diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 803165edb..f2ff52f1d 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -24,12 +24,13 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { } func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider { - return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0) + return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil) } func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( apiKey, apiBase, proxy, maxTokensField string, requestTimeoutSeconds int, + extraBody map[string]any, ) *HTTPProvider { return &HTTPProvider{ delegate: openai_compat.NewProvider( @@ -38,6 +39,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout( proxy, openai_compat.WithMaxTokensField(maxTokensField), openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second), + openai_compat.WithExtraBody(extraBody), ), } } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 938e4ea8b..90bc683b8 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -35,6 +35,7 @@ type Provider struct { apiBase string maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models) httpClient *http.Client + extraBody map[string]any // Additional fields to inject into request body } type Option func(*Provider) @@ -55,6 +56,12 @@ func WithRequestTimeout(timeout time.Duration) Option { } } +func WithExtraBody(extraBody map[string]any) Option { + return func(p *Provider) { + p.extraBody = extraBody + } +} + func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider { p := &Provider{ apiKey: apiKey, @@ -140,6 +147,12 @@ func (p *Provider) buildRequestBody( } } + // Merge extra body fields configured per-provider/model. + // These are injected last so they take precedence over defaults. + for k, v := range p.extraBody { + requestBody[k] = v + } + return requestBody } diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index efb03ccb8..ab632ccf3 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -610,6 +610,90 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) { } } +func TestProviderChat_ExtraBodyInjected(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + extraBody := map[string]any{"reasoning_split": true, "custom_field": "test"} + p := NewProvider("key", server.URL, "", WithExtraBody(extraBody)) + + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "minimax/abab7", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if got, ok := requestBody["reasoning_split"]; !ok || got != true { + t.Fatalf("reasoning_split = %v, want true", got) + } + if got, ok := requestBody["custom_field"]; !ok || got != "test" { + t.Fatalf("custom_field = %v, want test", got) + } +} + +func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + extraBody := map[string]any{"temperature": 0.9} + p := NewProvider("key", server.URL, "", WithExtraBody(extraBody)) + + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "gpt-4o", + map[string]any{"temperature": 0.5}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + // ExtraBody takes precedence over options since it is merged last. + if got := requestBody["temperature"]; got != float64(0.9) { + t.Fatalf("temperature = %v, want 0.9 (from extraBody, overriding options)", got) + } +} + type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {