mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(provider): support custom headers injection for HTTP providers (#2402)
* feat(provider): support custom headers injection for HTTP providers * fix(provider): resolve lint problem * fix(provider): align stream user-agent and header precedence docs
This commit is contained in:
@@ -605,11 +605,12 @@ type ModelConfig struct {
|
||||
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
||||
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
|
||||
|
||||
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
|
||||
|
||||
@@ -1279,6 +1280,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
isVirtual: true,
|
||||
}
|
||||
expanded = append(expanded, additionalEntry)
|
||||
@@ -1299,6 +1301,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
APIKeys: SimpleSecureStrings(keys[0]),
|
||||
}
|
||||
|
||||
|
||||
@@ -1528,6 +1528,42 @@ func TestModelConfig_ExtraBodyRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
cfg := &Config{
|
||||
Version: CurrentVersion,
|
||||
ModelList: []*ModelConfig{
|
||||
{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test",
|
||||
APIKeys: SimpleSecureStrings("sk-test"),
|
||||
CustomHeaders: map[string]string{"X-Source": "coding-plan", "X-Agent": "openclaw"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
|
||||
if loaded.ModelList[0].CustomHeaders == nil {
|
||||
t.Fatal("CustomHeaders should not be nil after round-trip")
|
||||
}
|
||||
if got := loaded.ModelList[0].CustomHeaders["X-Source"]; got != "coding-plan" {
|
||||
t.Errorf("CustomHeaders[X-Source] = %q, want coding-plan", got)
|
||||
}
|
||||
if got := loaded.ModelList[0].CustomHeaders["X-Agent"]; got != "openclaw" {
|
||||
t.Errorf("CustomHeaders[X-Agent] = %q, want openclaw", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
|
||||
@@ -160,6 +160,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "azure", "azure-openai":
|
||||
@@ -238,6 +239,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "minimax":
|
||||
@@ -264,6 +266,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
extraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
@@ -291,6 +294,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic-messages":
|
||||
|
||||
@@ -846,6 +846,49 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_CustomHeaders(t *testing.T) {
|
||||
var gotSource, gotAuth string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSource = r.Header.Get("X-Source")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-headers",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: server.URL,
|
||||
CustomHeaders: map[string]string{"X-Source": "coding-plan", "Authorization": "Token config-auth"},
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if gotSource != "coding-plan" {
|
||||
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
|
||||
}
|
||||
if gotAuth != "Token config-auth" {
|
||||
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token config-auth")
|
||||
}
|
||||
}
|
||||
|
||||
// openaiCompatResponse is the JSON response used by OpenAI-compatible providers.
|
||||
const openaiCompatResponse = `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
|
||||
|
||||
|
||||
@@ -24,13 +24,14 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil)
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil, nil)
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField, userAgent string,
|
||||
requestTimeoutSeconds int,
|
||||
extraBody map[string]any,
|
||||
customHeaders map[string]string,
|
||||
) *HTTPProvider {
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProvider(
|
||||
@@ -40,6 +41,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
openai_compat.WithMaxTokensField(maxTokensField),
|
||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
openai_compat.WithExtraBody(extraBody),
|
||||
openai_compat.WithCustomHeaders(customHeaders),
|
||||
openai_compat.WithUserAgent(userAgent),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type Provider struct {
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
extraBody map[string]any // Additional fields to inject into request body
|
||||
customHeaders map[string]string
|
||||
userAgent string
|
||||
}
|
||||
|
||||
@@ -87,6 +88,12 @@ func WithExtraBody(extraBody map[string]any) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomHeaders(customHeaders map[string]string) Option {
|
||||
return func(p *Provider) {
|
||||
p.customHeaders = customHeaders
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
@@ -181,6 +188,15 @@ func (p *Provider) buildRequestBody(
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func (p *Provider) applyCustomHeaders(req *http.Request) {
|
||||
for k, v := range p.customHeaders {
|
||||
if strings.TrimSpace(k) == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
@@ -211,6 +227,7 @@ func (p *Provider) Chat(
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
p.applyCustomHeaders(req)
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -254,9 +271,13 @@ func (p *Provider) ChatStream(
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if p.userAgent != "" {
|
||||
req.Header.Set("User-Agent", p.userAgent)
|
||||
}
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
p.applyCustomHeaders(req)
|
||||
|
||||
// Use a client without Timeout for streaming — the http.Client.Timeout covers
|
||||
// the entire request lifecycle including body reads, which would kill long streams.
|
||||
|
||||
@@ -710,6 +710,111 @@ func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_CustomHeadersInjected(t *testing.T) {
|
||||
var gotSource, gotAuth, gotUserAgent string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSource = r.Header.Get("X-Source")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider(
|
||||
"key",
|
||||
server.URL,
|
||||
"",
|
||||
WithUserAgent("PicoClaw/Test"),
|
||||
WithCustomHeaders(map[string]string{
|
||||
"X-Source": "coding-plan",
|
||||
"Authorization": "Token custom-auth",
|
||||
"User-Agent": "Custom-UA/1.0",
|
||||
}),
|
||||
)
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if gotSource != "coding-plan" {
|
||||
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
|
||||
}
|
||||
if gotAuth != "Token custom-auth" {
|
||||
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token custom-auth")
|
||||
}
|
||||
if gotUserAgent != "Custom-UA/1.0" {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/1.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChatStream_CustomHeadersInjected(t *testing.T) {
|
||||
var gotSource, gotAuth, gotUserAgent string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSource = r.Header.Get("X-Source")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":\"stop\"}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider(
|
||||
"key",
|
||||
server.URL,
|
||||
"",
|
||||
WithUserAgent("PicoClaw/Test"),
|
||||
WithCustomHeaders(map[string]string{
|
||||
"X-Source": "coding-plan",
|
||||
"Authorization": "Token stream-auth",
|
||||
"User-Agent": "Custom-UA/Stream",
|
||||
}),
|
||||
)
|
||||
|
||||
out, err := p.ChatStream(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStream() error = %v", err)
|
||||
}
|
||||
if out.Content != "ok" {
|
||||
t.Fatalf("Content = %q, want %q", out.Content, "ok")
|
||||
}
|
||||
if gotSource != "coding-plan" {
|
||||
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
|
||||
}
|
||||
if gotAuth != "Token stream-auth" {
|
||||
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token stream-auth")
|
||||
}
|
||||
if gotUserAgent != "Custom-UA/Stream" {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/Stream")
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
|
||||
Reference in New Issue
Block a user