mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Functions deduplication
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -42,7 +43,7 @@ func NewProvider(token string) *Provider {
|
||||
}
|
||||
|
||||
func NewProviderWithBaseURL(token, apiBase string) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, false)
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL(baseURL),
|
||||
@@ -385,20 +386,3 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeBaseURL(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -51,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider {
|
||||
|
||||
// NewProviderWithTimeout creates a provider with custom request timeout.
|
||||
func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, true)
|
||||
timeout := defaultRequestTimeout
|
||||
if timeoutSeconds > 0 {
|
||||
timeout = time.Duration(timeoutSeconds) * time.Second
|
||||
@@ -161,7 +162,7 @@ func buildRequestBody(
|
||||
options map[string]any,
|
||||
) (map[string]any, error) {
|
||||
// max_tokens is required and guaranteed by agent loop
|
||||
maxTokens, ok := asInt(options["max_tokens"])
|
||||
maxTokens, ok := common.AsInt(options["max_tokens"])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("max_tokens is required in options")
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func buildRequestBody(
|
||||
}
|
||||
|
||||
// Set temperature from options
|
||||
if temp, ok := asFloat(options["temperature"]); ok {
|
||||
if temp, ok := common.AsFloat(options["temperature"]); ok {
|
||||
result["temperature"] = temp
|
||||
}
|
||||
|
||||
@@ -361,61 +362,6 @@ func parseResponseBody(body []byte) (*LLMResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeBaseURL ensures the base URL is properly formatted.
|
||||
// It removes /v1 suffix if present (to avoid duplication) and always appends /v1.
|
||||
// This handles edge cases like "https://api.example.com/v1/proxy" correctly.
|
||||
func normalizeBaseURL(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
// Remove trailing slashes
|
||||
base = strings.TrimRight(base, "/")
|
||||
|
||||
// Remove /v1 suffix if present (will be re-added)
|
||||
// This prevents duplication for URLs like "https://api.example.com/v1/proxy"
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
|
||||
// Ensure we don't have an empty string after cutting
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
// Add /v1 suffix (required by Anthropic Messages API)
|
||||
return base + "/v1"
|
||||
}
|
||||
|
||||
// Helper functions for type conversion
|
||||
|
||||
func asInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case int64:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic API response structures
|
||||
|
||||
type anthropicMessageResponse struct {
|
||||
|
||||
@@ -372,44 +372,6 @@ func TestParseResponseBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeBaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string defaults to official API",
|
||||
apiBase: "",
|
||||
expected: "https://api.anthropic.com/v1",
|
||||
},
|
||||
{
|
||||
name: "URL without /v1 gets it appended",
|
||||
apiBase: "https://api.example.com/anthropic",
|
||||
expected: "https://api.example.com/anthropic/v1",
|
||||
},
|
||||
{
|
||||
name: "URL with /v1 remains unchanged",
|
||||
apiBase: "https://api.example.com/v1",
|
||||
expected: "https://api.example.com/v1",
|
||||
},
|
||||
{
|
||||
name: "URL with trailing slash gets cleaned",
|
||||
apiBase: "https://api.example.com/anthropic/",
|
||||
expected: "https://api.example.com/anthropic/v1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeBaseURL(tt.apiBase)
|
||||
if got != tt.expected {
|
||||
t.Errorf("normalizeBaseURL(%q) = %q, want %q", tt.apiBase, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
provider := NewProvider("test-key", "https://api.example.com", "")
|
||||
if provider == nil {
|
||||
|
||||
@@ -478,6 +478,69 @@ func AsInt(v any) (int, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractProtocol extracts the effective protocol and model identifier from a
|
||||
// model configuration.
|
||||
//
|
||||
// The explicit Provider field takes precedence. When Provider is empty, the
|
||||
// protocol is inferred from Model. Plain model names default to "openai".
|
||||
// Provider-prefixed models strip the first slash-separated segment from the
|
||||
// returned model ID.
|
||||
//
|
||||
// The returned protocol is normalized to the provider's canonical spelling.
|
||||
// Examples:
|
||||
// - Model "openai/gpt-4o" -> ("openai", "gpt-4o")
|
||||
// - Model "nvidia/z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1")
|
||||
// - Provider "nvidia", Model "z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1")
|
||||
// - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o")
|
||||
// - Model "gpt-4o" -> ("openai", "gpt-4o")
|
||||
func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
if cfg == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(cfg.Model)
|
||||
if provider := strings.TrimSpace(cfg.Provider); provider != "" {
|
||||
return NormalizeProvider(provider), model
|
||||
}
|
||||
if model == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
protocol, rest, found := strings.Cut(model, "/")
|
||||
if !found {
|
||||
return "openai", model
|
||||
}
|
||||
protocol = strings.TrimSpace(protocol)
|
||||
if protocol == "" {
|
||||
return "", strings.TrimSpace(rest)
|
||||
}
|
||||
return NormalizeProvider(protocol), strings.TrimSpace(rest)
|
||||
}
|
||||
|
||||
// NormalizeAnthropicBaseURL ensures the Anthropic base URL is properly formatted.
|
||||
// It removes a trailing /v1 suffix if present (to avoid duplication), then
|
||||
// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to
|
||||
// defaultBaseURL.
|
||||
func NormalizeAnthropicBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
if appendV1Suffix {
|
||||
return base + "/v1"
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// AsFloat converts various numeric types to float64.
|
||||
func AsFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
|
||||
@@ -660,6 +660,71 @@ func TestAsFloat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- ExtractProtocol tests ---
|
||||
|
||||
func TestExtractProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantProtocol string
|
||||
wantModelID string
|
||||
}{
|
||||
{"openai with prefix", "openai/gpt-4o", "openai", "gpt-4o"},
|
||||
{"anthropic with prefix", "anthropic/claude-sonnet-4.6", "anthropic", "claude-sonnet-4.6"},
|
||||
{"no prefix defaults to openai", "gpt-4o", "openai", "gpt-4o"},
|
||||
{"groq with prefix", "groq/llama-3.1-70b", "groq", "llama-3.1-70b"},
|
||||
{"empty string", "", "openai", ""},
|
||||
{"with whitespace", " openai/gpt-4 ", "openai", "gpt-4"},
|
||||
{"multiple slashes", "nvidia/meta/llama-3.1-8b", "nvidia", "meta/llama-3.1-8b"},
|
||||
{"azure with prefix", "azure/my-gpt5-deployment", "azure", "my-gpt5-deployment"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
protocol, modelID := ExtractProtocol(tt.model)
|
||||
if protocol != tt.wantProtocol {
|
||||
t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol)
|
||||
}
|
||||
if modelID != tt.wantModelID {
|
||||
t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- NormalizeAnthropicBaseURL tests ---
|
||||
|
||||
func TestNormalizeAnthropicBaseURL(t *testing.T) {
|
||||
const defaultURL = "https://api.anthropic.com"
|
||||
const defaultURLWithV1 = "https://api.anthropic.com/v1"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
defaultBase string
|
||||
appendV1Suffix bool
|
||||
expected string
|
||||
}{
|
||||
{"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1},
|
||||
{"empty without v1", "", defaultURL, false, defaultURL},
|
||||
{"URL without v1 gets it appended", "https://api.example.com/anthropic", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"},
|
||||
{"URL without v1 stays as-is", "https://api.example.com/anthropic", defaultURL, false, "https://api.example.com/anthropic"},
|
||||
{"URL with v1 remains unchanged when appending", "https://api.example.com/v1", defaultURLWithV1, true, "https://api.example.com/v1"},
|
||||
{"URL with v1 gets it stripped when not appending", "https://api.example.com/v1", defaultURL, false, "https://api.example.com"},
|
||||
{"trailing slash cleaned with v1", "https://api.example.com/anthropic/", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"},
|
||||
{"trailing slash cleaned without v1", "https://api.example.com/anthropic/", defaultURL, false, "https://api.example.com/anthropic"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NormalizeAnthropicBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix)
|
||||
if got != tt.expected {
|
||||
t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q",
|
||||
tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- WrapHTMLResponseError tests ---
|
||||
|
||||
func TestWrapHTMLResponseError(t *testing.T) {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/bedrock"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
type protocolMeta struct {
|
||||
@@ -102,27 +103,7 @@ func createCodexAuthProvider() (LLMProvider, error) {
|
||||
// - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o")
|
||||
// - Model "gpt-4o" -> ("openai", "gpt-4o")
|
||||
func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) {
|
||||
if cfg == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(cfg.Model)
|
||||
if provider := strings.TrimSpace(cfg.Provider); provider != "" {
|
||||
return NormalizeProvider(provider), model
|
||||
}
|
||||
if model == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
protocol, rest, found := strings.Cut(model, "/")
|
||||
if !found {
|
||||
return "openai", model
|
||||
}
|
||||
protocol = strings.TrimSpace(protocol)
|
||||
if protocol == "" {
|
||||
return "", strings.TrimSpace(rest)
|
||||
}
|
||||
return NormalizeProvider(protocol), strings.TrimSpace(rest)
|
||||
return common.ExtractProtocol(model)
|
||||
}
|
||||
|
||||
// ResolveAPIBase returns the configured API base, or the protocol default when
|
||||
|
||||
@@ -128,12 +128,3 @@ func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func extractProtocol(model string) (protocol, modelID string) {
|
||||
model = strings.TrimSpace(model)
|
||||
protocol, modelID, found := strings.Cut(model, "/")
|
||||
if !found {
|
||||
return "openai", model
|
||||
}
|
||||
return protocol, modelID
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ func normalizeGeminiModel(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
if strings.Contains(model, "/") {
|
||||
_, modelID := extractProtocol(model)
|
||||
_, modelID := common.ExtractProtocol(model)
|
||||
if modelID != "" {
|
||||
return modelID
|
||||
}
|
||||
|
||||
@@ -470,7 +470,9 @@ func (p *Provider) SupportsNativeSearch() bool {
|
||||
return isNativeSearchHost(p.apiBase)
|
||||
}
|
||||
|
||||
func isNativeSearchHost(apiBase string) bool {
|
||||
// isNativeOpenAIOrAzureEndpoint reports whether the given API base points to
|
||||
// OpenAI's own API or an Azure OpenAI deployment.
|
||||
func isNativeOpenAIOrAzureEndpoint(apiBase string) bool {
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
@@ -479,15 +481,14 @@ func isNativeSearchHost(apiBase string) bool {
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
}
|
||||
|
||||
func isNativeSearchHost(apiBase string) bool {
|
||||
return isNativeOpenAIOrAzureEndpoint(apiBase)
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
|
||||
func supportsPromptCacheKey(apiBase string) bool {
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := u.Hostname()
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
return isNativeOpenAIOrAzureEndpoint(apiBase)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user