mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #1878 from uiYzzi/feat/provider-extra-body-config
feat(providers): add extra_body config to inject custom fields into request body
This commit is contained in:
@@ -150,7 +150,7 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
@@ -196,7 +196,7 @@ func TestHandleCommand_UseCommandRejectsUnknownSkill(t *testing.T) {
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
@@ -240,7 +240,7 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
|
||||
@@ -936,10 +936,11 @@ type ModelConfig struct {
|
||||
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
||||
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
|
||||
// from security
|
||||
secModelName string
|
||||
@@ -2079,6 +2080,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
}
|
||||
expanded = append(expanded, additionalEntry)
|
||||
fallbackNames = append(fallbackNames, expandedName)
|
||||
@@ -2097,6 +2099,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
apiKeys: []string{keys[0]},
|
||||
}
|
||||
|
||||
|
||||
@@ -1193,3 +1193,62 @@ func TestConfigLogLevelEmpty(t *testing.T) {
|
||||
t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Gateway.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_ExtraBodyRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
cfg := &Config{
|
||||
ModelList: []*ModelConfig{
|
||||
{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test",
|
||||
apiKeys: []string{"sk-test"},
|
||||
ExtraBody: map[string]any{"custom_field": "value", "num_field": 42},
|
||||
},
|
||||
},
|
||||
security: &SecurityConfig{
|
||||
ModelList: map[string]ModelSecurityEntry{"test-model:0": {APIKeys: []string{"sk-test"}}},
|
||||
},
|
||||
}
|
||||
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
|
||||
if loaded.ModelList[0].ExtraBody == nil {
|
||||
t.Fatal("ExtraBody should not be nil after round-trip")
|
||||
}
|
||||
if got := loaded.ModelList[0].ExtraBody["custom_field"]; got != "value" {
|
||||
t.Errorf("ExtraBody[custom_field] = %v, want value", got)
|
||||
}
|
||||
if got := loaded.ModelList[0].ExtraBody["num_field"]; got != float64(42) {
|
||||
t.Errorf("ExtraBody[num_field] = %v, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
var minimaxCfg *ModelConfig
|
||||
for i := range cfg.ModelList {
|
||||
if cfg.ModelList[i].Model == "minimax/MiniMax-M2.5" {
|
||||
minimaxCfg = cfg.ModelList[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if minimaxCfg == nil {
|
||||
t.Fatal("Minimax model not found in ModelList")
|
||||
}
|
||||
if minimaxCfg.ExtraBody == nil {
|
||||
t.Fatal("Minimax ExtraBody should not be nil")
|
||||
}
|
||||
if got, ok := minimaxCfg.ExtraBody["reasoning_split"]; !ok || got != true {
|
||||
t.Fatalf("Minimax ExtraBody[reasoning_split] = %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,6 +339,7 @@ func DefaultConfig() *Config {
|
||||
ModelName: "MiniMax-M2.5",
|
||||
Model: "minimax/MiniMax-M2.5",
|
||||
APIBase: "https://api.minimaxi.com/v1",
|
||||
ExtraBody: map[string]any{"reasoning_split": true},
|
||||
},
|
||||
|
||||
// LongCat - https://longcat.chat/platform
|
||||
|
||||
@@ -93,6 +93,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "azure", "azure-openai":
|
||||
@@ -116,7 +117,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
|
||||
"qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita",
|
||||
"coding-plan", "alibaba-coding", "qwen-coding":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
@@ -132,6 +133,32 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "minimax":
|
||||
// Minimax requires reasoning_split: true in the request body
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
}
|
||||
apiBase := cfg.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
extraBody := cfg.ExtraBody
|
||||
if extraBody == nil {
|
||||
extraBody = make(map[string]any)
|
||||
}
|
||||
if _, ok := extraBody["reasoning_split"]; !ok {
|
||||
extraBody["reasoning_split"] = true
|
||||
}
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
extraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
@@ -157,6 +184,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic-messages":
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -604,3 +605,98 @@ func TestGetDefaultAPIBase_QwenUSAliases(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MinimaxInjectsReasoningSplit(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-minimax",
|
||||
Model: "minimax/MiniMax-M2.5",
|
||||
APIBase: server.URL,
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "MiniMax-M2.5" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "MiniMax-M2.5")
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify reasoning_split is automatically injected
|
||||
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
|
||||
t.Fatalf("reasoning_split = %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-minimax-custom",
|
||||
Model: "minimax/MiniMax-M2.5",
|
||||
APIBase: server.URL,
|
||||
ExtraBody: map[string]any{"custom_field": "test"},
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify reasoning_split is automatically injected
|
||||
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
|
||||
t.Fatalf("reasoning_split = %v, want true", got)
|
||||
}
|
||||
// Verify user's custom field is preserved
|
||||
if got, ok := requestBody["custom_field"]; !ok || got != "test" {
|
||||
t.Fatalf("custom_field = %v, want test", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,12 +24,13 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil)
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField string,
|
||||
requestTimeoutSeconds int,
|
||||
extraBody map[string]any,
|
||||
) *HTTPProvider {
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProvider(
|
||||
@@ -38,6 +39,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
proxy,
|
||||
openai_compat.WithMaxTokensField(maxTokensField),
|
||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
openai_compat.WithExtraBody(extraBody),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ type Provider struct {
|
||||
apiBase string
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
extraBody map[string]any // Additional fields to inject into request body
|
||||
}
|
||||
|
||||
type Option func(*Provider)
|
||||
@@ -55,6 +56,12 @@ func WithRequestTimeout(timeout time.Duration) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtraBody(extraBody map[string]any) Option {
|
||||
return func(p *Provider) {
|
||||
p.extraBody = extraBody
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
@@ -140,6 +147,12 @@ func (p *Provider) buildRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra body fields configured per-provider/model.
|
||||
// These are injected last so they take precedence over defaults.
|
||||
for k, v := range p.extraBody {
|
||||
requestBody[k] = v
|
||||
}
|
||||
|
||||
return requestBody
|
||||
}
|
||||
|
||||
|
||||
@@ -610,6 +610,90 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ExtraBodyInjected(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
extraBody := map[string]any{"reasoning_split": true, "custom_field": "test"}
|
||||
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"minimax/abab7",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if got, ok := requestBody["reasoning_split"]; !ok || got != true {
|
||||
t.Fatalf("reasoning_split = %v, want true", got)
|
||||
}
|
||||
if got, ok := requestBody["custom_field"]; !ok || got != "test" {
|
||||
t.Fatalf("custom_field = %v, want test", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
extraBody := map[string]any{"temperature": 0.9}
|
||||
p := NewProvider("key", server.URL, "", WithExtraBody(extraBody))
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
map[string]any{"temperature": 0.5},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// ExtraBody takes precedence over options since it is merged last.
|
||||
if got := requestBody["temperature"]; got != float64(0.9) {
|
||||
t.Fatalf("temperature = %v, want 0.9 (from extraBody, overriding options)", got)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
|
||||
@@ -31,12 +31,13 @@ type modelResponse struct {
|
||||
Proxy string `json:"proxy,omitempty"`
|
||||
AuthMethod string `json:"auth_method,omitempty"`
|
||||
// Advanced fields
|
||||
ConnectMode string `json:"connect_mode,omitempty"`
|
||||
Workspace string `json:"workspace,omitempty"`
|
||||
RPM int `json:"rpm,omitempty"`
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"`
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"`
|
||||
ConnectMode string `json:"connect_mode,omitempty"`
|
||||
Workspace string `json:"workspace,omitempty"`
|
||||
RPM int `json:"rpm,omitempty"`
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"`
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"`
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
// Meta
|
||||
Configured bool `json:"configured"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
@@ -81,6 +82,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
Configured: configured[i],
|
||||
IsDefault: m.ModelName == defaultModel,
|
||||
})
|
||||
@@ -183,6 +185,9 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
|
||||
if mc.APIKey() == "" {
|
||||
mc.SetAPIKey(cfg.ModelList[idx].APIKey())
|
||||
}
|
||||
if mc.ExtraBody == nil {
|
||||
mc.ExtraBody = cfg.ModelList[idx].ExtraBody
|
||||
}
|
||||
|
||||
cfg.ModelList[idx] = &mc
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ export interface ModelInfo {
|
||||
max_tokens_field?: string
|
||||
request_timeout?: number
|
||||
thinking_level?: string
|
||||
extra_body?: Record<string, unknown>
|
||||
// Meta
|
||||
configured: boolean
|
||||
is_default: boolean
|
||||
|
||||
Reference in New Issue
Block a user