mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat: add userAgent config for ModelConfig (#2242)
* feat: add userAgent config for ModelConfig * update docs for ModelConfig.userAgent * make defaut userAgent to PicoClaw and add test case
This commit is contained in:
@@ -600,6 +600,8 @@ type ModelConfig struct {
|
||||
// existing configs, the field is inferred during load: models with API keys
|
||||
// or the reserved "local-model" name are auto-enabled.
|
||||
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
|
||||
// UserAgent is the user agent string to use for HTTP requests.
|
||||
UserAgent string `json:"user_agent,omitempty" yaml:"-"`
|
||||
|
||||
// isVirtual marks this model as a virtual model generated from multi-key expansion.
|
||||
// Virtual models should not be persisted to config files.
|
||||
|
||||
@@ -41,15 +41,16 @@ type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
userAgent string
|
||||
}
|
||||
|
||||
// NewProvider creates a new Anthropic Messages API provider.
|
||||
func NewProvider(apiKey, apiBase string) *Provider {
|
||||
return NewProviderWithTimeout(apiKey, apiBase, 0)
|
||||
func NewProvider(apiKey, apiBase, userAgent string) *Provider {
|
||||
return NewProviderWithTimeout(apiKey, apiBase, userAgent, 0)
|
||||
}
|
||||
|
||||
// NewProviderWithTimeout creates a provider with custom request timeout.
|
||||
func NewProviderWithTimeout(apiKey, apiBase string, timeoutSeconds int) *Provider {
|
||||
func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
timeout := defaultRequestTimeout
|
||||
if timeoutSeconds > 0 {
|
||||
@@ -57,8 +58,9 @@ func NewProviderWithTimeout(apiKey, apiBase string, timeoutSeconds int) *Provide
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: baseURL,
|
||||
apiKey: apiKey,
|
||||
apiBase: baseURL,
|
||||
userAgent: userAgent,
|
||||
httpClient: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
@@ -105,6 +107,9 @@ func (p *Provider) Chat(
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-API-Key", p.apiKey) //nolint:canonicalheader // Anthropic API requires exact header name
|
||||
req.Header.Set("Anthropic-Version", defaultAPIVersion)
|
||||
if p.userAgent != "" {
|
||||
req.Header.Set("User-Agent", p.userAgent)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
resp, err := p.httpClient.Do(req)
|
||||
|
||||
@@ -411,7 +411,7 @@ func TestNormalizeBaseURL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
provider := NewProvider("test-key", "https://api.example.com")
|
||||
provider := NewProvider("test-key", "https://api.example.com", "")
|
||||
if provider == nil {
|
||||
t.Fatal("NewProvider() returned nil")
|
||||
}
|
||||
@@ -424,7 +424,7 @@ func TestNewProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetDefaultModel(t *testing.T) {
|
||||
provider := NewProvider("test-key", "")
|
||||
provider := NewProvider("test-key", "", "")
|
||||
got := provider.GetDefaultModel()
|
||||
expected := "claude-sonnet-4.6"
|
||||
if got != expected {
|
||||
@@ -743,7 +743,7 @@ func TestProviderChatErrors(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create provider using constructor to ensure proper initialization
|
||||
provider := NewProvider(tt.apiKey, "https://api.example.com")
|
||||
provider := NewProvider(tt.apiKey, "https://api.example.com", "")
|
||||
|
||||
_, err := provider.Chat(context.Background(), tt.messages, nil, "test-model", nil)
|
||||
if err == nil {
|
||||
|
||||
@@ -36,6 +36,7 @@ type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
userAgent string
|
||||
}
|
||||
|
||||
// Option configures the Azure Provider.
|
||||
@@ -50,11 +51,19 @@ func WithRequestTimeout(timeout time.Duration) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserAgent sets the User-Agent header for requests.
|
||||
func WithUserAgent(userAgent string) Option {
|
||||
return func(p *Provider) {
|
||||
p.userAgent = userAgent
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider creates a new Azure OpenAI provider.
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
func NewProvider(apiKey, apiBase, proxy, userAgent string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
userAgent: userAgent,
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
}
|
||||
|
||||
@@ -68,9 +77,9 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
}
|
||||
|
||||
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
|
||||
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
|
||||
func NewProviderWithTimeout(apiKey, apiBase, proxy, userAgent string, requestTimeoutSeconds int) *Provider {
|
||||
return NewProvider(
|
||||
apiKey, apiBase, proxy,
|
||||
apiKey, apiBase, proxy, userAgent,
|
||||
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
)
|
||||
}
|
||||
@@ -141,6 +150,9 @@ func (p *Provider) Chat(
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
if p.userAgent != "" {
|
||||
req.Header.Set("User-Agent", p.userAgent)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestProviderChat_AzureURLConstruction(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -69,7 +69,7 @@ func TestProviderChat_AzureAuthHeader(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-azure-key", server.URL, "")
|
||||
p := NewProvider("test-azure-key", server.URL, "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -92,7 +92,7 @@ func TestProviderChat_AzureRequestBodyContainsModel(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -112,7 +112,7 @@ func TestProviderChat_AzureUsesMaxOutputTokens(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
@@ -144,7 +144,7 @@ func TestProviderChat_AzureStoreIsFalse(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -161,7 +161,7 @@ func TestProviderChat_AzureHTTPError(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("bad-key", server.URL, "")
|
||||
p := NewProvider("bad-key", server.URL, "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
@@ -176,7 +176,7 @@ func TestProviderChat_AzureRateLimitError(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 429, got nil")
|
||||
@@ -194,7 +194,7 @@ func TestProviderChat_AzureServerError(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 500, got nil")
|
||||
@@ -229,7 +229,7 @@ func TestProviderChat_AzureParseTextOutput(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -270,7 +270,7 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -287,7 +287,7 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
|
||||
p := NewProvider("test-key", "", "")
|
||||
p := NewProvider("test-key", "", "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty API base")
|
||||
@@ -295,21 +295,21 @@ func TestProvider_AzureEmptyAPIBase(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "")
|
||||
p := NewProvider("test-key", "https://example.com", "", "")
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
|
||||
p := NewProvider("test-key", "https://example.com", "", "", WithRequestTimeout(300*time.Second))
|
||||
if p.httpClient.Timeout != 300*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
|
||||
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
|
||||
p := NewProviderWithTimeout("test-key", "https://example.com", "", "", 180)
|
||||
if p.httpClient.Timeout != 180*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
|
||||
}
|
||||
@@ -343,7 +343,7 @@ func TestProviderChat_AzureNativeWebSearchInjection(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
|
||||
// With native_search=true: user-defined web_search should be replaced by built-in
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment",
|
||||
@@ -393,7 +393,7 @@ func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
p := NewProvider("test-key", server.URL, "", "")
|
||||
|
||||
// Without native_search: user-defined web_search should be kept as-is
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", nil)
|
||||
|
||||
@@ -129,6 +129,11 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
protocol, modelID := ExtractProtocol(cfg.Model)
|
||||
|
||||
userAgent := cfg.UserAgent
|
||||
if userAgent == "" {
|
||||
userAgent = fmt.Sprintf("PicoClaw/%s", config.Version)
|
||||
}
|
||||
|
||||
switch protocol {
|
||||
case "openai":
|
||||
// OpenAI with OAuth/token auth (Codex-style)
|
||||
@@ -152,6 +157,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
@@ -171,6 +177,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.APIKey(),
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
@@ -228,6 +235,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
@@ -253,6 +261,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
extraBody,
|
||||
), modelID, nil
|
||||
@@ -279,6 +288,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
cfg.MaxTokensField,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
), modelID, nil
|
||||
@@ -295,6 +305,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
@@ -310,6 +321,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
|
||||
@@ -846,6 +846,107 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// openaiCompatResponse is the JSON response used by OpenAI-compatible providers.
|
||||
const openaiCompatResponse = `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
|
||||
|
||||
// anthropicResponse is the JSON response used by Anthropic providers.
|
||||
const anthropicResponse = `{"content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","model":"claude-sonnet-4-20250514","usage":{"input_tokens":10,"output_tokens":5}}`
|
||||
|
||||
func TestCreateProviderFromConfig_UserAgent(t *testing.T) {
|
||||
defaultUA := "PicoClaw/" + config.Version
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
userAgent string
|
||||
apiKey string
|
||||
response string
|
||||
wantUA string
|
||||
chatOpts map[string]any
|
||||
}{
|
||||
{
|
||||
name: "openai default user agent",
|
||||
model: "openai/gpt-4o",
|
||||
apiKey: "test-key",
|
||||
response: openaiCompatResponse,
|
||||
wantUA: defaultUA,
|
||||
},
|
||||
{
|
||||
name: "openai custom user agent",
|
||||
model: "openai/gpt-4o",
|
||||
apiKey: "test-key",
|
||||
userAgent: "MyAgent/1.2.3",
|
||||
response: openaiCompatResponse,
|
||||
wantUA: "MyAgent/1.2.3",
|
||||
},
|
||||
{
|
||||
name: "anthropic default user agent",
|
||||
model: "anthropic/claude-sonnet-4-20250514",
|
||||
apiKey: "test-key",
|
||||
response: anthropicResponse,
|
||||
wantUA: defaultUA,
|
||||
},
|
||||
{
|
||||
name: "anthropic-messages default user agent",
|
||||
model: "anthropic-messages/claude-sonnet-4-20250514",
|
||||
apiKey: "test-key",
|
||||
response: anthropicResponse,
|
||||
wantUA: defaultUA,
|
||||
chatOpts: map[string]any{"max_tokens": 1024},
|
||||
},
|
||||
{
|
||||
name: "azure default user agent",
|
||||
model: "azure/my-deployment",
|
||||
apiKey: "test-azure-key",
|
||||
response: openaiCompatResponse,
|
||||
wantUA: defaultUA,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var receivedUA string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedUA = r.Header.Get("User-Agent")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(tt.response))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-ua-" + tt.name,
|
||||
Model: tt.model,
|
||||
APIBase: server.URL,
|
||||
UserAgent: tt.userAgent,
|
||||
}
|
||||
cfg.SetAPIKey(tt.apiKey)
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
tt.chatOpts,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if receivedUA != tt.wantUA {
|
||||
t.Errorf("User-Agent = %q, want %q", receivedUA, tt.wantUA)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
|
||||
// Set dummy AWS env vars to make test deterministic
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
|
||||
|
||||
@@ -24,11 +24,11 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil)
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil)
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField string,
|
||||
apiKey, apiBase, proxy, maxTokensField, userAgent string,
|
||||
requestTimeoutSeconds int,
|
||||
extraBody map[string]any,
|
||||
) *HTTPProvider {
|
||||
@@ -40,6 +40,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
openai_compat.WithMaxTokensField(maxTokensField),
|
||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
openai_compat.WithExtraBody(extraBody),
|
||||
openai_compat.WithUserAgent(userAgent),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type Provider struct {
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
extraBody map[string]any // Additional fields to inject into request body
|
||||
userAgent string
|
||||
}
|
||||
|
||||
type Option func(*Provider)
|
||||
@@ -66,6 +67,12 @@ func WithMaxTokensField(maxTokensField string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithUserAgent(userAgent string) Option {
|
||||
return func(p *Provider) {
|
||||
p.userAgent = userAgent
|
||||
}
|
||||
}
|
||||
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(p *Provider) {
|
||||
if timeout > 0 {
|
||||
@@ -198,6 +205,9 @@ func (p *Provider) Chat(
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.userAgent != "" {
|
||||
req.Header.Set("User-Agent", p.userAgent)
|
||||
}
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user