mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2586 from kunalk16/fix-functions-deduplication
refactor(deduplication): functions deduplication in pkg/providers
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -42,7 +43,7 @@ func NewProvider(token string) *Provider {
|
||||
}
|
||||
|
||||
func NewProviderWithBaseURL(token, apiBase string) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, false)
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL(baseURL),
|
||||
@@ -385,20 +386,3 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeBaseURL(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -51,7 +52,7 @@ func NewProvider(apiKey, apiBase, userAgent string) *Provider {
|
||||
|
||||
// NewProviderWithTimeout creates a provider with custom request timeout.
|
||||
func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
baseURL := common.NormalizeBaseURL(apiBase, defaultBaseURL, true)
|
||||
timeout := defaultRequestTimeout
|
||||
if timeoutSeconds > 0 {
|
||||
timeout = time.Duration(timeoutSeconds) * time.Second
|
||||
@@ -161,7 +162,7 @@ func buildRequestBody(
|
||||
options map[string]any,
|
||||
) (map[string]any, error) {
|
||||
// max_tokens is required and guaranteed by agent loop
|
||||
maxTokens, ok := asInt(options["max_tokens"])
|
||||
maxTokens, ok := common.AsInt(options["max_tokens"])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("max_tokens is required in options")
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func buildRequestBody(
|
||||
}
|
||||
|
||||
// Set temperature from options
|
||||
if temp, ok := asFloat(options["temperature"]); ok {
|
||||
if temp, ok := common.AsFloat(options["temperature"]); ok {
|
||||
result["temperature"] = temp
|
||||
}
|
||||
|
||||
@@ -361,61 +362,6 @@ func parseResponseBody(body []byte) (*LLMResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeBaseURL ensures the base URL is properly formatted.
|
||||
// It removes /v1 suffix if present (to avoid duplication) and always appends /v1.
|
||||
// This handles edge cases like "https://api.example.com/v1/proxy" correctly.
|
||||
func normalizeBaseURL(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
// Remove trailing slashes
|
||||
base = strings.TrimRight(base, "/")
|
||||
|
||||
// Remove /v1 suffix if present (will be re-added)
|
||||
// This prevents duplication for URLs like "https://api.example.com/v1/proxy"
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
|
||||
// Ensure we don't have an empty string after cutting
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
// Add /v1 suffix (required by Anthropic Messages API)
|
||||
return base + "/v1"
|
||||
}
|
||||
|
||||
// Helper functions for type conversion
|
||||
|
||||
func asInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case int64:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic API response structures
|
||||
|
||||
type anthropicMessageResponse struct {
|
||||
|
||||
@@ -372,44 +372,6 @@ func TestParseResponseBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeBaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string defaults to official API",
|
||||
apiBase: "",
|
||||
expected: "https://api.anthropic.com/v1",
|
||||
},
|
||||
{
|
||||
name: "URL without /v1 gets it appended",
|
||||
apiBase: "https://api.example.com/anthropic",
|
||||
expected: "https://api.example.com/anthropic/v1",
|
||||
},
|
||||
{
|
||||
name: "URL with /v1 remains unchanged",
|
||||
apiBase: "https://api.example.com/v1",
|
||||
expected: "https://api.example.com/v1",
|
||||
},
|
||||
{
|
||||
name: "URL with trailing slash gets cleaned",
|
||||
apiBase: "https://api.example.com/anthropic/",
|
||||
expected: "https://api.example.com/anthropic/v1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeBaseURL(tt.apiBase)
|
||||
if got != tt.expected {
|
||||
t.Errorf("normalizeBaseURL(%q) = %q, want %q", tt.apiBase, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
provider := NewProvider("test-key", "https://api.example.com", "")
|
||||
if provider == nil {
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package common
|
||||
|
||||
import "strings"
|
||||
|
||||
// NormalizeBaseURL ensures the Anthropic base URL is properly formatted.
|
||||
// It removes a trailing /v1 suffix if present (to avoid duplication), then
|
||||
// re-appends /v1 when appendV1Suffix is true. An empty apiBase falls back to
|
||||
// defaultBaseURL.
|
||||
func NormalizeBaseURL(apiBase, defaultBaseURL string, appendV1Suffix bool) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if before, ok := strings.CutSuffix(base, "/v1"); ok {
|
||||
base = before
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
if appendV1Suffix {
|
||||
return base + "/v1"
|
||||
}
|
||||
return base
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package common
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeAnthropicBaseURL(t *testing.T) {
|
||||
const defaultURL = "https://api.anthropic.com"
|
||||
const defaultURLWithV1 = "https://api.anthropic.com/v1"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apiBase string
|
||||
defaultBase string
|
||||
appendV1Suffix bool
|
||||
expected string
|
||||
}{
|
||||
{"empty with v1", "", defaultURLWithV1, true, defaultURLWithV1},
|
||||
{"empty without v1", "", defaultURL, false, defaultURL},
|
||||
{
|
||||
"URL without v1 gets it appended",
|
||||
"https://api.example.com/anthropic", defaultURLWithV1,
|
||||
true, "https://api.example.com/anthropic/v1",
|
||||
},
|
||||
{
|
||||
"URL without v1 stays as-is",
|
||||
"https://api.example.com/anthropic", defaultURL,
|
||||
false, "https://api.example.com/anthropic",
|
||||
},
|
||||
{
|
||||
"URL with v1 remains unchanged when appending",
|
||||
"https://api.example.com/v1", defaultURLWithV1,
|
||||
true, "https://api.example.com/v1",
|
||||
},
|
||||
{
|
||||
"URL with v1 gets it stripped when not appending",
|
||||
"https://api.example.com/v1", defaultURL,
|
||||
false, "https://api.example.com",
|
||||
},
|
||||
{
|
||||
"trailing slash cleaned with v1",
|
||||
"https://api.example.com/anthropic/", defaultURLWithV1,
|
||||
true, "https://api.example.com/anthropic/v1",
|
||||
},
|
||||
{
|
||||
"trailing slash cleaned without v1",
|
||||
"https://api.example.com/anthropic/", defaultURL,
|
||||
false, "https://api.example.com/anthropic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NormalizeBaseURL(tt.apiBase, tt.defaultBase, tt.appendV1Suffix)
|
||||
if got != tt.expected {
|
||||
t.Errorf("NormalizeAnthropicBaseURL(%q, %q, %v) = %q, want %q",
|
||||
tt.apiBase, tt.defaultBase, tt.appendV1Suffix, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func SerializeMessages(messages []Message) []any {
|
||||
continue
|
||||
}
|
||||
|
||||
if format, data, ok := parseDataAudioURL(mediaURL); ok {
|
||||
if format, data, ok := ParseDataAudioURL(mediaURL); ok {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "input_audio",
|
||||
"input_audio": map[string]any{
|
||||
@@ -205,7 +205,8 @@ func serializeToolCalls(toolCalls []ToolCall) []openaiToolCall {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseDataAudioURL(mediaURL string) (format, data string, ok bool) {
|
||||
// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL.
|
||||
func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) {
|
||||
if !strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
@@ -660,6 +660,37 @@ func TestAsFloat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseDataAudioURL tests ---
|
||||
|
||||
func TestParseDataAudioURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mediaURL string
|
||||
wantFormat string
|
||||
wantData string
|
||||
wantOK bool
|
||||
}{
|
||||
{"valid mp3", "data:audio/mp3;base64,SGVsbG8=", "mp3", "SGVsbG8=", true},
|
||||
{"valid wav", "data:audio/wav;base64,AAAA", "wav", "AAAA", true},
|
||||
{"not audio", "data:image/png;base64,abc", "", "", false},
|
||||
{"no comma", "data:audio/mp3;base64", "", "", false},
|
||||
{"empty data", "data:audio/mp3;base64,", "", "", false},
|
||||
{"empty string", "", "", "", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
format, data, ok := ParseDataAudioURL(tt.mediaURL)
|
||||
if ok != tt.wantOK || format != tt.wantFormat || data != tt.wantData {
|
||||
t.Errorf(
|
||||
"ParseDataAudioURL(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||
tt.mediaURL, format, data, ok,
|
||||
tt.wantFormat, tt.wantData, tt.wantOK,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- WrapHTMLResponseError tests ---
|
||||
|
||||
func TestWrapHTMLResponseError(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// NormalizeStoredToolCall extracts the tool name, arguments, and thought signature
|
||||
// from a stored ToolCall. It handles both the top-level fields and the nested
|
||||
// Function struct used by different API formats.
|
||||
func NormalizeStoredToolCall(tc protocoltypes.ToolCall) (string, map[string]any, string) {
|
||||
name := tc.Name
|
||||
args := tc.Arguments
|
||||
thoughtSignature := ""
|
||||
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
} else if tc.Function != nil {
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
}
|
||||
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
|
||||
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
|
||||
args = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return name, args, thoughtSignature
|
||||
}
|
||||
|
||||
// ResolveToolResponseName returns the tool name for a given tool call ID.
|
||||
// It first checks the provided name map, then falls back to inferring the
|
||||
// name from the call ID format.
|
||||
func ResolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
|
||||
if toolCallID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
|
||||
return name
|
||||
}
|
||||
|
||||
return InferToolNameFromCallID(toolCallID)
|
||||
}
|
||||
|
||||
// InferToolNameFromCallID extracts a tool name from a call ID in the format
|
||||
// "call_<name>_<suffix>". Returns the original ID if it doesn't match.
|
||||
func InferToolNameFromCallID(toolCallID string) string {
|
||||
if !strings.HasPrefix(toolCallID, "call_") {
|
||||
return toolCallID
|
||||
}
|
||||
|
||||
rest := strings.TrimPrefix(toolCallID, "call_")
|
||||
if idx := strings.LastIndex(rest, "_"); idx > 0 {
|
||||
candidate := rest[:idx]
|
||||
if candidate != "" {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
return toolCallID
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
func TestNormalizeStoredToolCall_TopLevelFields(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"q": "hello"},
|
||||
}
|
||||
name, args, sig := NormalizeStoredToolCall(tc)
|
||||
if name != "search" {
|
||||
t.Errorf("name = %q, want %q", name, "search")
|
||||
}
|
||||
if args["q"] != "hello" {
|
||||
t.Errorf("args[q] = %v, want %q", args["q"], "hello")
|
||||
}
|
||||
if sig != "" {
|
||||
t.Errorf("thoughtSignature = %q, want empty", sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStoredToolCall_FallsBackToFunction(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"/tmp"}`,
|
||||
ThoughtSignature: "sig123",
|
||||
},
|
||||
}
|
||||
name, args, sig := NormalizeStoredToolCall(tc)
|
||||
if name != "read_file" {
|
||||
t.Errorf("name = %q, want %q", name, "read_file")
|
||||
}
|
||||
if args["path"] != "/tmp" {
|
||||
t.Errorf("args[path] = %v, want %q", args["path"], "/tmp")
|
||||
}
|
||||
if sig != "sig123" {
|
||||
t.Errorf("thoughtSignature = %q, want %q", sig, "sig123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStoredToolCall_TopLevelNameWithFunctionSig(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"q": "hi"},
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
ThoughtSignature: "thought1",
|
||||
},
|
||||
}
|
||||
name, _, sig := NormalizeStoredToolCall(tc)
|
||||
if name != "search" {
|
||||
t.Errorf("name = %q, want %q", name, "search")
|
||||
}
|
||||
if sig != "thought1" {
|
||||
t.Errorf("thoughtSignature = %q, want %q", sig, "thought1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStoredToolCall_NilArgs(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{Name: "test"}
|
||||
_, args, _ := NormalizeStoredToolCall(tc)
|
||||
if args == nil {
|
||||
t.Fatal("args should not be nil")
|
||||
}
|
||||
if len(args) != 0 {
|
||||
t.Errorf("args should be empty, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStoredToolCall_EmptyArgsParseFromFunction(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{
|
||||
Name: "tool",
|
||||
Arguments: map[string]any{},
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Arguments: `{"key":"val"}`,
|
||||
},
|
||||
}
|
||||
_, args, _ := NormalizeStoredToolCall(tc)
|
||||
if args["key"] != "val" {
|
||||
t.Errorf("args[key] = %v, want %q", args["key"], "val")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStoredToolCall_InvalidFunctionJSON(t *testing.T) {
|
||||
tc := protocoltypes.ToolCall{
|
||||
Name: "tool",
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Arguments: `not-json`,
|
||||
},
|
||||
}
|
||||
_, args, _ := NormalizeStoredToolCall(tc)
|
||||
if len(args) != 0 {
|
||||
t.Errorf("args should be empty for invalid JSON, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolResponseName_FromMap(t *testing.T) {
|
||||
names := map[string]string{"call_1": "search"}
|
||||
got := ResolveToolResponseName("call_1", names)
|
||||
if got != "search" {
|
||||
t.Errorf("got %q, want %q", got, "search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolResponseName_EmptyID(t *testing.T) {
|
||||
got := ResolveToolResponseName("", map[string]string{"x": "y"})
|
||||
if got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolResponseName_FallsBackToInfer(t *testing.T) {
|
||||
got := ResolveToolResponseName("call_search_docs_999", map[string]string{})
|
||||
if got != "search_docs" {
|
||||
t.Errorf("got %q, want %q", got, "search_docs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferToolNameFromCallID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
want string
|
||||
}{
|
||||
{"standard format", "call_search_docs_999", "search_docs"},
|
||||
{"single name", "call_read_123", "read"},
|
||||
{"no call prefix", "some_id", "some_id"},
|
||||
{"call prefix no underscore suffix", "call_onlyname", "call_onlyname"},
|
||||
{"empty string", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := InferToolNameFromCallID(tt.id)
|
||||
if got != tt.want {
|
||||
t.Errorf(
|
||||
"InferToolNameFromCallID(%q) = %q, want %q",
|
||||
tt.id, got, tt.want,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,63 +1,6 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
|
||||
name := tc.Name
|
||||
args := tc.Arguments
|
||||
thoughtSignature := ""
|
||||
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
} else if tc.Function != nil {
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
}
|
||||
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
|
||||
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
|
||||
args = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return name, args, thoughtSignature
|
||||
}
|
||||
|
||||
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
|
||||
if toolCallID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
|
||||
return name
|
||||
}
|
||||
|
||||
return inferToolNameFromCallID(toolCallID)
|
||||
}
|
||||
|
||||
func inferToolNameFromCallID(toolCallID string) string {
|
||||
if !strings.HasPrefix(toolCallID, "call_") {
|
||||
return toolCallID
|
||||
}
|
||||
|
||||
rest := strings.TrimPrefix(toolCallID, "call_")
|
||||
if idx := strings.LastIndex(rest, "_"); idx > 0 {
|
||||
candidate := rest[:idx]
|
||||
if candidate != "" {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
return toolCallID
|
||||
}
|
||||
import "strings"
|
||||
|
||||
func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string {
|
||||
if thoughtSignature != "" {
|
||||
|
||||
@@ -185,7 +185,7 @@ func (p *GeminiProvider) buildRequestBody(
|
||||
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
contents = append(contents, geminiContent{
|
||||
Role: "user",
|
||||
Parts: []geminiPart{{
|
||||
@@ -210,7 +210,7 @@ func (p *GeminiProvider) buildRequestBody(
|
||||
content.Parts = append(content.Parts, geminiPart{Text: msg.Content})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
|
||||
toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc)
|
||||
if toolName == "" {
|
||||
continue
|
||||
}
|
||||
@@ -234,7 +234,7 @@ func (p *GeminiProvider) buildRequestBody(
|
||||
}
|
||||
|
||||
case "tool":
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
contents = append(contents, geminiContent{
|
||||
Role: "user",
|
||||
Parts: []geminiPart{{
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -221,7 +222,7 @@ func (p *AntigravityProvider) buildRequest(
|
||||
}
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
// Tool result
|
||||
req.Contents = append(req.Contents, antigravityContent{
|
||||
Role: "user",
|
||||
@@ -248,7 +249,7 @@ func (p *AntigravityProvider) buildRequest(
|
||||
content.Parts = append(content.Parts, antigravityPart{Text: msg.Content})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
|
||||
toolName, toolArgs, thoughtSignature := common.NormalizeStoredToolCall(tc)
|
||||
if toolName == "" {
|
||||
logger.WarnCF(
|
||||
"provider.antigravity",
|
||||
@@ -275,7 +276,7 @@ func (p *AntigravityProvider) buildRequest(
|
||||
req.Contents = append(req.Contents, content)
|
||||
}
|
||||
case "tool":
|
||||
toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
toolName := common.ResolveToolResponseName(msg.ToolCallID, toolCallNames)
|
||||
req.Contents = append(req.Contents, antigravityContent{
|
||||
Role: "user",
|
||||
Parts: []antigravityPart{{
|
||||
@@ -328,60 +329,6 @@ func (p *AntigravityProvider) buildRequest(
|
||||
return req
|
||||
}
|
||||
|
||||
func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
|
||||
name := tc.Name
|
||||
args := tc.Arguments
|
||||
thoughtSignature := ""
|
||||
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
} else if tc.Function != nil {
|
||||
thoughtSignature = tc.Function.ThoughtSignature
|
||||
}
|
||||
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
|
||||
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
|
||||
args = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return name, args, thoughtSignature
|
||||
}
|
||||
|
||||
func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string {
|
||||
if toolCallID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if name, ok := toolCallNames[toolCallID]; ok && name != "" {
|
||||
return name
|
||||
}
|
||||
|
||||
return inferToolNameFromCallID(toolCallID)
|
||||
}
|
||||
|
||||
func inferToolNameFromCallID(toolCallID string) string {
|
||||
if !strings.HasPrefix(toolCallID, "call_") {
|
||||
return toolCallID
|
||||
}
|
||||
|
||||
rest := strings.TrimPrefix(toolCallID, "call_")
|
||||
if idx := strings.LastIndex(rest, "_"); idx > 0 {
|
||||
candidate := rest[:idx]
|
||||
if candidate != "" {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
return toolCallID
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
type antigravityJSONResponse struct {
|
||||
|
||||
@@ -48,13 +48,6 @@ func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) {
|
||||
got := resolveToolResponseName("call_search_docs_999", map[string]string{})
|
||||
if got != "search_docs" {
|
||||
t.Fatalf("expected inferred tool name search_docs, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
|
||||
p := &AntigravityProvider{}
|
||||
body := "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hidden reasoning\",\"thought\":true},{\"text\":\"visible answer\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":17,\"totalTokenCount\":216}}}\n" +
|
||||
|
||||
@@ -470,7 +470,9 @@ func (p *Provider) SupportsNativeSearch() bool {
|
||||
return isNativeSearchHost(p.apiBase)
|
||||
}
|
||||
|
||||
func isNativeSearchHost(apiBase string) bool {
|
||||
// isNativeOpenAIOrAzureEndpoint reports whether the given API base points to
|
||||
// OpenAI's own API or an Azure OpenAI deployment.
|
||||
func isNativeOpenAIOrAzureEndpoint(apiBase string) bool {
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
@@ -479,15 +481,14 @@ func isNativeSearchHost(apiBase string) bool {
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
}
|
||||
|
||||
func isNativeSearchHost(apiBase string) bool {
|
||||
return isNativeOpenAIOrAzureEndpoint(apiBase)
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors.
|
||||
func supportsPromptCacheKey(apiBase string) bool {
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := u.Hostname()
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
return isNativeOpenAIOrAzureEndpoint(apiBase)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -118,7 +119,7 @@ func BuildMultipartContent(text string, media []string) responses.ResponseInputM
|
||||
},
|
||||
})
|
||||
} else if strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
if format, data, ok := ParseDataAudioURL(mediaURL); ok {
|
||||
if format, data, ok := common.ParseDataAudioURL(mediaURL); ok {
|
||||
parts = append(parts, responses.ResponseInputContentUnionParam{
|
||||
OfInputFile: &responses.ResponseInputFileParam{
|
||||
FileData: openai.Opt(data),
|
||||
@@ -132,25 +133,6 @@ func BuildMultipartContent(text string, media []string) responses.ResponseInputM
|
||||
return parts
|
||||
}
|
||||
|
||||
// ParseDataAudioURL extracts the format and base64 data from a data:audio/... URL.
|
||||
func ParseDataAudioURL(mediaURL string) (format, data string, ok bool) {
|
||||
if !strings.HasPrefix(mediaURL, "data:audio/") {
|
||||
return "", "", false
|
||||
}
|
||||
payload := strings.TrimPrefix(mediaURL, "data:audio/")
|
||||
meta, data, found := strings.Cut(payload, ",")
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
format, _, _ = strings.Cut(meta, ";")
|
||||
format = strings.TrimSpace(format)
|
||||
data = strings.TrimSpace(data)
|
||||
if format == "" || data == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return format, data, true
|
||||
}
|
||||
|
||||
// ResolveToolCall extracts the function name and JSON arguments string from a ToolCall.
|
||||
// Returns ok=false if the tool call has no name or if arguments fail to marshal.
|
||||
func ResolveToolCall(tc protocoltypes.ToolCall) (name string, arguments string, ok bool) {
|
||||
|
||||
@@ -506,42 +506,6 @@ func TestParseResponseBody_CanceledStatus(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseDataAudioURL tests ---
|
||||
|
||||
func TestParseDataAudioURL_Valid(t *testing.T) {
|
||||
format, data, ok := ParseDataAudioURL("data:audio/mp3;base64,SGVsbG8=")
|
||||
if !ok {
|
||||
t.Fatal("expected ok=true")
|
||||
}
|
||||
if format != "mp3" {
|
||||
t.Errorf("format = %q, want %q", format, "mp3")
|
||||
}
|
||||
if data != "SGVsbG8=" {
|
||||
t.Errorf("data = %q, want %q", data, "SGVsbG8=")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDataAudioURL_NotAudio(t *testing.T) {
|
||||
_, _, ok := ParseDataAudioURL("data:image/png;base64,abc")
|
||||
if ok {
|
||||
t.Error("expected ok=false for non-audio URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDataAudioURL_MalformedNoComma(t *testing.T) {
|
||||
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64")
|
||||
if ok {
|
||||
t.Error("expected ok=false for malformed URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDataAudioURL_EmptyData(t *testing.T) {
|
||||
_, _, ok := ParseDataAudioURL("data:audio/mp3;base64,")
|
||||
if ok {
|
||||
t.Error("expected ok=false for empty data")
|
||||
}
|
||||
}
|
||||
|
||||
// --- BuildMultipartContent tests ---
|
||||
|
||||
func TestBuildMultipartContent_TextOnly(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user