mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(providers): add extra_body config to inject custom fields into request body
Allow configuring provider-specific fields like reasoning_split for minimax via the model config's extra_body map. These fields are merged into the request body last, giving them precedence over default values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -674,10 +674,11 @@ type ModelConfig struct {
|
|||||||
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
||||||
|
|
||||||
// Optional optimizations
|
// Optional optimizations
|
||||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||||
|
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks if the ModelConfig has all required fields.
|
// Validate checks if the ModelConfig has all required fields.
|
||||||
|
|||||||
@@ -1099,3 +1099,59 @@ func TestConfigLogLevelEmpty(t *testing.T) {
|
|||||||
t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Agents.Defaults.LogLevel)
|
t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Agents.Defaults.LogLevel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
|
var minimaxCfg *ModelConfig
|
||||||
|
for i := range cfg.ModelList {
|
||||||
|
if cfg.ModelList[i].Model == "minimax/MiniMax-M2.5" {
|
||||||
|
minimaxCfg = &cfg.ModelList[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if minimaxCfg == nil {
|
||||||
|
t.Fatal("Minimax model not found in ModelList")
|
||||||
|
}
|
||||||
|
if minimaxCfg.ExtraBody == nil {
|
||||||
|
t.Fatal("Minimax ExtraBody should not be nil")
|
||||||
|
}
|
||||||
|
if got, ok := minimaxCfg.ExtraBody["reasoning_split"]; !ok || got != true {
|
||||||
|
t.Fatalf("Minimax ExtraBody[reasoning_split] = %v, want true", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelConfig_ExtraBodyRoundTrip(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := filepath.Join(dir, "config.json")
|
||||||
|
|
||||||
|
cfg := &Config{
|
||||||
|
ModelList: []ModelConfig{
|
||||||
|
{
|
||||||
|
ModelName: "test-model",
|
||||||
|
Model: "openai/test",
|
||||||
|
APIKey: "sk-test",
|
||||||
|
ExtraBody: map[string]any{"custom_field": "value", "num_field": 42},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||||
|
t.Fatalf("SaveConfig error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := LoadConfig(cfgPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadConfig error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded.ModelList[0].ExtraBody == nil {
|
||||||
|
t.Fatal("ExtraBody should not be nil after round-trip")
|
||||||
|
}
|
||||||
|
if got := loaded.ModelList[0].ExtraBody["custom_field"]; got != "value" {
|
||||||
|
t.Errorf("ExtraBody[custom_field] = %v, want value", got)
|
||||||
|
}
|
||||||
|
if got := loaded.ModelList[0].ExtraBody["num_field"]; got != float64(42) {
|
||||||
|
t.Errorf("ExtraBody[num_field] = %v, want 42", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -376,6 +376,7 @@ func DefaultConfig() *Config {
|
|||||||
Model: "minimax/MiniMax-M2.5",
|
Model: "minimax/MiniMax-M2.5",
|
||||||
APIBase: "https://api.minimaxi.com/v1",
|
APIBase: "https://api.minimaxi.com/v1",
|
||||||
APIKey: "",
|
APIKey: "",
|
||||||
|
ExtraBody: map[string]any{"reasoning_split": true},
|
||||||
},
|
},
|
||||||
|
|
||||||
// LongCat - https://longcat.chat/platform
|
// LongCat - https://longcat.chat/platform
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
|||||||
cfg.Proxy,
|
cfg.Proxy,
|
||||||
cfg.MaxTokensField,
|
cfg.MaxTokensField,
|
||||||
cfg.RequestTimeout,
|
cfg.RequestTimeout,
|
||||||
|
cfg.ExtraBody,
|
||||||
), modelID, nil
|
), modelID, nil
|
||||||
|
|
||||||
case "azure", "azure-openai":
|
case "azure", "azure-openai":
|
||||||
@@ -132,6 +133,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
|||||||
cfg.Proxy,
|
cfg.Proxy,
|
||||||
cfg.MaxTokensField,
|
cfg.MaxTokensField,
|
||||||
cfg.RequestTimeout,
|
cfg.RequestTimeout,
|
||||||
|
cfg.ExtraBody,
|
||||||
), modelID, nil
|
), modelID, nil
|
||||||
|
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
@@ -157,6 +159,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
|||||||
cfg.Proxy,
|
cfg.Proxy,
|
||||||
cfg.MaxTokensField,
|
cfg.MaxTokensField,
|
||||||
cfg.RequestTimeout,
|
cfg.RequestTimeout,
|
||||||
|
cfg.ExtraBody,
|
||||||
), modelID, nil
|
), modelID, nil
|
||||||
|
|
||||||
case "anthropic-messages":
|
case "anthropic-messages":
|
||||||
|
|||||||
@@ -24,12 +24,13 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
|
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||||
apiKey, apiBase, proxy, maxTokensField string,
|
apiKey, apiBase, proxy, maxTokensField string,
|
||||||
requestTimeoutSeconds int,
|
requestTimeoutSeconds int,
|
||||||
|
extraBody map[string]any,
|
||||||
) *HTTPProvider {
|
) *HTTPProvider {
|
||||||
return &HTTPProvider{
|
return &HTTPProvider{
|
||||||
delegate: openai_compat.NewProvider(
|
delegate: openai_compat.NewProvider(
|
||||||
@@ -38,6 +39,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
|||||||
proxy,
|
proxy,
|
||||||
openai_compat.WithMaxTokensField(maxTokensField),
|
openai_compat.WithMaxTokensField(maxTokensField),
|
||||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||||
|
openai_compat.WithExtraBody(extraBody),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type Provider struct {
|
|||||||
apiBase string
|
apiBase string
|
||||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
extraBody map[string]any // Additional fields to inject into request body
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(*Provider)
|
type Option func(*Provider)
|
||||||
@@ -55,6 +56,12 @@ func WithRequestTimeout(timeout time.Duration) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithExtraBody(extraBody map[string]any) Option {
|
||||||
|
return func(p *Provider) {
|
||||||
|
p.extraBody = extraBody
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||||
p := &Provider{
|
p := &Provider{
|
||||||
apiKey: apiKey,
|
apiKey: apiKey,
|
||||||
@@ -140,6 +147,12 @@ func (p *Provider) buildRequestBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Merge extra body fields configured per-provider/model.
|
||||||
|
// These are injected last so they take precedence over defaults.
|
||||||
|
for k, v := range p.extraBody {
|
||||||
|
requestBody[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
return requestBody
|
return requestBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -610,6 +610,90 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProviderChat_ExtraBodyInjected(t *testing.T) {
|
||||||
|
var requestBody map[string]any
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp := map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{
|
||||||
|
"message": map[string]any{"content": "ok"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
extraBody := map[string]any{"reasoning_split": true, "custom_field": "test"}
|
||||||
|
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
|
||||||
|
|
||||||
|
_, err := p.Chat(
|
||||||
|
t.Context(),
|
||||||
|
[]Message{{Role: "user", Content: "hi"}},
|
||||||
|
nil,
|
||||||
|
"minimax/abab7",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
|
||||||
|
t.Fatalf("reasoning_split = %v, want true", got)
|
||||||
|
}
|
||||||
|
if got, ok := requestBody["custom_field"]; !ok || got != "test" {
|
||||||
|
t.Fatalf("custom_field = %v, want test", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
|
||||||
|
var requestBody map[string]any
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp := map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{
|
||||||
|
"message": map[string]any{"content": "ok"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
extraBody := map[string]any{"temperature": 0.9}
|
||||||
|
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
|
||||||
|
|
||||||
|
_, err := p.Chat(
|
||||||
|
t.Context(),
|
||||||
|
[]Message{{Role: "user", Content: "hi"}},
|
||||||
|
nil,
|
||||||
|
"gpt-4o",
|
||||||
|
map[string]any{"temperature": 0.5},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtraBody takes precedence over options since it is merged last.
|
||||||
|
if got := requestBody["temperature"]; got != float64(0.9) {
|
||||||
|
t.Fatalf("temperature = %v, want 0.9 (from extraBody, overriding options)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user