Functions deduplication

This commit is contained in:
Kunal Karmakar
2026-04-19 04:20:00 +00:00
parent 451db2f5d8
commit c71146b1d5
9 changed files with 145 additions and 152 deletions
+2 -18
View File
@@ -10,6 +10,7 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -42,7 +43,7 @@ func NewProvider(token string) *Provider {
}
func NewProviderWithBaseURL(token, apiBase string) *Provider {
baseURL := normalizeBaseURL(apiBase)
baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, false)
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL(baseURL),
@@ -385,20 +386,3 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
},
}
}
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
if base == "" {
return defaultBaseURL
}
return base
}
+4 -58
View File
@@ -16,6 +16,7 @@ import (
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -51,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider {
// NewProviderWithTimeout creates a provider with custom request timeout.
func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
baseURL := normalizeBaseURL(apiBase)
baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, true)
timeout := defaultRequestTimeout
if timeoutSeconds > 0 {
timeout = time.Duration(timeoutSeconds) * time.Second
@@ -161,7 +162,7 @@ func buildRequestBody(
options map[string]any,
) (map[string]any, error) {
// max_tokens is required and guaranteed by agent loop
maxTokens, ok := asInt(options["max_tokens"])
maxTokens, ok := common.AsInt(options["max_tokens"])
if !ok {
return nil, fmt.Errorf("max_tokens is required in options")
}
@@ -173,7 +174,7 @@ func buildRequestBody(
}
// Set temperature from options
if temp, ok := asFloat(options["temperature"]); ok {
if temp, ok := common.AsFloat(options["temperature"]); ok {
result["temperature"] = temp
}
@@ -361,61 +362,6 @@ func parseResponseBody(body []byte) (*LLMResponse, error) {
}, nil
}
// normalizeBaseURL ensures the base URL is properly formatted.
// It removes /v1 suffix if present (to avoid duplication) and always appends /v1.
// This handles edge cases like "https://api.example.com/v1/proxy" correctly.
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
// Remove trailing slashes
base = strings.TrimRight(base, "/")
// Remove /v1 suffix if present (will be re-added)
// This prevents duplication for URLs like "https://api.example.com/v1/proxy"
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
// Ensure we don't have an empty string after cutting
if base == "" {
return defaultBaseURL
}
// Add /v1 suffix (required by Anthropic Messages API)
return base + "/v1"
}
// Helper functions for type conversion
func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case float64:
return int(val), true
case int64:
return int(val), true
default:
return 0, false
}
}
func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
// Anthropic API response structures
type anthropicMessageResponse struct {
@@ -372,44 +372,6 @@ func TestParseResponseBody(t *testing.T) {
}
}
func TestNormalizeBaseURL(t *testing.T) {
tests := []struct {
name string
apiBase string
expected string
}{
{
name: "empty string defaults to official API",
apiBase: "",
expected: "https://api.anthropic.com/v1",
},
{
name: "URL without /v1 gets it appended",
apiBase: "https://api.example.com/anthropic",
expected: "https://api.example.com/anthropic/v1",
},
{
name: "URL with /v1 remains unchanged",
apiBase: "https://api.example.com/v1",
expected: "https://api.example.com/v1",
},
{
name: "URL with trailing slash gets cleaned",
apiBase: "https://api.example.com/anthropic/",
expected: "https://api.example.com/anthropic/v1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeBaseURL(tt.apiBase)
if got != tt.expected {
t.Errorf("normalizeBaseURL(%q) = %q, want %q", tt.apiBase, got, tt.expected)
}
})
}
}
func TestNewProvider(t *testing.T) {
provider := NewProvider("test-key", "https://api.example.com", "")
if provider == nil {
+63
View File
@@ -478,6 +478,69 @@ func AsInt(v any) (int, bool) {
}
}
// ExtractProtocol extracts the effective protocol and model identifier from a
// model configuration.
//
// The explicit Provider field takes precedence. When Provider is empty, the
// protocol is inferred from Model. Plain model names default to "openai".
// Provider-prefixed models strip the first slash-separated segment from the
// returned model ID.
//
// The returned protocol is normalized to the provider's canonical spelling.
// Examples:
// - Model "openai/gpt-4o" -> ("openai", "gpt-4o")
// - Model "nvidia/z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1")
// - Provider "nvidia", Model "z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1")
// - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o")
// - Model "gpt-4o" -> ("openai", "gpt-4o")
func ExtractProtocol(model string) (protocol, modelID string) {
if cfg == nil {
return "", ""
}
model := strings.TrimSpace(cfg.Model)
if provider := strings.TrimSpace(cfg.Provider); provider != "" {
return NormalizeProvider(provider), model
}
if model == "" {
return "", ""
}
protocol, rest, found := strings.Cut(model, "/")
if !found {
return "openai", model
}
protocol = strings.TrimSpace(protocol)
if protocol == "" {
return "", strings.TrimSpace(rest)
}
return NormalizeProvider(protocol), strings.TrimSpace(rest)
}
// NormalizeAnthropicBaseURL ensures the Anthropic base URL is properly formatted.
// It removes a trailing /v1 suffix if present (to avoid duplication), then
// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to
// defaultBaseURL.
func NormalizeAnthropicBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
if base == "" {
return defaultBaseURL
}
if appendV1Suffix {
return base + "/v1"
}
return base
}
// AsFloat converts various numeric types to float64.
func AsFloat(v any) (float64, bool) {
switch val := v.(type) {
+65
View File
@@ -660,6 +660,71 @@ func TestAsFloat(t *testing.T) {
}
}
// --- ExtractProtocol tests ---
func TestExtractProtocol(t *testing.T) {
tests := []struct {
name string
model string
wantProtocol string
wantModelID string
}{
{"openai with prefix", "openai/gpt-4o", "openai", "gpt-4o"},
{"anthropic with prefix", "anthropic/claude-sonnet-4.6", "anthropic", "claude-sonnet-4.6"},
{"no prefix defaults to openai", "gpt-4o", "openai", "gpt-4o"},
{"groq with prefix", "groq/llama-3.1-70b", "groq", "llama-3.1-70b"},
{"empty string", "", "openai", ""},
{"with whitespace", " openai/gpt-4 ", "openai", "gpt-4"},
{"multiple slashes", "nvidia/meta/llama-3.1-8b", "nvidia", "meta/llama-3.1-8b"},
{"azure with prefix", "azure/my-gpt5-deployment", "azure", "my-gpt5-deployment"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
protocol, modelID := ExtractProtocol(tt.model)
if protocol != tt.wantProtocol {
t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol)
}
if modelID != tt.wantModelID {
t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID)
}
})
}
}
// --- NormalizeAnthropicBaseURL tests ---
func TestNormalizeAnthropicBaseURL(t *testing.T) {
const defaultURL = "https://api.anthropic.com"
const defaultURLWithV1 = "https://api.anthropic.com/v1"
tests := []struct {
name string
apiBase string
defaultBase string
appendV1Suffix bool
expected string
}{
{"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1},
{"empty without v1", "", defaultURL, false, defaultURL},
{"URL without v1 gets it appended", "https://api.example.com/anthropic", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"},
{"URL without v1 stays as-is", "https://api.example.com/anthropic", defaultURL, false, "https://api.example.com/anthropic"},
{"URL with v1 remains unchanged when appending", "https://api.example.com/v1", defaultURLWithV1, true, "https://api.example.com/v1"},
{"URL with v1 gets it stripped when not appending", "https://api.example.com/v1", defaultURL, false, "https://api.example.com"},
{"trailing slash cleaned with v1", "https://api.example.com/anthropic/", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"},
{"trailing slash cleaned without v1", "https://api.example.com/anthropic/", defaultURL, false, "https://api.example.com/anthropic"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NormalizeAnthropicBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix)
if got != tt.expected {
t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q",
tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected)
}
})
}
}
// --- WrapHTMLResponseError tests ---
func TestWrapHTMLResponseError(t *testing.T) {
+2 -21
View File
@@ -15,6 +15,7 @@ import (
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
"github.com/sipeed/picoclaw/pkg/providers/azure"
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
type protocolMeta struct {
@@ -102,27 +103,7 @@ func createCodexAuthProvider() (LLMProvider, error) {
// - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o")
// - Model "gpt-4o" -> ("openai", "gpt-4o")
func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) {
if cfg == nil {
return "", ""
}
model := strings.TrimSpace(cfg.Model)
if provider := strings.TrimSpace(cfg.Provider); provider != "" {
return NormalizeProvider(provider), model
}
if model == "" {
return "", ""
}
protocol, rest, found := strings.Cut(model, "/")
if !found {
return "openai", model
}
protocol = strings.TrimSpace(protocol)
if protocol == "" {
return "", strings.TrimSpace(rest)
}
return NormalizeProvider(protocol), strings.TrimSpace(rest)
return common.ExtractProtocol(model)
}
// ResolveAPIBase returns the configured API base, or the protocol default when
-9
View File
@@ -128,12 +128,3 @@ func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
return result
}
func extractProtocol(model string) (protocol, modelID string) {
model = strings.TrimSpace(model)
protocol, modelID, found := strings.Cut(model, "/")
if !found {
return "openai", model
}
return protocol, modelID
}
+1 -1
View File
@@ -303,7 +303,7 @@ func normalizeGeminiModel(model string) string {
model = strings.TrimSpace(model)
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "/") {
_, modelID := extractProtocol(model)
_, modelID := common.ExtractProtocol(model)
if modelID != "" {
return modelID
}
+8 -7
View File
@@ -470,7 +470,9 @@ func (p *Provider) SupportsNativeSearch() bool {
return isNativeSearchHost(p.apiBase)
}
func isNativeSearchHost(apiBase string) bool {
// isNativeOpenAIOrAzureEndpoint reports whether the given API base points to
// OpenAI's own API or an Azure OpenAI deployment.
func isNativeOpenAIOrAzureEndpoint(apiBase string) bool {
u, err := url.Parse(apiBase)
if err != nil {
return false
@@ -479,15 +481,14 @@ func isNativeSearchHost(apiBase string) bool {
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
}
func isNativeSearchHost(apiBase string) bool {
return isNativeOpenAIOrAzureEndpoint(apiBase)
}
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API and Azure OpenAI support this. All other OpenAI-compatible providers
// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
func supportsPromptCacheKey(apiBase string) bool {
u, err := url.Parse(apiBase)
if err != nil {
return false
}
host := u.Hostname()
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
return isNativeOpenAIOrAzureEndpoint(apiBase)
}