mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(mcp): sanitize MCP tool schemas for Gemini function calling
This commit is contained in:
@@ -12,66 +12,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake
|
||||
return ""
|
||||
}
|
||||
|
||||
var geminiUnsupportedKeywords = map[string]bool{
|
||||
"patternProperties": true,
|
||||
"additionalProperties": true,
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
"examples": true,
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"multipleOf": true,
|
||||
"pattern": true,
|
||||
"format": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"uniqueItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
}
|
||||
|
||||
func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
for k, v := range schema {
|
||||
if geminiUnsupportedKeywords[k] {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
result[k] = sanitizeSchemaForGemini(val)
|
||||
case []any:
|
||||
sanitized := make([]any, len(val))
|
||||
for i, item := range val {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
sanitized[i] = sanitizeSchemaForGemini(m)
|
||||
} else {
|
||||
sanitized[i] = item
|
||||
}
|
||||
}
|
||||
result[k] = sanitized
|
||||
default:
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if _, hasProps := result["properties"]; hasProps {
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "object"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func extractProtocol(model string) (protocol, modelID string) {
|
||||
model = strings.TrimSpace(model)
|
||||
protocol, modelID, found := strings.Cut(model, "/")
|
||||
|
||||
@@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody(
|
||||
funcDecls = append(funcDecls, geminiFunctionDeclaration{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
|
||||
Parameters: common.SanitizeSchemaForGemini(t.Function.Parameters),
|
||||
})
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
|
||||
@@ -259,6 +262,65 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"parent": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/pageParent"},
|
||||
map[string]any{"$ref": "#/$defs/databaseParent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"$defs": map[string]any{
|
||||
"pageParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"page_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"page_id"},
|
||||
},
|
||||
"databaseParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"database_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"database_id"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
[]ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "mcp_notion_create",
|
||||
Description: "Create a Notion object",
|
||||
Parameters: schema,
|
||||
},
|
||||
}},
|
||||
"gemini-3-flash-preview",
|
||||
nil,
|
||||
)
|
||||
|
||||
tools, ok := body["tools"].([]geminiTool)
|
||||
if !ok || len(tools) != 1 {
|
||||
t.Fatalf("tools = %#v, want one geminiTool", body["tools"])
|
||||
}
|
||||
got, ok := tools[0].FunctionDeclarations[0].Parameters.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters)
|
||||
}
|
||||
|
||||
want := providercommon.SanitizeSchemaForGemini(schema)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
Reference in New Issue
Block a user