From 4ae11406d2118793848ef9f74627afbb74cd97cb Mon Sep 17 00:00:00 2001 From: Kunal Karmakar Date: Sun, 19 Apr 2026 06:48:28 +0000 Subject: [PATCH] Deduplicate further functions --- pkg/providers/anthropic/provider.go | 2 +- pkg/providers/anthropic_messages/provider.go | 2 +- pkg/providers/common/anthropic_common.go | 27 ++++ pkg/providers/common/anthropic_common_test.go | 59 +++++++ pkg/providers/common/common_test.go | 58 ------- pkg/providers/common/google_common.go | 70 +++++++++ pkg/providers/common/google_common_test.go | 146 ++++++++++++++++++ pkg/providers/httpapi/gemini_helpers.go | 59 ------- pkg/providers/httpapi/gemini_provider.go | 6 +- pkg/providers/oauth/antigravity_provider.go | 61 +------- .../oauth/antigravity_provider_test.go | 7 - 11 files changed, 311 insertions(+), 186 deletions(-) create mode 100644 pkg/providers/common/anthropic_common.go create mode 100644 pkg/providers/common/anthropic_common_test.go create mode 100644 pkg/providers/common/google_common.go create mode 100644 pkg/providers/common/google_common_test.go diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index 4330163df..6f4aadb8b 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -43,7 +43,7 @@ func NewProvider(token string) *Provider { } func NewProviderWithBaseURL(token, apiBase string) *Provider { - baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, false) + baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, false) client := anthropic.NewClient( option.WithAuthToken(token), option.WithBaseURL(baseURL), diff --git a/pkg/providers/anthropic_messages/provider.go b/pkg/providers/anthropic_messages/provider.go index dcc31b6f9..672fb9324 100644 --- a/pkg/providers/anthropic_messages/provider.go +++ b/pkg/providers/anthropic_messages/provider.go @@ -52,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider { // NewProviderWithTimeout creates a provider with custom request timeout. func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider { - baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, true) + baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, true) timeout := defaultRequestTimeout if timeoutSeconds > 0 { timeout = time.Duration(timeoutSeconds) * time.Second diff --git a/pkg/providers/common/anthropic_common.go b/pkg/providers/common/anthropic_common.go new file mode 100644 index 000000000..92dace9ac --- /dev/null +++ b/pkg/providers/common/anthropic_common.go @@ -0,0 +1,27 @@ +package common + +import "strings" + +// NormalizeBaseURL ensures the Anthropic base URL is properly formatted. +// It removes a trailing /v1 suffix if present (to avoid duplication), then +// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to +// defaultBaseURL. +func NormalizeBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if before, ok := strings.CutSuffix(base, "/v1"); ok { + base = before + } + if base == "" { + return defaultBaseURL + } + + if appendV1Suffix { + return base + "/v1" + } + return base +} diff --git a/pkg/providers/common/anthropic_common_test.go b/pkg/providers/common/anthropic_common_test.go new file mode 100644 index 000000000..7563141b5 --- /dev/null +++ b/pkg/providers/common/anthropic_common_test.go @@ -0,0 +1,59 @@ +package common + +import "testing" + +func TestNormalizeAnthropicBaseURL(t *testing.T) { + const defaultURL = "https://api.anthropic.com" + const defaultURLWithV1 = "https://api.anthropic.com/v1" + + tests := []struct { + name string + apiBase string + defaultBase string + appendV1Suffix bool + expected string + }{ + {"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1}, + {"empty without v1", "", defaultURL, false, defaultURL}, + { + "URL without v1 gets it appended", + "https://api.example.com/anthropic", defaultURLWithV1, + true, "https://api.example.com/anthropic/v1", + }, + { + "URL without v1 stays as-is", + "https://api.example.com/anthropic", defaultURL, + false, "https://api.example.com/anthropic", + }, + { + "URL with v1 remains unchanged when appending", + "https://api.example.com/v1", defaultURLWithV1, + true, "https://api.example.com/v1", + }, + { + "URL with v1 gets it stripped when not appending", + "https://api.example.com/v1", defaultURL, + false, "https://api.example.com", + }, + { + "trailing slash cleaned with v1", + "https://api.example.com/anthropic/", defaultURLWithV1, + true, "https://api.example.com/anthropic/v1", + }, + { + "trailing slash cleaned without v1", + "https://api.example.com/anthropic/", defaultURL, + false, "https://api.example.com/anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix) + if got != tt.expected { + t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q", + tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected) + } + }) + } +} diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go index 1f9a9b827..84aa1a707 100644 --- a/pkg/providers/common/common_test.go +++ b/pkg/providers/common/common_test.go @@ -722,64 +722,6 @@ func TestParseDataAudioURL(t *testing.T) { } } -// --- NormalizeAnthropicBaseURL tests --- - -func TestNormalizeAnthropicBaseURL(t *testing.T) { - const defaultURL = "https://api.anthropic.com" - const defaultURLWithV1 = "https://api.anthropic.com/v1" - - tests := []struct { - name string - apiBase string - defaultBase string - appendV1Suffix bool - expected string - }{ - {"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1}, - {"empty without v1", "", defaultURL, false, defaultURL}, - { - "URL without v1 gets it appended", - "https://api.example.com/anthropic", defaultURLWithV1, - true, "https://api.example.com/anthropic/v1", - }, - { - "URL without v1 stays as-is", - "https://api.example.com/anthropic", defaultURL, - false, "https://api.example.com/anthropic", - }, - { - "URL with v1 remains unchanged when appending", - "https://api.example.com/v1", defaultURLWithV1, - true, "https://api.example.com/v1", - }, - { - "URL with v1 gets it stripped when not appending", - "https://api.example.com/v1", defaultURL, - false, "https://api.example.com", - }, - { - "trailing slash cleaned with v1", - "https://api.example.com/anthropic/", defaultURLWithV1, - true, "https://api.example.com/anthropic/v1", - }, - { - "trailing slash cleaned without v1", - "https://api.example.com/anthropic/", defaultURL, - false, "https://api.example.com/anthropic", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NormalizeAnthropicBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix) - if got != tt.expected { - t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q", - tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected) - } - }) - } -} - // --- WrapHTMLResponseError tests --- func TestWrapHTMLResponseError(t *testing.T) { diff --git a/pkg/providers/common/google_common.go b/pkg/providers/common/google_common.go new file mode 100644 index 000000000..954c0c802 --- /dev/null +++ b/pkg/providers/common/google_common.go @@ -0,0 +1,70 @@ +package common + +import ( + "encoding/json" + "strings" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +// NormalizeStoredToolCall extracts the tool name, arguments, and thought signature +// from a stored ToolCall. It handles both the top-level fields and the nested +// Function struct used by different API formats. +func NormalizeStoredToolCall(tc protocoltypes.ToolCall) (string, map[string]any, string) { + name := tc.Name + args := tc.Arguments + thoughtSignature := "" + + if name == "" && tc.Function != nil { + name = tc.Function.Name + thoughtSignature = tc.Function.ThoughtSignature + } else if tc.Function != nil { + thoughtSignature = tc.Function.ThoughtSignature + } + + if args == nil { + args = map[string]any{} + } + + if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" { + var parsed map[string]any + if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil { + args = parsed + } + } + + return name, args, thoughtSignature +} + +// ResolveToolResponseName returns the tool name for a given tool call ID. +// It first checks the provided name map, then falls back to inferring the +// name from the call ID format. +func ResolveToolResponseName(toolCallID string, toolCallNames map[string]string) string { + if toolCallID == "" { + return "" + } + + if name, ok := toolCallNames[toolCallID]; ok && name != "" { + return name + } + + return InferToolNameFromCallID(toolCallID) +} + +// InferToolNameFromCallID extracts a tool name from a call ID in the format +// "call__". Returns the original ID if it doesn't match. +func InferToolNameFromCallID(toolCallID string) string { + if !strings.HasPrefix(toolCallID, "call_") { + return toolCallID + } + + rest := strings.TrimPrefix(toolCallID, "call_") + if idx := strings.LastIndex(rest, "_"); idx > 0 { + candidate := rest[:idx] + if candidate != "" { + return candidate + } + } + + return toolCallID +} diff --git a/pkg/providers/common/google_common_test.go b/pkg/providers/common/google_common_test.go new file mode 100644 index 000000000..cc013dcd1 --- /dev/null +++ b/pkg/providers/common/google_common_test.go @@ -0,0 +1,146 @@ +package common + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +func TestNormalizeStoredToolCall_TopLevelFields(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "search", + Arguments: map[string]any{"q": "hello"}, + } + name, args, sig := NormalizeStoredToolCall(tc) + if name != "search" { + t.Errorf("name = %q, want %q", name, "search") + } + if args["q"] != "hello" { + t.Errorf("args[q] = %v, want %q", args["q"], "hello") + } + if sig != "" { + t.Errorf("thoughtSignature = %q, want empty", sig) + } +} + +func TestNormalizeStoredToolCall_FallsBackToFunction(t *testing.T) { + tc := protocoltypes.ToolCall{ + Function: &protocoltypes.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp"}`, + ThoughtSignature: "sig123", + }, + } + name, args, sig := NormalizeStoredToolCall(tc) + if name != "read_file" { + t.Errorf("name = %q, want %q", name, "read_file") + } + if args["path"] != "/tmp" { + t.Errorf("args[path] = %v, want %q", args["path"], "/tmp") + } + if sig != "sig123" { + t.Errorf("thoughtSignature = %q, want %q", sig, "sig123") + } +} + +func TestNormalizeStoredToolCall_TopLevelNameWithFunctionSig(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "search", + Arguments: map[string]any{"q": "hi"}, + Function: &protocoltypes.FunctionCall{ + ThoughtSignature: "thought1", + }, + } + name, _, sig := NormalizeStoredToolCall(tc) + if name != "search" { + t.Errorf("name = %q, want %q", name, "search") + } + if sig != "thought1" { + t.Errorf("thoughtSignature = %q, want %q", sig, "thought1") + } +} + +func TestNormalizeStoredToolCall_NilArgs(t *testing.T) { + tc := protocoltypes.ToolCall{Name: "test"} + _, args, _ := NormalizeStoredToolCall(tc) + if args == nil { + t.Fatal("args should not be nil") + } + if len(args) != 0 { + t.Errorf("args should be empty, got %v", args) + } +} + +func TestNormalizeStoredToolCall_EmptyArgsParseFromFunction(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "tool", + Arguments: map[string]any{}, + Function: &protocoltypes.FunctionCall{ + Arguments: `{"key":"val"}`, + }, + } + _, args, _ := NormalizeStoredToolCall(tc) + if args["key"] != "val" { + t.Errorf("args[key] = %v, want %q", args["key"], "val") + } +} + +func TestNormalizeStoredToolCall_InvalidFunctionJSON(t *testing.T) { + tc := protocoltypes.ToolCall{ + Name: "tool", + Function: &protocoltypes.FunctionCall{ + Arguments: `not-json`, + }, + } + _, args, _ := NormalizeStoredToolCall(tc) + if len(args) != 0 { + t.Errorf("args should be empty for invalid JSON, got %v", args) + } +} + +func TestResolveToolResponseName_FromMap(t *testing.T) { + names := map[string]string{"call_1": "search"} + got := ResolveToolResponseName("call_1", names) + if got != "search" { + t.Errorf("got %q, want %q", got, "search") + } +} + +func TestResolveToolResponseName_EmptyID(t *testing.T) { + got := ResolveToolResponseName("", map[string]string{"x": "y"}) + if got != "" { + t.Errorf("got %q, want empty", got) + } +} + +func TestResolveToolResponseName_FallsBackToInfer(t *testing.T) { + got := ResolveToolResponseName("call_search_docs_999", map[string]string{}) + if got != "search_docs" { + t.Errorf("got %q, want %q", got, "search_docs") + } +} + +func TestInferToolNameFromCallID(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + {"standard format", "call_search_docs_999", "search_docs"}, + {"single name", "call_read_123", "read"}, + {"no call prefix", "some_id", "some_id"}, + {"call prefix no underscore suffix", "call_onlyname", "call_onlyname"}, + {"empty string", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := InferToolNameFromCallID(tt.id) + if got != tt.want { + t.Errorf( + "InferToolNameFromCallID(%q) = %q, want %q", + tt.id, got, tt.want, + ) + } + }) + } +} diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index 0f1e20ca5..249c1b8de 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -1,64 +1,5 @@ package httpapi -import ( - "encoding/json" - "strings" -) - -func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) { - name := tc.Name - args := tc.Arguments - thoughtSignature := "" - - if name == "" && tc.Function != nil { - name = tc.Function.Name - thoughtSignature = tc.Function.ThoughtSignature - } else if tc.Function != nil { - thoughtSignature = tc.Function.ThoughtSignature - } - - if args == nil { - args = map[string]any{} - } - - if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" { - var parsed map[string]any - if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil { - args = parsed - } - } - - return name, args, thoughtSignature -} - -func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string { - if toolCallID == "" { - return "" - } - - if name, ok := toolCallNames[toolCallID]; ok && name != "" { - return name - } - - return inferToolNameFromCallID(toolCallID) -} - -func inferToolNameFromCallID(toolCallID string) string { - if !strings.HasPrefix(toolCallID, "call_") { - return toolCallID - } - - rest := strings.TrimPrefix(toolCallID, "call_") - if idx := strings.LastIndex(rest, "_"); idx > 0 { - candidate := rest[:idx] - if candidate != "" { - return candidate - } - } - - return toolCallID -} - func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string { if thoughtSignature != "" { return thoughtSignature diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index dab6acd29..9ad4693da 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -185,7 +185,7 @@ func (p *GeminiProvider) buildRequestBody( case "user": if msg.ToolCallID != "" { - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) contents = append(contents, geminiContent{ Role: "user", Parts: []geminiPart{{ @@ -210,7 +210,7 @@ func (p *GeminiProvider) buildRequestBody( content.Parts = append(content.Parts, geminiPart{Text: msg.Content}) } for _, tc := range msg.ToolCalls { - toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc) + toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc) if toolName == "" { continue } @@ -234,7 +234,7 @@ func (p *GeminiProvider) buildRequestBody( } case "tool": - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) contents = append(contents, geminiContent{ Role: "user", Parts: []geminiPart{{ diff --git a/pkg/providers/oauth/antigravity_provider.go b/pkg/providers/oauth/antigravity_provider.go index 38526dd7a..1ac2d9c7f 100644 --- a/pkg/providers/oauth/antigravity_provider.go +++ b/pkg/providers/oauth/antigravity_provider.go @@ -14,6 +14,7 @@ import ( "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers/common" ) const ( @@ -221,7 +222,7 @@ func (p *AntigravityProvider) buildRequest( } case "user": if msg.ToolCallID != "" { - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) // Tool result req.Contents = append(req.Contents, antigravityContent{ Role: "user", @@ -248,7 +249,7 @@ func (p *AntigravityProvider) buildRequest( content.Parts = append(content.Parts, antigravityPart{Text: msg.Content}) } for _, tc := range msg.ToolCalls { - toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc) + toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc) if toolName == "" { logger.WarnCF( "provider.antigravity", @@ -275,7 +276,7 @@ func (p *AntigravityProvider) buildRequest( req.Contents = append(req.Contents, content) } case "tool": - toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames) req.Contents = append(req.Contents, antigravityContent{ Role: "user", Parts: []antigravityPart{{ @@ -328,60 +329,6 @@ func (p *AntigravityProvider) buildRequest( return req } -func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) { - name := tc.Name - args := tc.Arguments - thoughtSignature := "" - - if name == "" && tc.Function != nil { - name = tc.Function.Name - thoughtSignature = tc.Function.ThoughtSignature - } else if tc.Function != nil { - thoughtSignature = tc.Function.ThoughtSignature - } - - if args == nil { - args = map[string]any{} - } - - if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" { - var parsed map[string]any - if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil { - args = parsed - } - } - - return name, args, thoughtSignature -} - -func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string { - if toolCallID == "" { - return "" - } - - if name, ok := toolCallNames[toolCallID]; ok && name != "" { - return name - } - - return inferToolNameFromCallID(toolCallID) -} - -func inferToolNameFromCallID(toolCallID string) string { - if !strings.HasPrefix(toolCallID, "call_") { - return toolCallID - } - - rest := strings.TrimPrefix(toolCallID, "call_") - if idx := strings.LastIndex(rest, "_"); idx > 0 { - candidate := rest[:idx] - if candidate != "" { - return candidate - } - } - - return toolCallID -} - // --- Response parsing --- type antigravityJSONResponse struct { diff --git a/pkg/providers/oauth/antigravity_provider_test.go b/pkg/providers/oauth/antigravity_provider_test.go index 41cb5b0db..2989f8519 100644 --- a/pkg/providers/oauth/antigravity_provider_test.go +++ b/pkg/providers/oauth/antigravity_provider_test.go @@ -48,13 +48,6 @@ func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) { } } -func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) { - got := resolveToolResponseName("call_search_docs_999", map[string]string{}) - if got != "search_docs" { - t.Fatalf("expected inferred tool name search_docs, got %q", got) - } -} - func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) { p := &AntigravityProvider{} body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" +