Deduplicate further functions

This commit is contained in:
Kunal Karmakar
2026-04-19 06:48:28 +00:00
parent bc077db0ee
commit 4ae11406d2
11 changed files with 311 additions and 186 deletions
+1 -1
View File
@@ -43,7 +43,7 @@ func NewProvider(token string) *Provider {
}
func NewProviderWithBaseURL(token, apiBase string) *Provider {
baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, false)
baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, false)
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL(baseURL),
+1 -1
View File
@@ -52,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider {
// NewProviderWithTimeout creates a provider with custom request timeout.
func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
baseURL := common.NormalizeAnthropicBaseURL(apiBase, defaultBaseURL, true)
baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, true)
timeout := defaultRequestTimeout
if timeoutSeconds > 0 {
timeout = time.Duration(timeoutSeconds) * time.Second
+27
View File
@@ -0,0 +1,27 @@
package common
import "strings"
// NormalizeBaseURL ensures the Anthropic base URL is properly formatted.
// It removes a trailing /v1 suffix if present (to avoid duplication), then
// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to
// defaultBaseURL.
func NormalizeBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if before, ok := strings.CutSuffix(base, "/v1"); ok {
base = before
}
if base == "" {
return defaultBaseURL
}
if appendV1Suffix {
return base + "/v1"
}
return base
}
@@ -0,0 +1,59 @@
package common
import "testing"
func TestNormalizeAnthropicBaseURL(t *testing.T) {
const defaultURL = "https://api.anthropic.com"
const defaultURLWithV1 = "https://api.anthropic.com/v1"
tests := []struct {
name string
apiBase string
defaultBase string
appendV1Suffix bool
expected string
}{
{"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1},
{"empty without v1", "", defaultURL, false, defaultURL},
{
"URL without v1 gets it appended",
"https://api.example.com/anthropic", defaultURLWithV1,
true, "https://api.example.com/anthropic/v1",
},
{
"URL without v1 stays as-is",
"https://api.example.com/anthropic", defaultURL,
false, "https://api.example.com/anthropic",
},
{
"URL with v1 remains unchanged when appending",
"https://api.example.com/v1", defaultURLWithV1,
true, "https://api.example.com/v1",
},
{
"URL with v1 gets it stripped when not appending",
"https://api.example.com/v1", defaultURL,
false, "https://api.example.com",
},
{
"trailing slash cleaned with v1",
"https://api.example.com/anthropic/", defaultURLWithV1,
true, "https://api.example.com/anthropic/v1",
},
{
"trailing slash cleaned without v1",
"https://api.example.com/anthropic/", defaultURL,
false, "https://api.example.com/anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NormalizeBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix)
if got != tt.expected {
t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q",
tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected)
}
})
}
}
-58
View File
@@ -722,64 +722,6 @@ func TestParseDataAudioURL(t *testing.T) {
}
}
// --- NormalizeAnthropicBaseURL tests ---
func TestNormalizeAnthropicBaseURL(t *testing.T) {
const defaultURL = "https://api.anthropic.com"
const defaultURLWithV1 = "https://api.anthropic.com/v1"
tests := []struct {
name string
apiBase string
defaultBase string
appendV1Suffix bool
expected string
}{
{"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1},
{"empty without v1", "", defaultURL, false, defaultURL},
{
"URL without v1 gets it appended",
"https://api.example.com/anthropic", defaultURLWithV1,
true, "https://api.example.com/anthropic/v1",
},
{
"URL without v1 stays as-is",
"https://api.example.com/anthropic", defaultURL,
false, "https://api.example.com/anthropic",
},
{
"URL with v1 remains unchanged when appending",
"https://api.example.com/v1", defaultURLWithV1,
true, "https://api.example.com/v1",
},
{
"URL with v1 gets it stripped when not appending",
"https://api.example.com/v1", defaultURL,
false, "https://api.example.com",
},
{
"trailing slash cleaned with v1",
"https://api.example.com/anthropic/", defaultURLWithV1,
true, "https://api.example.com/anthropic/v1",
},
{
"trailing slash cleaned without v1",
"https://api.example.com/anthropic/", defaultURL,
false, "https://api.example.com/anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NormalizeAnthropicBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix)
if got != tt.expected {
t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q",
tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected)
}
})
}
}
// --- WrapHTMLResponseError tests ---
func TestWrapHTMLResponseError(t *testing.T) {
+70
View File
@@ -0,0 +1,70 @@
package common
import (
"encoding/json"
"strings"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// NormalizeStoredToolCall extracts the tool name, arguments, and thought signature
// from a stored ToolCall. It handles both the top-level fields and the nested
// Function struct used by different API formats.
func NormalizeStoredToolCall(tc protocoltypes.ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
if name == "" && tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
} else if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
if args == nil {
args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
}
return name, args, thoughtSignature
}
// ResolveToolResponseName returns the tool name for a given tool call ID.
// It first checks the provided name map, then falls back to inferring the
// name from the call ID format.
func ResolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
if toolCallID == "" {
return ""
}
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
return name
}
return InferToolNameFromCallID(toolCallID)
}
// InferToolNameFromCallID extracts a tool name from a call ID in the format
// "call_<name>_<suffix>". Returns the original ID if it doesn't match.
func InferToolNameFromCallID(toolCallID string) string {
if !strings.HasPrefix(toolCallID, "call_") {
return toolCallID
}
rest := strings.TrimPrefix(toolCallID, "call_")
if idx := strings.LastIndex(rest, "_"); idx > 0 {
candidate := rest[:idx]
if candidate != "" {
return candidate
}
}
return toolCallID
}
+146
View File
@@ -0,0 +1,146 @@
package common
import (
"testing"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestNormalizeStoredToolCall_TopLevelFields(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "search",
Arguments: map[string]any{"q": "hello"},
}
name, args, sig := NormalizeStoredToolCall(tc)
if name != "search" {
t.Errorf("name = %q, want %q", name, "search")
}
if args["q"] != "hello" {
t.Errorf("args[q] = %v, want %q", args["q"], "hello")
}
if sig != "" {
t.Errorf("thoughtSignature = %q, want empty", sig)
}
}
func TestNormalizeStoredToolCall_FallsBackToFunction(t *testing.T) {
tc := protocoltypes.ToolCall{
Function: &protocoltypes.FunctionCall{
Name: "read_file",
Arguments: `{"path":"/tmp"}`,
ThoughtSignature: "sig123",
},
}
name, args, sig := NormalizeStoredToolCall(tc)
if name != "read_file" {
t.Errorf("name = %q, want %q", name, "read_file")
}
if args["path"] != "/tmp" {
t.Errorf("args[path] = %v, want %q", args["path"], "/tmp")
}
if sig != "sig123" {
t.Errorf("thoughtSignature = %q, want %q", sig, "sig123")
}
}
func TestNormalizeStoredToolCall_TopLevelNameWithFunctionSig(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "search",
Arguments: map[string]any{"q": "hi"},
Function: &protocoltypes.FunctionCall{
ThoughtSignature: "thought1",
},
}
name, _, sig := NormalizeStoredToolCall(tc)
if name != "search" {
t.Errorf("name = %q, want %q", name, "search")
}
if sig != "thought1" {
t.Errorf("thoughtSignature = %q, want %q", sig, "thought1")
}
}
func TestNormalizeStoredToolCall_NilArgs(t *testing.T) {
tc := protocoltypes.ToolCall{Name: "test"}
_, args, _ := NormalizeStoredToolCall(tc)
if args == nil {
t.Fatal("args should not be nil")
}
if len(args) != 0 {
t.Errorf("args should be empty, got %v", args)
}
}
func TestNormalizeStoredToolCall_EmptyArgsParseFromFunction(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "tool",
Arguments: map[string]any{},
Function: &protocoltypes.FunctionCall{
Arguments: `{"key":"val"}`,
},
}
_, args, _ := NormalizeStoredToolCall(tc)
if args["key"] != "val" {
t.Errorf("args[key] = %v, want %q", args["key"], "val")
}
}
func TestNormalizeStoredToolCall_InvalidFunctionJSON(t *testing.T) {
tc := protocoltypes.ToolCall{
Name: "tool",
Function: &protocoltypes.FunctionCall{
Arguments: `not-json`,
},
}
_, args, _ := NormalizeStoredToolCall(tc)
if len(args) != 0 {
t.Errorf("args should be empty for invalid JSON, got %v", args)
}
}
func TestResolveToolResponseName_FromMap(t *testing.T) {
names := map[string]string{"call_1": "search"}
got := ResolveToolResponseName("call_1", names)
if got != "search" {
t.Errorf("got %q, want %q", got, "search")
}
}
func TestResolveToolResponseName_EmptyID(t *testing.T) {
got := ResolveToolResponseName("", map[string]string{"x": "y"})
if got != "" {
t.Errorf("got %q, want empty", got)
}
}
func TestResolveToolResponseName_FallsBackToInfer(t *testing.T) {
got := ResolveToolResponseName("call_search_docs_999", map[string]string{})
if got != "search_docs" {
t.Errorf("got %q, want %q", got, "search_docs")
}
}
func TestInferToolNameFromCallID(t *testing.T) {
tests := []struct {
name string
id string
want string
}{
{"standard format", "call_search_docs_999", "search_docs"},
{"single name", "call_read_123", "read"},
{"no call prefix", "some_id", "some_id"},
{"call prefix no underscore suffix", "call_onlyname", "call_onlyname"},
{"empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := InferToolNameFromCallID(tt.id)
if got != tt.want {
t.Errorf(
"InferToolNameFromCallID(%q) = %q, want %q",
tt.id, got, tt.want,
)
}
})
}
}
-59
View File
@@ -1,64 +1,5 @@
package httpapi
import (
"encoding/json"
"strings"
)
func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
if name == "" && tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
} else if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
if args == nil {
args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
}
return name, args, thoughtSignature
}
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
if toolCallID == "" {
return ""
}
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
return name
}
return inferToolNameFromCallID(toolCallID)
}
func inferToolNameFromCallID(toolCallID string) string {
if !strings.HasPrefix(toolCallID, "call_") {
return toolCallID
}
rest := strings.TrimPrefix(toolCallID, "call_")
if idx := strings.LastIndex(rest, "_"); idx > 0 {
candidate := rest[:idx]
if candidate != "" {
return candidate
}
}
return toolCallID
}
func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string {
if thoughtSignature != "" {
return thoughtSignature
+3 -3
View File
@@ -185,7 +185,7 @@ func (p *GeminiProvider) buildRequestBody(
case "user":
if msg.ToolCallID != "" {
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
@@ -210,7 +210,7 @@ func (p *GeminiProvider) buildRequestBody(
content.Parts = append(content.Parts, geminiPart{Text: msg.Content})
}
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc)
if toolName == "" {
continue
}
@@ -234,7 +234,7 @@ func (p *GeminiProvider) buildRequestBody(
}
case "tool":
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
contents = append(contents, geminiContent{
Role: "user",
Parts: []geminiPart{{
+4 -57
View File
@@ -14,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
const (
@@ -221,7 +222,7 @@ func (p *AntigravityProvider) buildRequest(
}
case "user":
if msg.ToolCallID != "" {
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
// Tool result
req.Contents = append(req.Contents, antigravityContent{
Role: "user",
@@ -248,7 +249,7 @@ func (p *AntigravityProvider) buildRequest(
content.Parts = append(content.Parts, antigravityPart{Text: msg.Content})
}
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc)
if toolName == "" {
logger.WarnCF(
"provider.antigravity",
@@ -275,7 +276,7 @@ func (p *AntigravityProvider) buildRequest(
req.Contents = append(req.Contents, content)
}
case "tool":
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
req.Contents = append(req.Contents, antigravityContent{
Role: "user",
Parts: []antigravityPart{{
@@ -328,60 +329,6 @@ func (p *AntigravityProvider) buildRequest(
return req
}
func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
if name == "" && tc.Function != nil {
name = tc.Function.Name
thoughtSignature = tc.Function.ThoughtSignature
} else if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
if args == nil {
args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
}
return name, args, thoughtSignature
}
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
if toolCallID == "" {
return ""
}
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
return name
}
return inferToolNameFromCallID(toolCallID)
}
func inferToolNameFromCallID(toolCallID string) string {
if !strings.HasPrefix(toolCallID, "call_") {
return toolCallID
}
rest := strings.TrimPrefix(toolCallID, "call_")
if idx := strings.LastIndex(rest, "_"); idx > 0 {
candidate := rest[:idx]
if candidate != "" {
return candidate
}
}
return toolCallID
}
// --- Response parsing ---
type antigravityJSONResponse struct {
@@ -48,13 +48,6 @@ func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) {
}
}
func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) {
got := resolveToolResponseName("call_search_docs_999", map[string]string{})
if got != "search_docs" {
t.Fatalf("expected inferred tool name search_docs, got %q", got)
}
}
func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
p := &AntigravityProvider{}
body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" +