diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index 84aa1a707..3cf2f4285 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -660,37 +660,6 @@ func TestAsFloat(t *testing.T) { } } -// --- ExtractProtocol tests --- - -func TestExtractProtocol(t *testing.T) { - tests := []struct { - name string - model string - wantProtocol string - wantModelID string - }{ - {"openai with prefix", "openai/gpt-4o", "openai", "gpt-4o"}, - {"anthropic with prefix", "anthropic/claude-sonnet-4.6", "anthropic", "claude-sonnet-4.6"}, - {"no prefix defaults to openai", "gpt-4o", "openai", "gpt-4o"}, - {"groq with prefix", "groq/llama-3.1-70b", "groq", "llama-3.1-70b"}, - {"empty string", "", "openai", ""}, - {"with whitespace", " openai/gpt-4 ", "openai", "gpt-4"}, - {"multiple slashes", "nvidia/meta/llama-3.1-8b", "nvidia", "meta/llama-3.1-8b"}, - {"azure with prefix", "azure/my-gpt5-deployment", "azure", "my-gpt5-deployment"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - protocol, modelID := ExtractProtocol(tt.model) - if protocol != tt.wantProtocol { - t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol) - } - if modelID != tt.wantModelID { - t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID) - } - }) - } -} - // --- ParseDataAudioURL tests --- func TestParseDataAudioURL(t *testing.T) { diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 63413b0a1..86d009811 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -15,7 +15,6 @@ import ( anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages" "github.com/sipeed/picoclaw/pkg/providers/azure" "github.com/sipeed/picoclaw/pkg/providers/bedrock" - "github.com/sipeed/picoclaw/pkg/providers/common" ) type protocolMeta struct { @@ -103,7 +102,27 @@ func createCodexAuthProvider() (LLMProvider, error) { // - Provider "openai", Model "openai/gpt-4o" -> ("openai", "openai/gpt-4o") // - Model "gpt-4o" -> ("openai", "gpt-4o") func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) { - return common.ExtractProtocol(model) + if cfg == nil { + return "", "" + } + + model := strings.TrimSpace(cfg.Model) + if provider := strings.TrimSpace(cfg.Provider); provider != "" { + return NormalizeProvider(provider), model + } + if model == "" { + return "", "" + } + + protocol, rest, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + protocol = strings.TrimSpace(protocol) + if protocol == "" { + return "", strings.TrimSpace(rest) + } + return NormalizeProvider(protocol), strings.TrimSpace(rest) } // ResolveAPIBase returns the configured API base, or the protocol default when diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index 249c1b8de..a2b2d63c3 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -1,5 +1,7 @@ package httpapi +import "strings" + func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string { if thoughtSignature != "" { return thoughtSignature @@ -69,3 +71,12 @@ func sanitizeSchemaForGemini(schema map[string]any) map[string]any { return result } + +func extractProtocol(model string) (protocol, modelID string) { + model = strings.TrimSpace(model) + protocol, modelID, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + return protocol, modelID +} diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index 9ad4693da..d1d523757 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -303,7 +303,7 @@ func normalizeGeminiModel(model string) string { model = strings.TrimSpace(model) model = strings.TrimPrefix(model, "models/") if strings.Contains(model, "/") { - _, modelID := common.ExtractProtocol(model) + _, modelID := extractProtocol(model) if modelID != "" { return modelID }