diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index d4ceaab2c..4330163df 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -10,6 +10,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -42,7 +43,7 @@ func NewProvider(token string) *Provider { } func NewProviderWithBaseURL(token, apiBase string) *Provider { - baseURL := normalizeBaseURL(apiBase) + baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, false) client := anthropic.NewClient( option.WithAuthToken(token), option.WithBaseURL(baseURL), @@ -385,20 +386,3 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { }, } } - -func normalizeBaseURL(apiBase string) string { - base := strings.TrimSpace(apiBase) - if base == "" { - return defaultBaseURL - } - - base = strings.TrimRight(base, "/") - if before, ok := strings.CutSuffix(base, "/v1"); ok { - base = before - } - if base == "" { - return defaultBaseURL - } - - return base -} diff --git a/pkg/providers/anthropic_messages/provider.go b/pkg/providers/anthropic_messages/provider.go index 1e865b709..dcc31b6f9 100644 --- a/pkg/providers/anthropic_messages/provider.go +++ b/pkg/providers/anthropic_messages/provider.go @@ -16,6 +16,7 @@ import ( "strings" "time" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -51,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider { // NewProviderWithTimeout creates a provider with custom request timeout. func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider { - baseURL := normalizeBaseURL(apiBase) + baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, true) timeout := defaultRequestTimeout if timeoutSeconds > 0 { timeout = time.Duration(timeoutSeconds) * time.Second @@ -161,7 +162,7 @@ func buildRequestBody( options map[string]any, ) (map[string]any, error) { // max_tokens is required and guaranteed by agent loop - maxTokens, ok := asInt(options["max_tokens"]) + maxTokens, ok := common.AsInt(options["max_tokens"]) if !ok { return nil, fmt.Errorf("max_tokens is required in options") } @@ -173,7 +174,7 @@ func buildRequestBody( } // Set temperature from options - if temp, ok := asFloat(options["temperature"]); ok { + if temp, ok := common.AsFloat(options["temperature"]); ok { result["temperature"] = temp } @@ -361,61 +362,6 @@ func parseResponseBody(body []byte) (*LLMResponse, error) { }, nil } -// normalizeBaseURL ensures the base URL is properly formatted. -// It removes /v1 suffix if present (to avoid duplication) and always appends /v1. -// This handles edge cases like "https://api.example.com/v1/proxy" correctly. -func normalizeBaseURL(apiBase string) string { - base := strings.TrimSpace(apiBase) - if base == "" { - return defaultBaseURL - } - - // Remove trailing slashes - base = strings.TrimRight(base, "/") - - // Remove /v1 suffix if present (will be re-added) - // This prevents duplication for URLs like "https://api.example.com/v1/proxy" - if before, ok := strings.CutSuffix(base, "/v1"); ok { - base = before - } - - // Ensure we don't have an empty string after cutting - if base == "" { - return defaultBaseURL - } - - // Add /v1 suffix (required by Anthropic Messages API) - return base + "/v1" -} - -// Helper functions for type conversion - -func asInt(v any) (int, bool) { - switch val := v.(type) { - case int: - return val, true - case float64: - return int(val), true - case int64: - return int(val), true - default: - return 0, false - } -} - -func asFloat(v any) (float64, bool) { - switch val := v.(type) { - case float64: - return val, true - case int: - return float64(val), true - case int64: - return float64(val), true - default: - return 0, false - } -} - // Anthropic API response structures type anthropicMessageResponse struct { diff --git a/pkg/providers/anthropic_messages/provider_test.go b/pkg/providers/anthropic_messages/provider_test.go index ba9d24b66..6401d84bd 100644 --- a/pkg/providers/anthropic_messages/provider_test.go +++ b/pkg/providers/anthropic_messages/provider_test.go @@ -372,44 +372,6 @@ func TestParseResponseBody(t *testing.T) { } } -func TestNormalizeBaseURL(t *testing.T) { - tests := []struct { - name string - apiBase string - expected string - }{ - { - name: "empty string defaults to official API", - apiBase: "", - expected: "https://api.anthropic.com/v1", - }, - { - name: "URL without /v1 gets it appended", - apiBase: "https://api.example.com/anthropic", - expected: "https://api.example.com/anthropic/v1", - }, - { - name: "URL with /v1 remains unchanged", - apiBase: "https://api.example.com/v1", - expected: "https://api.example.com/v1", - }, - { - name: "URL with trailing slash gets cleaned", - apiBase: "https://api.example.com/anthropic/", - expected: "https://api.example.com/anthropic/v1", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := normalizeBaseURL(tt.apiBase) - if got != tt.expected { - t.Errorf("normalizeBaseURL(%q) = %q, want %q", tt.apiBase, got, tt.expected) - } - }) - } -} - func TestNewProvider(t *testing.T) { provider := NewProvider("test-key", "https://api.example.com", "") if provider == nil { diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go index 0a702e85e..afc6877b1 100644 --- a/pkg/providers/common/common.go +++ b/pkg/providers/common/common.go @@ -478,6 +478,69 @@ func AsInt(v any) (int, bool) { } } +// ExtractProtocol extracts the effective protocol and model identifier from a +// model configuration. +// +// The explicit Provider field takes precedence. When Provider is empty, the +// protocol is inferred from Model. Plain model names default to "openai". +// Provider-prefixed models strip the first slash-separated segment from the +// returned model ID. +// +// The returned protocol is normalized to the provider's canonical spelling. +// Examples: +// - Model "openai/gpt-4o" -> ("openai", "gpt-4o") +// - Model "nvidia/z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1") +// - Provider "nvidia", Model "z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1") +// - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o") +// - Model "gpt-4o" -> ("openai", "gpt-4o") +func ExtractProtocol(model string) (protocol, modelID string) { + if cfg == nil { + return "", "" + } + + model := strings.TrimSpace(cfg.Model) + if provider := strings.TrimSpace(cfg.Provider); provider != "" { + return NormalizeProvider(provider), model + } + if model == "" { + return "", "" + } + + protocol, rest, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + protocol = strings.TrimSpace(protocol) + if protocol == "" { + return "", strings.TrimSpace(rest) + } + return NormalizeProvider(protocol), strings.TrimSpace(rest) +} + +// NormalizeAnthropicBaseURL ensures the Anthropic base URL is properly formatted. +// It removes a trailing /v1 suffix if present (to avoid duplication), then +// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to +// defaultBaseURL. +func NormalizeAnthropicBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if before, ok := strings.CutSuffix(base, "/v1"); ok { + base = before + } + if base == "" { + return defaultBaseURL + } + + if appendV1Suffix { + return base + "/v1" + } + return base +} + // AsFloat converts various numeric types to float64. func AsFloat(v any) (float64, bool) { switch val := v.(type) { diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index a42d778f1..56c80e754 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -660,6 +660,71 @@ func TestAsFloat(t *testing.T) { } } +// --- ExtractProtocol tests --- + +func TestExtractProtocol(t *testing.T) { + tests := []struct { + name string + model string + wantProtocol string + wantModelID string + }{ + {"openai with prefix", "openai/gpt-4o", "openai", "gpt-4o"}, + {"anthropic with prefix", "anthropic/claude-sonnet-4.6", "anthropic", "claude-sonnet-4.6"}, + {"no prefix defaults to openai", "gpt-4o", "openai", "gpt-4o"}, + {"groq with prefix", "groq/llama-3.1-70b", "groq", "llama-3.1-70b"}, + {"empty string", "", "openai", ""}, + {"with whitespace", " openai/gpt-4 ", "openai", "gpt-4"}, + {"multiple slashes", "nvidia/meta/llama-3.1-8b", "nvidia", "meta/llama-3.1-8b"}, + {"azure with prefix", "azure/my-gpt5-deployment", "azure", "my-gpt5-deployment"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + protocol, modelID := ExtractProtocol(tt.model) + if protocol != tt.wantProtocol { + t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol) + } + if modelID != tt.wantModelID { + t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID) + } + }) + } +} + +// --- NormalizeAnthropicBaseURL tests --- + +func TestNormalizeAnthropicBaseURL(t *testing.T) { + const defaultURL = "https://api.anthropic.com" + const defaultURLWithV1 = "https://api.anthropic.com/v1" + + tests := []struct { + name string + apiBase string + defaultBase string + appendV1Suffix bool + expected string + }{ + {"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1}, + {"empty without v1", "", defaultURL, false, defaultURL}, + {"URL without v1 gets it appended", "https://api.example.com/anthropic", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"}, + {"URL without v1 stays as-is", "https://api.example.com/anthropic", defaultURL, false, "https://api.example.com/anthropic"}, + {"URL with v1 remains unchanged when appending", "https://api.example.com/v1", defaultURLWithV1, true, "https://api.example.com/v1"}, + {"URL with v1 gets it stripped when not appending", "https://api.example.com/v1", defaultURL, false, "https://api.example.com"}, + {"trailing slash cleaned with v1", "https://api.example.com/anthropic/", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"}, + {"trailing slash cleaned without v1", "https://api.example.com/anthropic/", defaultURL, false, "https://api.example.com/anthropic"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeAnthropicBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix) + if got != tt.expected { + t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q", + tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected) + } + }) + } +} + // --- WrapHTMLResponseError tests --- func TestWrapHTMLResponseError(t *testing.T) { diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 86d009811..63413b0a1 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -15,6 +15,7 @@ import ( anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" "github.com/sipeed/picoclaw/pkg/providers/azure" "github.com/sipeed/picoclaw/pkg/providers/bedrock" + "github.com/sipeed/picoclaw/pkg/providers/common" ) type protocolMeta struct { @@ -102,27 +103,7 @@ func createCodexAuthProvider() (LLMProvider, error) { // - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o") // - Model "gpt-4o" -> ("openai", "gpt-4o") func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) { - if cfg == nil { - return "", "" - } - - model := strings.TrimSpace(cfg.Model) - if provider := strings.TrimSpace(cfg.Provider); provider != "" { - return NormalizeProvider(provider), model - } - if model == "" { - return "", "" - } - - protocol, rest, found := strings.Cut(model, "/") - if !found { - return "openai", model - } - protocol = strings.TrimSpace(protocol) - if protocol == "" { - return "", strings.TrimSpace(rest) - } - return NormalizeProvider(protocol), strings.TrimSpace(rest) + return common.ExtractProtocol(model) } // ResolveAPIBase returns the configured API base, or the protocol default when diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index 36d95cf9e..0f1e20ca5 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -128,12 +128,3 @@ func sanitizeSchemaForGemini(schema map[string]any) map[string]any { return result } - -func extractProtocol(model string) (protocol, modelID string) { - model = strings.TrimSpace(model) - protocol, modelID, found := strings.Cut(model, "/") - if !found { - return "openai", model - } - return protocol, modelID -} diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index d488d06f8..dab6acd29 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -303,7 +303,7 @@ func normalizeGeminiModel(model string) string { model = strings.TrimSpace(model) model = strings.TrimPrefix(model, "models/") if strings.Contains(model, "/") { - _, modelID := extractProtocol(model) + _, modelID := common.ExtractProtocol(model) if modelID != "" { return modelID } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 98a70cfd2..29667cd31 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -470,7 +470,9 @@ func (p *Provider) SupportsNativeSearch() bool { return isNativeSearchHost(p.apiBase) } -func isNativeSearchHost(apiBase string) bool { +// isNativeOpenAIOrAzureEndpoint reports whether the given API base points to +// OpenAI's own API or an Azure OpenAI deployment. +func isNativeOpenAIOrAzureEndpoint(apiBase string) bool { u, err := url.Parse(apiBase) if err != nil { return false @@ -479,15 +481,14 @@ func isNativeSearchHost(apiBase string) bool { return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") } +func isNativeSearchHost(apiBase string) bool { + return isNativeOpenAIOrAzureEndpoint(apiBase) +} + // supportsPromptCacheKey reports whether the given API base is known to // support the prompt_cache_key request field. Currently only OpenAI's own // API and Azure OpenAI support this. All other OpenAI-compatible providers // (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors. func supportsPromptCacheKey(apiBase string) bool { - u, err := url.Parse(apiBase) - if err != nil { - return false - } - host := u.Hostname() - return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") + return isNativeOpenAIOrAzureEndpoint(apiBase) }