Merge pull request #2586 from kunalk16/fix-functions-deduplication

refactor(deduplication): functions deduplication in pkg/providers
This commit is contained in:
美電球
2026-04-23 20:55:11 +08:00
committed by GitHub
16 changed files with 360 additions and 304 deletions
+2 -18
View File
@@ -10,6 +10,7 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -42,7 +43,7 @@ func NewProvider(token string) *Provider {
}
func NewProviderWithBaseURL(token, apiBase string) *Provider {
baseURL := normalizeBaseURL(apiBase)
baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, false)
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL(baseURL),
@@ -385,20 +386,3 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
},
}
}
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
if base == "" {
return defaultBaseURL
}
return base
}
+4 -58
View File
@@ -16,6 +16,7 @@ import (
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -51,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider {
// NewProviderWithTimeout creates a provider with custom request timeout.
func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
baseURL := normalizeBaseURL(apiBase)
baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, true)
timeout := defaultRequestTimeout
if timeoutSeconds > 0 {
timeout = time.Duration(timeoutSeconds) * time.Second
@@ -161,7 +162,7 @@ func buildRequestBody(
options map[string]any,
) (map[string]any, error) {
// max_tokens is required and guaranteed by agent loop
maxTokens, ok := asInt(options["max_tokens"])
maxTokens, ok := common.AsInt(options["max_tokens"])
if !ok {
return nil, fmt.Errorf("max_tokens is required in options")
}
@@ -173,7 +174,7 @@ func buildRequestBody(
}
// Set temperature from options
if temp, ok := asFloat(options["temperature"]); ok {
if temp, ok := common.AsFloat(options["temperature"]); ok {
result["temperature"] = temp
}
@@ -361,61 +362,6 @@ func parseResponseBody(body []byte) (*LLMResponse, error) {
}, nil
}
// normalizeBaseURL ensures the base URL is properly formatted.
// It removes /v1 suffix if present (to avoid duplication) and always appends /v1.
// This handles edge cases like "https://api.example.com/v1/proxy" correctly.
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
// Remove trailing slashes
base = strings.TrimRight(base, "/")
// Remove /v1 suffix if present (will be re-added)
// This prevents duplication for URLs like "https://api.example.com/v1/proxy"
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
// Ensure we don't have an empty string after cutting
if base == "" {
return defaultBaseURL
}
// Add /v1 suffix (required by Anthropic Messages API)
return base + "/v1"
}
// Helper functions for type conversion
func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case float64:
return int(val), true
case int64:
return int(val), true
default:
return 0, false
}
}
func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
// Anthropic API response structures
type anthropicMessageResponse struct {
@@ -372,44 +372,6 @@ func TestParseResponseBody(t *testing.T) {
}
}
func TestNormalizeBaseURL(t *testing.T) {
tests := []struct {
name string
apiBase string
expected string
}{
{
name: "empty string defaults to official API",
apiBase: "",
expected: "https://api.anthropic.com/v1",
},
{
name: "URL without /v1 gets it appended",
apiBase: "https://api.example.com/anthropic",
expected: "https://api.example.com/anthropic/v1",
},
{
name: "URL with /v1 remains unchanged",
apiBase: "https://api.example.com/v1",
expected: "https://api.example.com/v1",
},
{
name: "URL with trailing slash gets cleaned",
apiBase: "https://api.example.com/anthropic/",
expected: "https://api.example.com/anthropic/v1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeBaseURL(tt.apiBase)
if got != tt.expected {
t.Errorf("normalizeBaseURL(%q) = %q, want %q", tt.apiBase, got, tt.expected)
}
})
}
}
func TestNewProvider(t *testing.T) {
provider := NewProvider("test-key", "https://api.example.com", "")
if provider == nil {
+27
View File
@@ -0,0 +1,27 @@
package common
import "strings"
// NormalizeBaseURL ensures the Anthropic base URL is properly formatted.
// It removes a trailing /v1 suffix if present (to avoid duplication), then
// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to
// defaultBaseURL.
func NormalizeBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
if base == "" {
return defaultBaseURL
}
if appendV1Suffix {
return base + "/v1"
}
return base
}
@@ -0,0 +1,59 @@
package common
import "testing"
func TestNormalizeAnthropicBaseURL(t *testing.T) {
const defaultURL = "https://api.anthropic.com"
const defaultURLWithV1 = "https://api.anthropic.com/v1"
tests := []struct {
name string
apiBase string
defaultBase string
appendV1Suffix bool
expected string
}{
{"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1},
{"empty without v1", "", defaultURL, false, defaultURL},
{
"URL without v1 gets it appended",
"https://api.example.com/anthropic", defaultURLWithV1,
true, "https://api.example.com/anthropic/v1",
},
{
"URL without v1 stays as-is",
"https://api.example.com/anthropic", defaultURL,
false, "https://api.example.com/anthropic",
},
{
"URL with v1 remains unchanged when appending",
"https://api.example.com/v1", defaultURLWithV1,
true, "https://api.example.com/v1",
},
{
"URL with v1 gets it stripped when not appending",
"https://api.example.com/v1", defaultURL,
false, "https://api.example.com",
},
{
"trailing slash cleaned with v1",
"https://api.example.com/anthropic/", defaultURLWithV1,
true, "https://api.example.com/anthropic/v1",
},
{
"trailing slash cleaned without v1",
"https://api.example.com/anthropic/", defaultURL,
false, "https://api.example.com/anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NormalizeBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix)
if got != tt.expected {
t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q",
tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected)
}
})
}
}
+3 -2
View File
@@ -127,7 +127,7 @@ func SerializeMessages(messages []Message) []any {
continue
}
if format, data, ok := parseDataAudioURL(mediaURL); ok {
if format, data, ok := ParseDataAudioURL(mediaURL); ok {
parts = append(parts, map[string]any{
"type": "input_audio",
"input_audio": map[string]any{
@@ -205,7 +205,8 @@ func serializeToolCalls(toolCalls []ToolCall) []openaiToolCall {
return out
}
func parseDataAudioURL(mediaURL string) (format, data string, ok bool) {
// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL.
func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:audio/") {
return "", "", false
}
+31
View File
@@ -660,6 +660,37 @@ func TestAsFloat(t *testing.T) {
}
}
// --- ParseDataAudioURL tests ---
func TestParseDataAudioURL(t *testing.T) {
tests := []struct {
name string
mediaURL string
wantFormat string
wantData string
wantOK bool
}{
{"valid mp3", "data:audio/mp3;base64,SGVsbG8=", "mp3", "SGVsbG8=", true},
{"valid wav", "data:audio/wav;base64,AAAA", "wav", "AAAA", true},
{"not audio", "data:image/png;base64,abc", "", "", false},
{"no comma", "data:audio/mp3;base64", "", "", false},
{"empty data", "data:audio/mp3;base64,", "", "", false},
{"empty string", "", "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
format, data, ok := ParseDataAudioURL(tt.mediaURL)
if ok != tt.wantOK || format != tt.wantFormat || data != tt.wantData {
t.Errorf(
"ParseDataAudioURL(%q) = (%q, %q, %v), want (%q, %q, %v)",
tt.mediaURL, format, data, ok,
tt.wantFormat, tt.wantData, tt.wantOK,
)
}
})
}
}
// --- WrapHTMLResponseError tests ---
func TestWrapHTMLResponseError(t *testing.T) {
+70
View File
@@ -0,0 +1,70 @@
package common
import (
"encoding/json"
"strings"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// NormalizeStoredToolCall extracts the tool name, arguments, and thought signature
// from a stored ToolCall. It handles both the top-level fields and the nested
// Function struct used by different API formats.
func NormalizeStoredToolCall(tc protocoltypes.ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
if name == "" && tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
} else if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
if args == nil {
args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
}
return name, args, thoughtSignature
}
// ResolveToolResponseName returns the tool name for a given tool call ID.
// It first checks the provided name map, then falls back to inferring the
// name from the call ID format.
func ResolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
if toolCallID == "" {
return ""
}
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
return name
}
return InferToolNameFromCallID(toolCallID)
}
// InferToolNameFromCallID extracts a tool name from a call ID in the format
// "call_<name>_<suffix>". Returns the original ID if it doesn't match.
func InferToolNameFromCallID(toolCallID string) string {
if !strings.HasPrefix(toolCallID, "call_") {
return toolCallID
}
rest := strings.TrimPrefix(toolCallID, "call_")
if idx := strings.LastIndex(rest, "_"); idx > 0 {
candidate := rest[:idx]
if candidate != "" {
return candidate
}
}
return toolCallID
}
+146
View File
@@ -0,0 +1,146 @@
package common
import (
"testing"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestNormalizeStoredToolCall_TopLevelFields(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "search",
Arguments: map[string]any{"q": "hello"},
}
name, args, sig := NormalizeStoredToolCall(tc)
if name != "search" {
t.Errorf("name = %q, want %q", name, "search")
}
if args["q"] != "hello" {
t.Errorf("args[q] = %v, want %q", args["q"], "hello")
}
if sig != "" {
t.Errorf("thoughtSignature = %q, want empty", sig)
}
}
func TestNormalizeStoredToolCall_FallsBackToFunction(t *testing.T) {
tc := protocoltypes.ToolCall{
Function: &protocoltypes.FunctionCall{
Name: "read_file",
Arguments: `{"path":"/tmp"}`,
ThoughtSignature: "sig123",
},
}
name, args, sig := NormalizeStoredToolCall(tc)
if name != "read_file" {
t.Errorf("name = %q, want %q", name, "read_file")
}
if args["path"] != "/tmp" {
t.Errorf("args[path] = %v, want %q", args["path"], "/tmp")
}
if sig != "sig123" {
t.Errorf("thoughtSignature = %q, want %q", sig, "sig123")
}
}
func TestNormalizeStoredToolCall_TopLevelNameWithFunctionSig(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "search",
Arguments: map[string]any{"q": "hi"},
Function: &protocoltypes.FunctionCall{
ThoughtSignature: "thought1",
},
}
name, _, sig := NormalizeStoredToolCall(tc)
if name != "search" {
t.Errorf("name = %q, want %q", name, "search")
}
if sig != "thought1" {
t.Errorf("thoughtSignature = %q, want %q", sig, "thought1")
}
}
func TestNormalizeStoredToolCall_NilArgs(t *testing.T) {
tc := protocoltypes.ToolCall{Name: "test"}
_, args, _ := NormalizeStoredToolCall(tc)
if args == nil {
t.Fatal("args should not be nil")
}
if len(args) != 0 {
t.Errorf("args should be empty, got %v", args)
}
}
func TestNormalizeStoredToolCall_EmptyArgsParseFromFunction(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "tool",
Arguments: map[string]any{},
Function: &protocoltypes.FunctionCall{
Arguments: `{"key":"val"}`,
},
}
_, args, _ := NormalizeStoredToolCall(tc)
if args["key"] != "val" {
t.Errorf("args[key] = %v, want %q", args["key"], "val")
}
}
func TestNormalizeStoredToolCall_InvalidFunctionJSON(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "tool",
Function: &protocoltypes.FunctionCall{
Arguments: `not-json`,
},
}
_, args, _ := NormalizeStoredToolCall(tc)
if len(args) != 0 {
t.Errorf("args should be empty for invalid JSON, got %v", args)
}
}
func TestResolveToolResponseName_FromMap(t *testing.T) {
names := map[string]string{"call_1": "search"}
got := ResolveToolResponseName("call_1", names)
if got != "search" {
t.Errorf("got %q, want %q", got, "search")
}
}
func TestResolveToolResponseName_EmptyID(t *testing.T) {
got := ResolveToolResponseName("", map[string]string{"x": "y"})
if got != "" {
t.Errorf("got %q, want empty", got)
}
}
func TestResolveToolResponseName_FallsBackToInfer(t *testing.T) {
got := ResolveToolResponseName("call_search_docs_999", map[string]string{})
if got != "search_docs" {
t.Errorf("got %q, want %q", got, "search_docs")
}
}
func TestInferToolNameFromCallID(t *testing.T) {
tests := []struct {
name string
id string
want string
}{
{"standard format", "call_search_docs_999", "search_docs"},
{"single name", "call_read_123", "read"},
{"no call prefix", "some_id", "some_id"},
{"call prefix no underscore suffix", "call_onlyname", "call_onlyname"},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := InferToolNameFromCallID(tt.id)
if got != tt.want {
t.Errorf(
"InferToolNameFromCallID(%q) = %q, want %q",
tt.id, got, tt.want,
)
}
})
}
}
+1 -58
View File
@@ -1,63 +1,6 @@
package httpapi
import (
"encoding/json"
"strings"
)
func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
if name == "" && tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
} else if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
if args == nil {
args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
}
return name, args, thoughtSignature
}
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
if toolCallID == "" {
return ""
}
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
return name
}
return inferToolNameFromCallID(toolCallID)
}
func inferToolNameFromCallID(toolCallID string) string {
if !strings.HasPrefix(toolCallID, "call_") {
return toolCallID
}
rest := strings.TrimPrefix(toolCallID, "call_")
if idx := strings.LastIndex(rest, "_"); idx > 0 {
candidate := rest[:idx]
if candidate != "" {
return candidate
}
}
return toolCallID
}
import "strings"
func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string {
if thoughtSignature != "" {
+3 -3
View File
@@ -185,7 +185,7 @@ func (p *GeminiProvider) buildRequestBody(
case "user":
if msg.ToolCallID != "" {
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
@@ -210,7 +210,7 @@ func (p *GeminiProvider) buildRequestBody(
content.Parts = append(content.Parts, geminiPart{Text: msg.Content})
}
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc)
if toolName == "" {
continue
}
@@ -234,7 +234,7 @@ func (p *GeminiProvider) buildRequestBody(
}
case "tool":
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
+4 -57
View File
@@ -14,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
const (
@@ -221,7 +222,7 @@ func (p *AntigravityProvider) buildRequest(
}
case "user":
if msg.ToolCallID != "" {
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
// Tool result
req.Contents = append(req.Contents, antigravityContent{
Role: "user",
@@ -248,7 +249,7 @@ func (p *AntigravityProvider) buildRequest(
content.Parts = append(content.Parts, antigravityPart{Text: msg.Content})
}
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc)
if toolName == "" {
logger.WarnCF(
"provider.antigravity",
@@ -275,7 +276,7 @@ func (p *AntigravityProvider) buildRequest(
req.Contents = append(req.Contents, content)
}
case "tool":
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
req.Contents = append(req.Contents, antigravityContent{
Role: "user",
Parts: []antigravityPart{{
@@ -328,60 +329,6 @@ func (p *AntigravityProvider) buildRequest(
return req
}
func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
if name == "" && tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
} else if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
if args == nil {
args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
}
return name, args, thoughtSignature
}
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
if toolCallID == "" {
return ""
}
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
return name
}
return inferToolNameFromCallID(toolCallID)
}
func inferToolNameFromCallID(toolCallID string) string {
if !strings.HasPrefix(toolCallID, "call_") {
return toolCallID
}
rest := strings.TrimPrefix(toolCallID, "call_")
if idx := strings.LastIndex(rest, "_"); idx > 0 {
candidate := rest[:idx]
if candidate != "" {
return candidate
}
}
return toolCallID
}
// --- Response parsing ---
type antigravityJSONResponse struct {
@@ -48,13 +48,6 @@ func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) {
}
}
func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) {
got := resolveToolResponseName("call_search_docs_999", map[string]string{})
if got != "search_docs" {
t.Fatalf("expected inferred tool name search_docs, got %q", got)
}
}
func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
p := &AntigravityProvider{}
body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" +
+8 -7
View File
@@ -470,7 +470,9 @@ func (p *Provider) SupportsNativeSearch() bool {
return isNativeSearchHost(p.apiBase)
}
func isNativeSearchHost(apiBase string) bool {
// isNativeOpenAIOrAzureEndpoint reports whether the given API base points to
// OpenAI's own API or an Azure OpenAI deployment.
func isNativeOpenAIOrAzureEndpoint(apiBase string) bool {
u, err := url.Parse(apiBase)
if err != nil {
return false
@@ -479,15 +481,14 @@ func isNativeSearchHost(apiBase string) bool {
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
}
func isNativeSearchHost(apiBase string) bool {
return isNativeOpenAIOrAzureEndpoint(apiBase)
}
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API and Azure OpenAI support this. All other OpenAI-compatible providers
// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
func supportsPromptCacheKey(apiBase string) bool {
u, err := url.Parse(apiBase)
if err != nil {
return false
}
host := u.Hostname()
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
return isNativeOpenAIOrAzureEndpoint(apiBase)
}
@@ -10,6 +10,7 @@ import (
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -118,7 +119,7 @@ func BuildMultipartContent(text string, media []string) responses.ResponseInputM
},
})
} else if strings.HasPrefix(mediaURL, "data:audio/") {
if format, data, ok := ParseDataAudioURL(mediaURL); ok {
if format, data, ok := common.ParseDataAudioURL(mediaURL); ok {
parts = append(parts, responses.ResponseInputContentUnionParam{
OfInputFile: &responses.ResponseInputFileParam{
FileData: openai.Opt(data),
@@ -132,25 +133,6 @@ func BuildMultipartContent(text string, media []string) responses.ResponseInputM
return parts
}
// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL.
func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) {
if !strings.HasPrefix(mediaURL, "data:audio/") {
return "", "", false
}
payload := strings.TrimPrefix(mediaURL, "data:audio/")
meta, data, found := strings.Cut(payload, ",")
if !found {
return "", "", false
}
format, _, _ = strings.Cut(meta, ";")
format = strings.TrimSpace(format)
data = strings.TrimSpace(data)
if format == "" || data == "" {
return "", "", false
}
return format, data, true
}
// ResolveToolCall extracts the function name and JSON arguments string from a ToolCall.
// Returns ok=false if the tool call has no name or if arguments fail to marshal.
func ResolveToolCall(tc protocoltypes.ToolCall) (name string, arguments string, ok bool) {
@@ -506,42 +506,6 @@ func TestParseResponseBody_CanceledStatus(t *testing.T) {
}
}
// --- ParseDataAudioURL tests ---
func TestParseDataAudioURL_Valid(t *testing.T) {
format, data, ok := ParseDataAudioURL("data:audio/mp3;base64,SGVsbG8=")
if !ok {
t.Fatal("expected ok=true")
}
if format != "mp3" {
t.Errorf("format = %q, want %q", format, "mp3")
}
if data != "SGVsbG8=" {
t.Errorf("data = %q, want %q", data, "SGVsbG8=")
}
}
func TestParseDataAudioURL_NotAudio(t *testing.T) {
_, _, ok := ParseDataAudioURL("data:image/png;base64,abc")
if ok {
t.Error("expected ok=false for non-audio URL")
}
}
func TestParseDataAudioURL_MalformedNoComma(t *testing.T) {
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64")
if ok {
t.Error("expected ok=false for malformed URL")
}
}
func TestParseDataAudioURL_EmptyData(t *testing.T) {
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64,")
if ok {
t.Error("expected ok=false for empty data")
}
}
// --- BuildMultipartContent tests ---
func TestBuildMultipartContent_TextOnly(t *testing.T) {