feat: add userAgent config for ModelConfig (#2242)

* feat: add userAgent config for ModelConfig

* update docs for ModelConfig.userAgent

* make defaut userAgent to PicoClaw and add test case
This commit is contained in:
Cytown
2026-04-02 11:44:13 +08:00
committed by GitHub
parent 415abc8cd4
commit 2c446e1e07
15 changed files with 286 additions and 29 deletions
+2
View File
@@ -600,6 +600,8 @@ type ModelConfig struct {
// existing configs, the field is inferred during load: models with API keys
// or the reserved "local-model" name are auto-enabled.
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
// UserAgent is the user agent string to use for HTTP requests.
UserAgent string `json:"user_agent,omitempty" yaml:"-"`
// isVirtual marks this model as a virtual model generated from multi-key expansion.
// Virtual models should not be persisted to config files.
+10 -5
View File
@@ -41,15 +41,16 @@ type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
userAgent string
}
// NewProvider creates a new Anthropic Messages API provider.
func NewProvider(apiKey, apiBase string) *Provider {
return NewProviderWithTimeout(apiKey, apiBase, 0)
func NewProvider(apiKey, apiBase, userAgent string) *Provider {
return NewProviderWithTimeout(apiKey, apiBase, userAgent, 0)
}
// NewProviderWithTimeout creates a provider with custom request timeout.
func NewProviderWithTimeout(apiKey, apiBase string, timeoutSeconds int) *Provider {
func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
baseURL := normalizeBaseURL(apiBase)
timeout := defaultRequestTimeout
if timeoutSeconds > 0 {
@@ -57,8 +58,9 @@ func NewProviderWithTimeout(apiKey, apiBase string, timeoutSeconds int) *Provide
}
return &Provider{
apiKey: apiKey,
apiBase: baseURL,
apiKey: apiKey,
apiBase: baseURL,
userAgent: userAgent,
httpClient: &http.Client{
Timeout: timeout,
},
@@ -105,6 +107,9 @@ func (p *Provider) Chat(
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-API-Key", p.apiKey) //nolint:canonicalheader // Anthropic API requires exact header name
req.Header.Set("Anthropic-Version", defaultAPIVersion)
if p.userAgent != "" {
req.Header.Set("User-Agent", p.userAgent)
}
// Execute request
resp, err := p.httpClient.Do(req)
@@ -411,7 +411,7 @@ func TestNormalizeBaseURL(t *testing.T) {
}
func TestNewProvider(t *testing.T) {
provider := NewProvider("test-key", "https://api.example.com")
provider := NewProvider("test-key", "https://api.example.com", "")
if provider == nil {
t.Fatal("NewProvider() returned nil")
}
@@ -424,7 +424,7 @@ func TestNewProvider(t *testing.T) {
}
func TestGetDefaultModel(t *testing.T) {
provider := NewProvider("test-key", "")
provider := NewProvider("test-key", "", "")
got := provider.GetDefaultModel()
expected := "claude-sonnet-4.6"
if got != expected {
@@ -743,7 +743,7 @@ func TestProviderChatErrors(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create provider using constructor to ensure proper initialization
provider := NewProvider(tt.apiKey, "https://api.example.com")
provider := NewProvider(tt.apiKey, "https://api.example.com", "")
_, err := provider.Chat(context.Background(), tt.messages, nil, "test-model", nil)
if err == nil {
+15 -3
View File
@@ -36,6 +36,7 @@ type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
userAgent string
}
// Option configures the Azure Provider.
@@ -50,11 +51,19 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
}
// WithUserAgent sets the User-Agent header for requests.
func WithUserAgent(userAgent string) Option {
return func(p *Provider) {
p.userAgent = userAgent
}
}
// NewProvider creates a new Azure OpenAI provider.
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
func NewProvider(apiKey, apiBase, proxy, userAgent string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
userAgent: userAgent,
httpClient: common.NewHTTPClient(proxy),
}
@@ -68,9 +77,9 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
}
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
func NewProviderWithTimeout(apiKey, apiBase, proxy, userAgent string, requestTimeoutSeconds int) *Provider {
return NewProvider(
apiKey, apiBase, proxy,
apiKey, apiBase, proxy, userAgent,
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
)
}
@@ -141,6 +150,9 @@ func (p *Provider) Chat(
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
if p.userAgent != "" {
req.Header.Set("User-Agent", p.userAgent)
}
resp, err := p.httpClient.Do(req)
if err != nil {
+16 -16
View File
@@ -46,7 +46,7 @@ func TestProviderChat_AzureURLConstruction(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -69,7 +69,7 @@ func TestProviderChat_AzureAuthHeader(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-azure-key", server.URL, "")
p := NewProvider("test-azure-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -92,7 +92,7 @@ func TestProviderChat_AzureRequestBodyContainsModel(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -112,7 +112,7 @@ func TestProviderChat_AzureUsesMaxOutputTokens(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
@@ -144,7 +144,7 @@ func TestProviderChat_AzureStoreIsFalse(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -161,7 +161,7 @@ func TestProviderChat_AzureHTTPError(t *testing.T) {
}))
defer server.Close()
p := NewProvider("bad-key", server.URL, "")
p := NewProvider("bad-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
@@ -176,7 +176,7 @@ func TestProviderChat_AzureRateLimitError(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for 429, got nil")
@@ -194,7 +194,7 @@ func TestProviderChat_AzureServerError(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for 500, got nil")
@@ -229,7 +229,7 @@ func TestProviderChat_AzureParseTextOutput(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -270,7 +270,7 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -287,7 +287,7 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
p := NewProvider("test-key", "", "")
p := NewProvider("test-key", "", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
@@ -295,21 +295,21 @@ func TestProvider_AzureEmptyAPIBase(t *testing.T) {
}
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "")
p := NewProvider("test-key", "https://example.com", "", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
p := NewProvider("test-key", "https://example.com", "", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
p := NewProviderWithTimeout("test-key", "https://example.com", "", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
@@ -343,7 +343,7 @@ func TestProviderChat_AzureNativeWebSearchInjection(t *testing.T) {
},
}
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
// With native_search=true: user-defined web_search should be replaced by built-in
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment",
@@ -393,7 +393,7 @@ func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) {
},
}
p := NewProvider("test-key", server.URL, "")
p := NewProvider("test-key", server.URL, "", "")
// Without native_search: user-defined web_search should be kept as-is
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", nil)
+12
View File
@@ -129,6 +129,11 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
protocol, modelID := ExtractProtocol(cfg.Model)
userAgent := cfg.UserAgent
if userAgent == "" {
userAgent = fmt.Sprintf("PicoClaw/%s", config.Version)
}
switch protocol {
case "openai":
// OpenAI with OAuth/token auth (Codex-style)
@@ -152,6 +157,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
), modelID, nil
@@ -171,6 +177,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.APIKey(),
cfg.APIBase,
cfg.Proxy,
userAgent,
cfg.RequestTimeout,
), modelID, nil
@@ -228,6 +235,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
), modelID, nil
@@ -253,6 +261,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
userAgent,
cfg.RequestTimeout,
extraBody,
), modelID, nil
@@ -279,6 +288,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
), modelID, nil
@@ -295,6 +305,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
return anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
userAgent,
cfg.RequestTimeout,
), modelID, nil
@@ -310,6 +321,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
return anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
userAgent,
cfg.RequestTimeout,
), modelID, nil
+101
View File
@@ -846,6 +846,107 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
}
}
// openaiCompatResponse is the JSON response used by OpenAI-compatible providers.
const openaiCompatResponse = `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
// anthropicResponse is the JSON response used by Anthropic providers.
const anthropicResponse = `{"content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","model":"claude-sonnet-4-20250514","usage":{"input_tokens":10,"output_tokens":5}}`
func TestCreateProviderFromConfig_UserAgent(t *testing.T) {
defaultUA := "PicoClaw/" + config.Version
tests := []struct {
name string
model string
userAgent string
apiKey string
response string
wantUA string
chatOpts map[string]any
}{
{
name: "openai default user agent",
model: "openai/gpt-4o",
apiKey: "test-key",
response: openaiCompatResponse,
wantUA: defaultUA,
},
{
name: "openai custom user agent",
model: "openai/gpt-4o",
apiKey: "test-key",
userAgent: "MyAgent/1.2.3",
response: openaiCompatResponse,
wantUA: "MyAgent/1.2.3",
},
{
name: "anthropic default user agent",
model: "anthropic/claude-sonnet-4-20250514",
apiKey: "test-key",
response: anthropicResponse,
wantUA: defaultUA,
},
{
name: "anthropic-messages default user agent",
model: "anthropic-messages/claude-sonnet-4-20250514",
apiKey: "test-key",
response: anthropicResponse,
wantUA: defaultUA,
chatOpts: map[string]any{"max_tokens": 1024},
},
{
name: "azure default user agent",
model: "azure/my-deployment",
apiKey: "test-azure-key",
response: openaiCompatResponse,
wantUA: defaultUA,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var receivedUA string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedUA = r.Header.Get("User-Agent")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tt.response))
}))
defer server.Close()
cfg := &config.ModelConfig{
ModelName: "test-ua-" + tt.name,
Model: tt.model,
APIBase: server.URL,
UserAgent: tt.userAgent,
}
cfg.SetAPIKey(tt.apiKey)
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
_, err = provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
modelID,
tt.chatOpts,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if receivedUA != tt.wantUA {
t.Errorf("User-Agent = %q, want %q", receivedUA, tt.wantUA)
}
})
}
}
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
+3 -2
View File
@@ -24,11 +24,11 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil)
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil)
}
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
apiKey, apiBase, proxy, maxTokensField string,
apiKey, apiBase, proxy, maxTokensField, userAgent string,
requestTimeoutSeconds int,
extraBody map[string]any,
) *HTTPProvider {
@@ -40,6 +40,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
openai_compat.WithExtraBody(extraBody),
openai_compat.WithUserAgent(userAgent),
),
}
}
+10
View File
@@ -36,6 +36,7 @@ type Provider struct {
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
extraBody map[string]any // Additional fields to inject into request body
userAgent string
}
type Option func(*Provider)
@@ -66,6 +67,12 @@ func WithMaxTokensField(maxTokensField string) Option {
}
}
func WithUserAgent(userAgent string) Option {
return func(p *Provider) {
p.userAgent = userAgent
}
}
func WithRequestTimeout(timeout time.Duration) Option {
return func(p *Provider) {
if timeout > 0 {
@@ -198,6 +205,9 @@ func (p *Provider) Chat(
}
req.Header.Set("Content-Type", "application/json")
if p.userAgent != "" {
req.Header.Set("User-Agent", p.userAgent)
}
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}