From c71146b1d53f4f5e57a25a52c29268ec5b854a40 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Sun, 19 Apr 2026 04:20:00 +0000 Subject: [PATCH 1/5] Functions deduplication --- pkg/providers/anthropic/provider.go | 20 +----- pkg/providers/anthropic_messages/provider.go | 62 ++---------------- .../anthropic_messages/provider_test.go | 38 ----------- pkg/providers/common/common.go | 63 ++++++++++++++++++ pkg/providers/common/common_test.go | 65 +++++++++++++++++++ pkg/providers/factory_provider.go | 23 +------ pkg/providers/httpapi/gemini_helpers.go | 9 --- pkg/providers/httpapi/gemini_provider.go | 2 +- pkg/providers/openai_compat/provider.go | 15 +++-- 9 files changed, 145 insertions(+), 152 deletions(-) diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index d4ceaab2c..4330163df 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -10,6 +10,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -42,7 +43,7 @@ func NewProvider(token string) *Provider { } func NewProviderWithBaseURL(token, apiBase string) *Provider { - baseURL := normalizeBaseURL(apiBase) + baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, false) client := anthropic.NewClient( option.WithAuthToken(token), option.WithBaseURL(baseURL), @@ -385,20 +386,3 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { }, } } - -func normalizeBaseURL(apiBase string) string { - base := strings.TrimSpace(apiBase) - if base == "" { - return defaultBaseURL - } - - base = strings.TrimRight(base, "/") - if before, ok := strings.CutSuffix(base, "/v1"); ok { - base = before - } - if base == "" { - return defaultBaseURL - } - - return base -} diff --git a/pkg/providers/anthropic_messages/provider.go b/pkg/providers/anthropic_messages/provider.go index 1e865b709..dcc31b6f9 100644 --- a/pkg/providers/anthropic_messages/provider.go +++ b/pkg/providers/anthropic_messages/provider.go @@ -16,6 +16,7 @@ import ( "strings" "time" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -51,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider { // NewProviderWithTimeout creates a provider with custom request timeout. func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider { - baseURL := normalizeBaseURL(apiBase) + baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, true) timeout := defaultRequestTimeout if timeoutSeconds > 0 { timeout = time.Duration(timeoutSeconds) * time.Second @@ -161,7 +162,7 @@ func buildRequestBody( options map[string]any, ) (map[string]any, error) { // max_tokens is required and guaranteed by agent loop - maxTokens, ok := asInt(options["max_tokens"]) + maxTokens, ok := common.AsInt(options["max_tokens"]) if !ok { return nil, fmt.Errorf("max_tokens is required in options") } @@ -173,7 +174,7 @@ func buildRequestBody( } // Set temperature from options - if temp, ok := asFloat(options["temperature"]); ok { + if temp, ok := common.AsFloat(options["temperature"]); ok { result["temperature"] = temp } @@ -361,61 +362,6 @@ func parseResponseBody(body []byte) (*LLMResponse, error) { }, nil } -// normalizeBaseURL ensures the base URL is properly formatted. -// It removes /v1 suffix if present (to avoid duplication) and always appends /v1. -// This handles edge cases like "https://api.example.com/v1/proxy" correctly. -func normalizeBaseURL(apiBase string) string { - base := strings.TrimSpace(apiBase) - if base == "" { - return defaultBaseURL - } - - // Remove trailing slashes - base = strings.TrimRight(base, "/") - - // Remove /v1 suffix if present (will be re-added) - // This prevents duplication for URLs like "https://api.example.com/v1/proxy" - if before, ok := strings.CutSuffix(base, "/v1"); ok { - base = before - } - - // Ensure we don't have an empty string after cutting - if base == "" { - return defaultBaseURL - } - - // Add /v1 suffix (required by Anthropic Messages API) - return base + "/v1" -} - -// Helper functions for type conversion - -func asInt(v any) (int, bool) { - switch val := v.(type) { - case int: - return val, true - case float64: - return int(val), true - case int64: - return int(val), true - default: - return 0, false - } -} - -func asFloat(v any) (float64, bool) { - switch val := v.(type) { - case float64: - return val, true - case int: - return float64(val), true - case int64: - return float64(val), true - default: - return 0, false - } -} - // Anthropic API response structures type anthropicMessageResponse struct { diff --git a/pkg/providers/anthropic_messages/provider_test.go b/pkg/providers/anthropic_messages/provider_test.go index ba9d24b66..6401d84bd 100644 --- a/pkg/providers/anthropic_messages/provider_test.go +++ b/pkg/providers/anthropic_messages/provider_test.go @@ -372,44 +372,6 @@ func TestParseResponseBody(t *testing.T) { } } -func TestNormalizeBaseURL(t *testing.T) { - tests := []struct { - name string - apiBase string - expected string - }{ - { - name: "empty string defaults to official API", - apiBase: "", - expected: "https://api.anthropic.com/v1", - }, - { - name: "URL without /v1 gets it appended", - apiBase: "https://api.example.com/anthropic", - expected: "https://api.example.com/anthropic/v1", - }, - { - name: "URL with /v1 remains unchanged", - apiBase: "https://api.example.com/v1", - expected: "https://api.example.com/v1", - }, - { - name: "URL with trailing slash gets cleaned", - apiBase: "https://api.example.com/anthropic/", - expected: "https://api.example.com/anthropic/v1", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := normalizeBaseURL(tt.apiBase) - if got != tt.expected { - t.Errorf("normalizeBaseURL(%q) = %q, want %q", tt.apiBase, got, tt.expected) - } - }) - } -} - func TestNewProvider(t *testing.T) { provider := NewProvider("test-key", "https://api.example.com", "") if provider == nil { diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go index 0a702e85e..afc6877b1 100644 --- a/pkg/providers/common/common.go +++ b/pkg/providers/common/common.go @@ -478,6 +478,69 @@ func AsInt(v any) (int, bool) { } } +// ExtractProtocol extracts the effective protocol and model identifier from a +// model configuration. +// +// The explicit Provider field takes precedence. When Provider is empty, the +// protocol is inferred from Model. Plain model names default to "openai". +// Provider-prefixed models strip the first slash-separated segment from the +// returned model ID. +// +// The returned protocol is normalized to the provider's canonical spelling. +// Examples: +// - Model "openai/gpt-4o" -> ("openai", "gpt-4o") +// - Model "nvidia/z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1") +// - Provider "nvidia", Model "z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1") +// - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o") +// - Model "gpt-4o" -> ("openai", "gpt-4o") +func ExtractProtocol(model string) (protocol, modelID string) { + if cfg == nil { + return "", "" + } + + model := strings.TrimSpace(cfg.Model) + if provider := strings.TrimSpace(cfg.Provider); provider != "" { + return NormalizeProvider(provider), model + } + if model == "" { + return "", "" + } + + protocol, rest, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + protocol = strings.TrimSpace(protocol) + if protocol == "" { + return "", strings.TrimSpace(rest) + } + return NormalizeProvider(protocol), strings.TrimSpace(rest) +} + +// NormalizeAnthropicBaseURL ensures the Anthropic base URL is properly formatted. +// It removes a trailing /v1 suffix if present (to avoid duplication), then +// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to +// defaultBaseURL. +func NormalizeAnthropicBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if before, ok := strings.CutSuffix(base, "/v1"); ok { + base = before + } + if base == "" { + return defaultBaseURL + } + + if appendV1Suffix { + return base + "/v1" + } + return base +} + // AsFloat converts various numeric types to float64. func AsFloat(v any) (float64, bool) { switch val := v.(type) { diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index a42d778f1..56c80e754 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -660,6 +660,71 @@ func TestAsFloat(t *testing.T) { } } +// --- ExtractProtocol tests --- + +func TestExtractProtocol(t *testing.T) { + tests := []struct { + name string + model string + wantProtocol string + wantModelID string + }{ + {"openai with prefix", "openai/gpt-4o", "openai", "gpt-4o"}, + {"anthropic with prefix", "anthropic/claude-sonnet-4.6", "anthropic", "claude-sonnet-4.6"}, + {"no prefix defaults to openai", "gpt-4o", "openai", "gpt-4o"}, + {"groq with prefix", "groq/llama-3.1-70b", "groq", "llama-3.1-70b"}, + {"empty string", "", "openai", ""}, + {"with whitespace", " openai/gpt-4 ", "openai", "gpt-4"}, + {"multiple slashes", "nvidia/meta/llama-3.1-8b", "nvidia", "meta/llama-3.1-8b"}, + {"azure with prefix", "azure/my-gpt5-deployment", "azure", "my-gpt5-deployment"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + protocol, modelID := ExtractProtocol(tt.model) + if protocol != tt.wantProtocol { + t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol) + } + if modelID != tt.wantModelID { + t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID) + } + }) + } +} + +// --- NormalizeAnthropicBaseURL tests --- + +func TestNormalizeAnthropicBaseURL(t *testing.T) { + const defaultURL = "https://api.anthropic.com" + const defaultURLWithV1 = "https://api.anthropic.com/v1" + + tests := []struct { + name string + apiBase string + defaultBase string + appendV1Suffix bool + expected string + }{ + {"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1}, + {"empty without v1", "", defaultURL, false, defaultURL}, + {"URL without v1 gets it appended", "https://api.example.com/anthropic", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"}, + {"URL without v1 stays as-is", "https://api.example.com/anthropic", defaultURL, false, "https://api.example.com/anthropic"}, + {"URL with v1 remains unchanged when appending", "https://api.example.com/v1", defaultURLWithV1, true, "https://api.example.com/v1"}, + {"URL with v1 gets it stripped when not appending", "https://api.example.com/v1", defaultURL, false, "https://api.example.com"}, + {"trailing slash cleaned with v1", "https://api.example.com/anthropic/", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"}, + {"trailing slash cleaned without v1", "https://api.example.com/anthropic/", defaultURL, false, "https://api.example.com/anthropic"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeAnthropicBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix) + if got != tt.expected { + t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q", + tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected) + } + }) + } +} + // --- WrapHTMLResponseError tests --- func TestWrapHTMLResponseError(t *testing.T) { diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 86d009811..63413b0a1 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -15,6 +15,7 @@ import ( anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" "github.com/sipeed/picoclaw/pkg/providers/azure" "github.com/sipeed/picoclaw/pkg/providers/bedrock" + "github.com/sipeed/picoclaw/pkg/providers/common" ) type protocolMeta struct { @@ -102,27 +103,7 @@ func createCodexAuthProvider() (LLMProvider, error) { // - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o") // - Model "gpt-4o" -> ("openai", "gpt-4o") func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) { - if cfg == nil { - return "", "" - } - - model := strings.TrimSpace(cfg.Model) - if provider := strings.TrimSpace(cfg.Provider); provider != "" { - return NormalizeProvider(provider), model - } - if model == "" { - return "", "" - } - - protocol, rest, found := strings.Cut(model, "/") - if !found { - return "openai", model - } - protocol = strings.TrimSpace(protocol) - if protocol == "" { - return "", strings.TrimSpace(rest) - } - return NormalizeProvider(protocol), strings.TrimSpace(rest) + return common.ExtractProtocol(model) } // ResolveAPIBase returns the configured API base, or the protocol default when diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index 36d95cf9e..0f1e20ca5 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -128,12 +128,3 @@ func sanitizeSchemaForGemini(schema map[string]any) map[string]any { return result } - -func extractProtocol(model string) (protocol, modelID string) { - model = strings.TrimSpace(model) - protocol, modelID, found := strings.Cut(model, "/") - if !found { - return "openai", model - } - return protocol, modelID -} diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index d488d06f8..dab6acd29 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -303,7 +303,7 @@ func normalizeGeminiModel(model string) string { model = strings.TrimSpace(model) model = strings.TrimPrefix(model, "models/") if strings.Contains(model, "/") { - _, modelID := extractProtocol(model) + _, modelID := common.ExtractProtocol(model) if modelID != "" { return modelID } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 98a70cfd2..29667cd31 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -470,7 +470,9 @@ func (p *Provider) SupportsNativeSearch() bool { return isNativeSearchHost(p.apiBase) } -func isNativeSearchHost(apiBase string) bool { +// isNativeOpenAIOrAzureEndpoint reports whether the given API base points to +// OpenAI's own API or an Azure OpenAI deployment. +func isNativeOpenAIOrAzureEndpoint(apiBase string) bool { u, err := url.Parse(apiBase) if err != nil { return false @@ -479,15 +481,14 @@ func isNativeSearchHost(apiBase string) bool { return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") } +func isNativeSearchHost(apiBase string) bool { + return isNativeOpenAIOrAzureEndpoint(apiBase) +} + // supportsPromptCacheKey reports whether the given API base is known to // support the prompt_cache_key request field. Currently only OpenAI's own // API and Azure OpenAI support this. All other OpenAI-compatible providers // (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors. func supportsPromptCacheKey(apiBase string) bool { - u, err := url.Parse(apiBase) - if err != nil { - return false - } - host := u.Hostname() - return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") + return isNativeOpenAIOrAzureEndpoint(apiBase) } From e901e70c1493aede97e5fb7d9022fc76ea264d95 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Sun, 19 Apr 2026 04:30:21 +0000 Subject: [PATCH 2/5] Fix linting --- pkg/providers/common/common_test.go | 36 ++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index 56c80e754..71c1bd1d1 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -706,12 +706,36 @@ func TestNormalizeAnthropicBaseURL(t *testing.T) { }{ {"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1}, {"empty without v1", "", defaultURL, false, defaultURL}, - {"URL without v1 gets it appended", "https://api.example.com/anthropic", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"}, - {"URL without v1 stays as-is", "https://api.example.com/anthropic", defaultURL, false, "https://api.example.com/anthropic"}, - {"URL with v1 remains unchanged when appending", "https://api.example.com/v1", defaultURLWithV1, true, "https://api.example.com/v1"}, - {"URL with v1 gets it stripped when not appending", "https://api.example.com/v1", defaultURL, false, "https://api.example.com"}, - {"trailing slash cleaned with v1", "https://api.example.com/anthropic/", defaultURLWithV1, true, "https://api.example.com/anthropic/v1"}, - {"trailing slash cleaned without v1", "https://api.example.com/anthropic/", defaultURL, false, "https://api.example.com/anthropic"}, + { + "URL without v1 gets it appended", + "https://api.example.com/anthropic", defaultURLWithV1, + true, "https://api.example.com/anthropic/v1", + }, + { + "URL without v1 stays as-is", + "https://api.example.com/anthropic", defaultURL, + false, "https://api.example.com/anthropic", + }, + { + "URL with v1 remains unchanged when appending", + "https://api.example.com/v1", defaultURLWithV1, + true, "https://api.example.com/v1", + }, + { + "URL with v1 gets it stripped when not appending", + "https://api.example.com/v1", defaultURL, + false, "https://api.example.com", + }, + { + "trailing slash cleaned with v1", + "https://api.example.com/anthropic/", defaultURLWithV1, + true, "https://api.example.com/anthropic/v1", + }, + { + "trailing slash cleaned without v1", + "https://api.example.com/anthropic/", defaultURL, + false, "https://api.example.com/anthropic", + }, } for _, tt := range tests { From bc077db0ee4f3730183b18cde74ddd7581c4cb47 Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Sun, 19 Apr 2026 06:26:52 +0000 Subject: [PATCH 3/5] Deduplicate ParseDataAudioURL function --- pkg/providers/common/common.go | 68 +------------------ pkg/providers/common/common_test.go | 31 +++++++++ .../responses_common.go | 22 +----- .../responses_common_test.go | 36 ---------- 4 files changed, 36 insertions(+), 121 deletions(-) diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go index afc6877b1..5e03bc0c2 100644 --- a/pkg/providers/common/common.go +++ b/pkg/providers/common/common.go @@ -127,7 +127,7 @@ func SerializeMessages(messages []Message) []any { continue } - if format, data, ok := parseDataAudioURL(mediaURL); ok { + if format, data, ok := ParseDataAudioURL(mediaURL); ok { parts = append(parts, map[string]any{ "type": "input_audio", "input_audio": map[string]any{ @@ -205,7 +205,8 @@ func serializeToolCalls(toolCalls []ToolCall) []openaiToolCall { return out } -func parseDataAudioURL(mediaURL string) (format, data string, ok bool) { +// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL. +func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) { if !strings.HasPrefix(mediaURL, "data:audio/") { return "", "", false } @@ -478,69 +479,6 @@ func AsInt(v any) (int, bool) { } } -// ExtractProtocol extracts the effective protocol and model identifier from a -// model configuration. -// -// The explicit Provider field takes precedence. When Provider is empty, the -// protocol is inferred from Model. Plain model names default to "openai". -// Provider-prefixed models strip the first slash-separated segment from the -// returned model ID. -// -// The returned protocol is normalized to the provider's canonical spelling. -// Examples: -// - Model "openai/gpt-4o" -> ("openai", "gpt-4o") -// - Model "nvidia/z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1") -// - Provider "nvidia", Model "z-ai/glm-5.1" -> ("nvidia", "z-ai/glm-5.1") -// - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o") -// - Model "gpt-4o" -> ("openai", "gpt-4o") -func ExtractProtocol(model string) (protocol, modelID string) { - if cfg == nil { - return "", "" - } - - model := strings.TrimSpace(cfg.Model) - if provider := strings.TrimSpace(cfg.Provider); provider != "" { - return NormalizeProvider(provider), model - } - if model == "" { - return "", "" - } - - protocol, rest, found := strings.Cut(model, "/") - if !found { - return "openai", model - } - protocol = strings.TrimSpace(protocol) - if protocol == "" { - return "", strings.TrimSpace(rest) - } - return NormalizeProvider(protocol), strings.TrimSpace(rest) -} - -// NormalizeAnthropicBaseURL ensures the Anthropic base URL is properly formatted. -// It removes a trailing /v1 suffix if present (to avoid duplication), then -// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to -// defaultBaseURL. -func NormalizeAnthropicBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string { - base := strings.TrimSpace(apiBase) - if base == "" { - return defaultBaseURL - } - - base = strings.TrimRight(base, "/") - if before, ok := strings.CutSuffix(base, "/v1"); ok { - base = before - } - if base == "" { - return defaultBaseURL - } - - if appendV1Suffix { - return base + "/v1" - } - return base -} - // AsFloat converts various numeric types to float64. func AsFloat(v any) (float64, bool) { switch val := v.(type) { diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index 71c1bd1d1..1f9a9b827 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -691,6 +691,37 @@ func TestExtractProtocol(t *testing.T) { } } +// --- ParseDataAudioURL tests --- + +func TestParseDataAudioURL(t *testing.T) { + tests := []struct { + name string + mediaURL string + wantFormat string + wantData string + wantOK bool + }{ + {"valid mp3", "data:audio/mp3;base64,SGVsbG8=", "mp3", "SGVsbG8=", true}, + {"valid wav", "data:audio/wav;base64,AAAA", "wav", "AAAA", true}, + {"not audio", "data:image/png;base64,abc", "", "", false}, + {"no comma", "data:audio/mp3;base64", "", "", false}, + {"empty data", "data:audio/mp3;base64,", "", "", false}, + {"empty string", "", "", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + format, data, ok := ParseDataAudioURL(tt.mediaURL) + if ok != tt.wantOK || format != tt.wantFormat || data != tt.wantData { + t.Errorf( + "ParseDataAudioURL(%q) = (%q, %q, %v), want (%q, %q, %v)", + tt.mediaURL, format, data, ok, + tt.wantFormat, tt.wantData, tt.wantOK, + ) + } + }) + } +} + // --- NormalizeAnthropicBaseURL tests --- func TestNormalizeAnthropicBaseURL(t *testing.T) { diff --git a/pkg/providers/openai_responses_common/responses_common.go b/pkg/providers/openai_responses_common/responses_common.go index 839471f69..17b731ed4 100644 --- a/pkg/providers/openai_responses_common/responses_common.go +++ b/pkg/providers/openai_responses_common/responses_common.go @@ -10,6 +10,7 @@ import ( "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/responses" + "github.com/sipeed/picoclaw/pkg/providers/common" "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) @@ -118,7 +119,7 @@ func BuildMultipartContent(text string, media []string) responses.ResponseInputM }, }) } else if strings.HasPrefix(mediaURL, "data:audio/") { - if format, data, ok := ParseDataAudioURL(mediaURL); ok { + if format, data, ok := common.ParseDataAudioURL(mediaURL); ok { parts = append(parts, responses.ResponseInputContentUnionParam{ OfInputFile: &responses.ResponseInputFileParam{ FileData: openai.Opt(data), @@ -132,25 +133,6 @@ func BuildMultipartContent(text string, media []string) responses.ResponseInputM return parts } -// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL. -func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) { - if !strings.HasPrefix(mediaURL, "data:audio/") { - return "", "", false - } - payload := strings.TrimPrefix(mediaURL, "data:audio/") - meta, data, found := strings.Cut(payload, ",") - if !found { - return "", "", false - } - format, _, _ = strings.Cut(meta, ";") - format = strings.TrimSpace(format) - data = strings.TrimSpace(data) - if format == "" || data == "" { - return "", "", false - } - return format, data, true -} - // ResolveToolCall extracts the function name and JSON arguments string from a ToolCall. // Returns ok=false if the tool call has no name or if arguments fail to marshal. func ResolveToolCall(tc protocoltypes.ToolCall) (name string, arguments string, ok bool) { diff --git a/pkg/providers/openai_responses_common/responses_common_test.go b/pkg/providers/openai_responses_common/responses_common_test.go index 0d41190b1..ace91edf0 100644 --- a/pkg/providers/openai_responses_common/responses_common_test.go +++ b/pkg/providers/openai_responses_common/responses_common_test.go @@ -506,42 +506,6 @@ func TestParseResponseBody_CanceledStatus(t *testing.T) { } } -// --- ParseDataAudioURL tests --- - -func TestParseDataAudioURL_Valid(t *testing.T) { - format, data, ok := ParseDataAudioURL("data:audio/mp3;base64,SGVsbG8=") - if !ok { - t.Fatal("expected ok=true") - } - if format != "mp3" { - t.Errorf("format = %q, want %q", format, "mp3") - } - if data != "SGVsbG8=" { - t.Errorf("data = %q, want %q", data, "SGVsbG8=") - } -} - -func TestParseDataAudioURL_NotAudio(t *testing.T) { - _, _, ok := ParseDataAudioURL("data:image/png;base64,abc") - if ok { - t.Error("expected ok=false for non-audio URL") - } -} - -func TestParseDataAudioURL_MalformedNoComma(t *testing.T) { - _, _, ok := ParseDataAudioURL("data:audio/mp3;base64") - if ok { - t.Error("expected ok=false for malformed URL") - } -} - -func TestParseDataAudioURL_EmptyData(t *testing.T) { - _, _, ok := ParseDataAudioURL("data:audio/mp3;base64,") - if ok { - t.Error("expected ok=false for empty data") - } -} - // --- BuildMultipartContent tests --- func TestBuildMultipartContent_TextOnly(t *testing.T) { From 4ae11406d2118793848ef9f74627afbb74cd97cb Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Sun, 19 Apr 2026 06:48:28 +0000 Subject: [PATCH 4/5] Deduplicate further functions --- pkg/providers/anthropic/provider.go | 2 +- pkg/providers/anthropic_messages/provider.go | 2 +- pkg/providers/common/anthropic_common.go | 27 ++++ pkg/providers/common/anthropic_common_test.go | 59 +++++++ pkg/providers/common/common_test.go | 58 ------- pkg/providers/common/google_common.go | 70 +++++++++ pkg/providers/common/google_common_test.go | 146 ++++++++++++++++++ pkg/providers/httpapi/gemini_helpers.go | 59 ------- pkg/providers/httpapi/gemini_provider.go | 6 +- pkg/providers/oauth/antigravity_provider.go | 61 +------- .../oauth/antigravity_provider_test.go | 7 - 11 files changed, 311 insertions(+), 186 deletions(-) create mode 100644 pkg/providers/common/anthropic_common.go create mode 100644 pkg/providers/common/anthropic_common_test.go create mode 100644 pkg/providers/common/google_common.go create mode 100644 pkg/providers/common/google_common_test.go diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index 4330163df..6f4aadb8b 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -43,7 +43,7 @@ func NewProvider(token string) *Provider { } func NewProviderWithBaseURL(token, apiBase string) *Provider { - baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, false) + baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, false) client := anthropic.NewClient( option.WithAuthToken(token), option.WithBaseURL(baseURL), diff --git a/pkg/providers/anthropic_messages/provider.go b/pkg/providers/anthropic_messages/provider.go index dcc31b6f9..672fb9324 100644 --- a/pkg/providers/anthropic_messages/provider.go +++ b/pkg/providers/anthropic_messages/provider.go @@ -52,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider { // NewProviderWithTimeout creates a provider with custom request timeout. func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider { - baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, true) + baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, true) timeout := defaultRequestTimeout if timeoutSeconds > 0 { timeout = time.Duration(timeoutSeconds) * time.Second diff --git a/pkg/providers/common/anthropic_common.go b/pkg/providers/common/anthropic_common.go new file mode 100644 index 000000000..92dace9ac --- /dev/null +++ b/pkg/providers/common/anthropic_common.go @@ -0,0 +1,27 @@ +package common + +import "strings" + +// NormalizeBaseURL ensures the Anthropic base URL is properly formatted. +// It removes a trailing /v1 suffix if present (to avoid duplication), then +// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to +// defaultBaseURL. +func NormalizeBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if before, ok := strings.CutSuffix(base, "/v1"); ok { + base = before + } + if base == "" { + return defaultBaseURL + } + + if appendV1Suffix { + return base + "/v1" + } + return base +} diff --git a/pkg/providers/common/anthropic_common_test.go b/pkg/providers/common/anthropic_common_test.go new file mode 100644 index 000000000..7563141b5 --- /dev/null +++ b/pkg/providers/common/anthropic_common_test.go @@ -0,0 +1,59 @@ +package common + +import "testing" + +func TestNormalizeAnthropicBaseURL(t *testing.T) { + const defaultURL = "https://api.anthropic.com" + const defaultURLWithV1 = "https://api.anthropic.com/v1" + + tests := []struct { + name string + apiBase string + defaultBase string + appendV1Suffix bool + expected string + }{ + {"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1}, + {"empty without v1", "", defaultURL, false, defaultURL}, + { + "URL without v1 gets it appended", + "https://api.example.com/anthropic", defaultURLWithV1, + true, "https://api.example.com/anthropic/v1", + }, + { + "URL without v1 stays as-is", + "https://api.example.com/anthropic", defaultURL, + false, "https://api.example.com/anthropic", + }, + { + "URL with v1 remains unchanged when appending", + "https://api.example.com/v1", defaultURLWithV1, + true, "https://api.example.com/v1", + }, + { + "URL with v1 gets it stripped when not appending", + "https://api.example.com/v1", defaultURL, + false, "https://api.example.com", + }, + { + "trailing slash cleaned with v1", + "https://api.example.com/anthropic/", defaultURLWithV1, + true, "https://api.example.com/anthropic/v1", + }, + { + "trailing slash cleaned without v1", + "https://api.example.com/anthropic/", defaultURL, + false, "https://api.example.com/anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix) + if got != tt.expected { + t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q", + tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected) + } + }) + } +} diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index 1f9a9b827..84aa1a707 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -722,64 +722,6 @@ func TestParseDataAudioURL(t *testing.T) { } } -// --- NormalizeAnthropicBaseURL tests --- - -func TestNormalizeAnthropicBaseURL(t *testing.T) { - const defaultURL = "https://api.anthropic.com" - const defaultURLWithV1 = "https://api.anthropic.com/v1" - - tests := []struct { - name string - apiBase string - defaultBase string - appendV1Suffix bool - expected string - }{ - {"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1}, - {"empty without v1", "", defaultURL, false, defaultURL}, - { - "URL without v1 gets it appended", - "https://api.example.com/anthropic", defaultURLWithV1, - true, "https://api.example.com/anthropic/v1", - }, - { - "URL without v1 stays as-is", - "https://api.example.com/anthropic", defaultURL, - false, "https://api.example.com/anthropic", - }, - { - "URL with v1 remains unchanged when appending", - "https://api.example.com/v1", defaultURLWithV1, - true, "https://api.example.com/v1", - }, - { - "URL with v1 gets it stripped when not appending", - "https://api.example.com/v1", defaultURL, - false, "https://api.example.com", - }, - { - "trailing slash cleaned with v1", - "https://api.example.com/anthropic/", defaultURLWithV1, - true, "https://api.example.com/anthropic/v1", - }, - { - "trailing slash cleaned without v1", - "https://api.example.com/anthropic/", defaultURL, - false, "https://api.example.com/anthropic", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NormalizeAnthropicBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix) - if got != tt.expected { - t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q", - tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected) - } - }) - } -} - // --- WrapHTMLResponseError tests --- func TestWrapHTMLResponseError(t *testing.T) { diff --git a/pkg/providers/common/google_common.go b/pkg/providers/common/google_common.go new file mode 100644 index 000000000..954c0c802 --- /dev/null +++ b/pkg/providers/common/google_common.go @@ -0,0 +1,70 @@ +package common + +import ( + "encoding/json" + "strings" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// NormalizeStoredToolCall extracts the tool name, arguments, and thought signature +// from a stored ToolCall. It handles both the top-level fields and the nested +// Function struct used by different API formats. +func NormalizeStoredToolCall(tc protocoltypes.ToolCall) (string, map[string]any, string) { + name := tc.Name + args := tc.Arguments + thoughtSignature := "" + + if name == "" && tc.Function != nil { + name = tc.Function.Name + thoughtSignature = tc.Function.ThoughtSignature + } else if tc.Function != nil { + thoughtSignature = tc.Function.ThoughtSignature + } + + if args == nil { + args = map[string]any{} + } + + if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" { + var parsed map[string]any + if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil { + args = parsed + } + } + + return name, args, thoughtSignature +} + +// ResolveToolResponseName returns the tool name for a given tool call ID. +// It first checks the provided name map, then falls back to inferring the +// name from the call ID format. +func ResolveToolResponseName(toolCallID string, toolCallNames map[string]string) string { + if toolCallID == "" { + return "" + } + + if name, ok := toolCallNames[toolCallID]; ok && name != "" { + return name + } + + return InferToolNameFromCallID(toolCallID) +} + +// InferToolNameFromCallID extracts a tool name from a call ID in the format +// "call__". Returns the original ID if it doesn't match. +func InferToolNameFromCallID(toolCallID string) string { + if !strings.HasPrefix(toolCallID, "call_") { + return toolCallID + } + + rest := strings.TrimPrefix(toolCallID, "call_") + if idx := strings.LastIndex(rest, "_"); idx > 0 { + candidate := rest[:idx] + if candidate != "" { + return candidate + } + } + + return toolCallID +} diff --git a/pkg/providers/common/google_common_test.go b/pkg/providers/common/google_common_test.go new file mode 100644 index 000000000..cc013dcd1 --- /dev/null +++ b/pkg/providers/common/google_common_test.go @@ -0,0 +1,146 @@ +package common + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +func TestNormalizeStoredToolCall_TopLevelFields(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "search", + Arguments: map[string]any{"q": "hello"}, + } + name, args, sig := NormalizeStoredToolCall(tc) + if name != "search" { + t.Errorf("name = %q, want %q", name, "search") + } + if args["q"] != "hello" { + t.Errorf("args[q] = %v, want %q", args["q"], "hello") + } + if sig != "" { + t.Errorf("thoughtSignature = %q, want empty", sig) + } +} + +func TestNormalizeStoredToolCall_FallsBackToFunction(t *testing.T) { + tc := protocoltypes.ToolCall{ + Function: &protocoltypes.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp"}`, + ThoughtSignature: "sig123", + }, + } + name, args, sig := NormalizeStoredToolCall(tc) + if name != "read_file" { + t.Errorf("name = %q, want %q", name, "read_file") + } + if args["path"] != "/tmp" { + t.Errorf("args[path] = %v, want %q", args["path"], "/tmp") + } + if sig != "sig123" { + t.Errorf("thoughtSignature = %q, want %q", sig, "sig123") + } +} + +func TestNormalizeStoredToolCall_TopLevelNameWithFunctionSig(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "search", + Arguments: map[string]any{"q": "hi"}, + Function: &protocoltypes.FunctionCall{ + ThoughtSignature: "thought1", + }, + } + name, _, sig := NormalizeStoredToolCall(tc) + if name != "search" { + t.Errorf("name = %q, want %q", name, "search") + } + if sig != "thought1" { + t.Errorf("thoughtSignature = %q, want %q", sig, "thought1") + } +} + +func TestNormalizeStoredToolCall_NilArgs(t *testing.T) { + tc := protocoltypes.ToolCall{Name: "test"} + _, args, _ := NormalizeStoredToolCall(tc) + if args == nil { + t.Fatal("args should not be nil") + } + if len(args) != 0 { + t.Errorf("args should be empty, got %v", args) + } +} + +func TestNormalizeStoredToolCall_EmptyArgsParseFromFunction(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "tool", + Arguments: map[string]any{}, + Function: &protocoltypes.FunctionCall{ + Arguments: `{"key":"val"}`, + }, + } + _, args, _ := NormalizeStoredToolCall(tc) + if args["key"] != "val" { + t.Errorf("args[key] = %v, want %q", args["key"], "val") + } +} + +func TestNormalizeStoredToolCall_InvalidFunctionJSON(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "tool", + Function: &protocoltypes.FunctionCall{ + Arguments: `not-json`, + }, + } + _, args, _ := NormalizeStoredToolCall(tc) + if len(args) != 0 { + t.Errorf("args should be empty for invalid JSON, got %v", args) + } +} + +func TestResolveToolResponseName_FromMap(t *testing.T) { + names := map[string]string{"call_1": "search"} + got := ResolveToolResponseName("call_1", names) + if got != "search" { + t.Errorf("got %q, want %q", got, "search") + } +} + +func TestResolveToolResponseName_EmptyID(t *testing.T) { + got := ResolveToolResponseName("", map[string]string{"x": "y"}) + if got != "" { + t.Errorf("got %q, want empty", got) + } +} + +func TestResolveToolResponseName_FallsBackToInfer(t *testing.T) { + got := ResolveToolResponseName("call_search_docs_999", map[string]string{}) + if got != "search_docs" { + t.Errorf("got %q, want %q", got, "search_docs") + } +} + +func TestInferToolNameFromCallID(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + {"standard format", "call_search_docs_999", "search_docs"}, + {"single name", "call_read_123", "read"}, + {"no call prefix", "some_id", "some_id"}, + {"call prefix no underscore suffix", "call_onlyname", "call_onlyname"}, + {"empty string", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := InferToolNameFromCallID(tt.id) + if got != tt.want { + t.Errorf( + "InferToolNameFromCallID(%q) = %q, want %q", + tt.id, got, tt.want, + ) + } + }) + } +} diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index 0f1e20ca5..249c1b8de 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -1,64 +1,5 @@ package httpapi -import ( - "encoding/json" - "strings" -) - -func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) { - name := tc.Name - args := tc.Arguments - thoughtSignature := "" - - if name == "" && tc.Function != nil { - name = tc.Function.Name - thoughtSignature = tc.Function.ThoughtSignature - } else if tc.Function != nil { - thoughtSignature = tc.Function.ThoughtSignature - } - - if args == nil { - args = map[string]any{} - } - - if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" { - var parsed map[string]any - if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil { - args = parsed - } - } - - return name, args, thoughtSignature -} - -func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string { - if toolCallID == "" { - return "" - } - - if name, ok := toolCallNames[toolCallID]; ok && name != "" { - return name - } - - return inferToolNameFromCallID(toolCallID) -} - -func inferToolNameFromCallID(toolCallID string) string { - if !strings.HasPrefix(toolCallID, "call_") { - return toolCallID - } - - rest := strings.TrimPrefix(toolCallID, "call_") - if idx := strings.LastIndex(rest, "_"); idx > 0 { - candidate := rest[:idx] - if candidate != "" { - return candidate - } - } - - return toolCallID -} - func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string { if thoughtSignature != "" { return thoughtSignature diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index dab6acd29..9ad4693da 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -185,7 +185,7 @@ func (p *GeminiProvider) buildRequestBody( case "user": if msg.ToolCallID != "" { - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) contents = append(contents, geminiContent{ Role: "user", Parts: []geminiPart{{ @@ -210,7 +210,7 @@ func (p *GeminiProvider) buildRequestBody( content.Parts = append(content.Parts, geminiPart{Text: msg.Content}) } for _, tc := range msg.ToolCalls { - toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc) + toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc) if toolName == "" { continue } @@ -234,7 +234,7 @@ func (p *GeminiProvider) buildRequestBody( } case "tool": - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) contents = append(contents, geminiContent{ Role: "user", Parts: []geminiPart{{ diff --git a/pkg/providers/oauth/antigravity_provider.go b/pkg/providers/oauth/antigravity_provider.go index 38526dd7a..1ac2d9c7f 100644 --- a/pkg/providers/oauth/antigravity_provider.go +++ b/pkg/providers/oauth/antigravity_provider.go @@ -14,6 +14,7 @@ import ( "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers/common" ) const ( @@ -221,7 +222,7 @@ func (p *AntigravityProvider) buildRequest( } case "user": if msg.ToolCallID != "" { - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) // Tool result req.Contents = append(req.Contents, antigravityContent{ Role: "user", @@ -248,7 +249,7 @@ func (p *AntigravityProvider) buildRequest( content.Parts = append(content.Parts, antigravityPart{Text: msg.Content}) } for _, tc := range msg.ToolCalls { - toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc) + toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc) if toolName == "" { logger.WarnCF( "provider.antigravity", @@ -275,7 +276,7 @@ func (p *AntigravityProvider) buildRequest( req.Contents = append(req.Contents, content) } case "tool": - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) req.Contents = append(req.Contents, antigravityContent{ Role: "user", Parts: []antigravityPart{{ @@ -328,60 +329,6 @@ func (p *AntigravityProvider) buildRequest( return req } -func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) { - name := tc.Name - args := tc.Arguments - thoughtSignature := "" - - if name == "" && tc.Function != nil { - name = tc.Function.Name - thoughtSignature = tc.Function.ThoughtSignature - } else if tc.Function != nil { - thoughtSignature = tc.Function.ThoughtSignature - } - - if args == nil { - args = map[string]any{} - } - - if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" { - var parsed map[string]any - if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil { - args = parsed - } - } - - return name, args, thoughtSignature -} - -func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string { - if toolCallID == "" { - return "" - } - - if name, ok := toolCallNames[toolCallID]; ok && name != "" { - return name - } - - return inferToolNameFromCallID(toolCallID) -} - -func inferToolNameFromCallID(toolCallID string) string { - if !strings.HasPrefix(toolCallID, "call_") { - return toolCallID - } - - rest := strings.TrimPrefix(toolCallID, "call_") - if idx := strings.LastIndex(rest, "_"); idx > 0 { - candidate := rest[:idx] - if candidate != "" { - return candidate - } - } - - return toolCallID -} - // --- Response parsing --- type antigravityJSONResponse struct { diff --git a/pkg/providers/oauth/antigravity_provider_test.go b/pkg/providers/oauth/antigravity_provider_test.go index 41cb5b0db..2989f8519 100644 --- a/pkg/providers/oauth/antigravity_provider_test.go +++ b/pkg/providers/oauth/antigravity_provider_test.go @@ -48,13 +48,6 @@ func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) { } } -func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) { - got := resolveToolResponseName("call_search_docs_999", map[string]string{}) - if got != "search_docs" { - t.Fatalf("expected inferred tool name search_docs, got %q", got) - } -} - func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) { p := &AntigravityProvider{} body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" + From 76164701370498463f610d43503c6d6fd4eb1b1d Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Wed, 22 Apr 2026 05:41:32 +0000 Subject: [PATCH 5/5] Revert deduplication --- pkg/providers/common/common_test.go | 31 ------------------------ pkg/providers/factory_provider.go | 23 ++++++++++++++++-- pkg/providers/httpapi/gemini_helpers.go | 11 +++++++++ pkg/providers/httpapi/gemini_provider.go | 2 +- 4 files changed, 33 insertions(+), 34 deletions(-) diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index 84aa1a707..3cf2f4285 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -660,37 +660,6 @@ func TestAsFloat(t *testing.T) { } } -// --- ExtractProtocol tests --- - -func TestExtractProtocol(t *testing.T) { - tests := []struct { - name string - model string - wantProtocol string - wantModelID string - }{ - {"openai with prefix", "openai/gpt-4o", "openai", "gpt-4o"}, - {"anthropic with prefix", "anthropic/claude-sonnet-4.6", "anthropic", "claude-sonnet-4.6"}, - {"no prefix defaults to openai", "gpt-4o", "openai", "gpt-4o"}, - {"groq with prefix", "groq/llama-3.1-70b", "groq", "llama-3.1-70b"}, - {"empty string", "", "openai", ""}, - {"with whitespace", " openai/gpt-4 ", "openai", "gpt-4"}, - {"multiple slashes", "nvidia/meta/llama-3.1-8b", "nvidia", "meta/llama-3.1-8b"}, - {"azure with prefix", "azure/my-gpt5-deployment", "azure", "my-gpt5-deployment"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - protocol, modelID := ExtractProtocol(tt.model) - if protocol != tt.wantProtocol { - t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol) - } - if modelID != tt.wantModelID { - t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID) - } - }) - } -} - // --- ParseDataAudioURL tests --- func TestParseDataAudioURL(t *testing.T) { diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 63413b0a1..86d009811 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -15,7 +15,6 @@ import ( anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" "github.com/sipeed/picoclaw/pkg/providers/azure" "github.com/sipeed/picoclaw/pkg/providers/bedrock" - "github.com/sipeed/picoclaw/pkg/providers/common" ) type protocolMeta struct { @@ -103,7 +102,27 @@ func createCodexAuthProvider() (LLMProvider, error) { // - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o") // - Model "gpt-4o" -> ("openai", "gpt-4o") func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) { - return common.ExtractProtocol(model) + if cfg == nil { + return "", "" + } + + model := strings.TrimSpace(cfg.Model) + if provider := strings.TrimSpace(cfg.Provider); provider != "" { + return NormalizeProvider(provider), model + } + if model == "" { + return "", "" + } + + protocol, rest, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + protocol = strings.TrimSpace(protocol) + if protocol == "" { + return "", strings.TrimSpace(rest) + } + return NormalizeProvider(protocol), strings.TrimSpace(rest) } // ResolveAPIBase returns the configured API base, or the protocol default when diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index 249c1b8de..a2b2d63c3 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -1,5 +1,7 @@ package httpapi +import "strings" + func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string { if thoughtSignature != "" { return thoughtSignature @@ -69,3 +71,12 @@ func sanitizeSchemaForGemini(schema map[string]any) map[string]any { return result } + +func extractProtocol(model string) (protocol, modelID string) { + model = strings.TrimSpace(model) + protocol, modelID, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + return protocol, modelID +} diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index 9ad4693da..d1d523757 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -303,7 +303,7 @@ func normalizeGeminiModel(model string) string { model = strings.TrimSpace(model) model = strings.TrimPrefix(model, "models/") if strings.Contains(model, "/") { - _, modelID := common.ExtractProtocol(model) + _, modelID := extractProtocol(model) if modelID != "" { return modelID }