mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(mcp): sanitize MCP tool schemas for Gemini function calling
This commit is contained in:
@@ -0,0 +1,636 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxGeminiSchemaDepth = 64
|
||||
|
||||
var geminiSupportedTypes = map[string]bool{
|
||||
"array": true,
|
||||
"boolean": true,
|
||||
"integer": true,
|
||||
"number": true,
|
||||
"object": true,
|
||||
"string": true,
|
||||
}
|
||||
|
||||
// SanitizeSchemaForGemini reduces a JSON Schema to the conservative subset
|
||||
// accepted by Gemini-style function declarations. It resolves local refs,
|
||||
// collapses composition keywords like anyOf/oneOf/allOf, and strips advanced
|
||||
// keywords that Gemini rejects.
|
||||
func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sanitizer := geminiSchemaSanitizer{root: schema}
|
||||
result := sanitizer.sanitizeNode(schema, nil, 0)
|
||||
if len(result) == 0 {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
if _, hasProps := result["properties"]; hasProps {
|
||||
result["type"] = "object"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type geminiSchemaSanitizer struct {
|
||||
root map[string]any
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) sanitizeNode(
|
||||
node map[string]any,
|
||||
refTrail map[string]struct{},
|
||||
depth int,
|
||||
) map[string]any {
|
||||
if node == nil || depth > maxGeminiSchemaDepth {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
normalized := s.normalizeNode(node, refTrail, depth)
|
||||
if len(normalized) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
if desc, ok := normalized["description"].(string); ok && strings.TrimSpace(desc) != "" {
|
||||
result["description"] = desc
|
||||
}
|
||||
|
||||
if schemaType := sanitizeGeminiSchemaType(normalized["type"]); schemaType != "" {
|
||||
result["type"] = schemaType
|
||||
}
|
||||
|
||||
if enumValues := sanitizeGeminiEnum(normalized["enum"]); len(enumValues) > 0 {
|
||||
result["enum"] = enumValues
|
||||
}
|
||||
|
||||
if propsRaw, ok := normalized["properties"].(map[string]any); ok {
|
||||
props := make(map[string]any, len(propsRaw))
|
||||
for name, rawProp := range propsRaw {
|
||||
propSchema, ok := rawProp.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
sanitizedProp := s.sanitizeNode(propSchema, refTrail, depth+1)
|
||||
if len(sanitizedProp) == 0 {
|
||||
sanitizedProp = map[string]any{}
|
||||
}
|
||||
props[name] = sanitizedProp
|
||||
}
|
||||
result["properties"] = props
|
||||
result["type"] = "object"
|
||||
if required := sanitizeGeminiRequired(normalized["required"], props); len(required) > 0 {
|
||||
result["required"] = required
|
||||
}
|
||||
}
|
||||
|
||||
if itemsRaw, ok := normalized["items"].(map[string]any); ok {
|
||||
items := s.sanitizeNode(itemsRaw, refTrail, depth+1)
|
||||
if len(items) == 0 {
|
||||
items = map[string]any{}
|
||||
}
|
||||
result["items"] = items
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "array"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) normalizeNode(
|
||||
node map[string]any,
|
||||
refTrail map[string]struct{},
|
||||
depth int,
|
||||
) map[string]any {
|
||||
if node == nil || depth > maxGeminiSchemaDepth {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
normalized := cloneGeminiSchemaMap(node)
|
||||
|
||||
if ref, ok := normalized["$ref"].(string); ok {
|
||||
delete(normalized, "$ref")
|
||||
if _, seen := refTrail[ref]; !seen {
|
||||
if target, ok := s.resolveLocalSchemaRef(ref); ok {
|
||||
nextTrail := cloneRefTrail(refTrail)
|
||||
nextTrail[ref] = struct{}{}
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
s.normalizeNode(target, nextTrail, depth+1),
|
||||
normalized,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rawAllOf, ok := normalized["allOf"]; ok {
|
||||
delete(normalized, "allOf")
|
||||
for _, part := range schemaSlice(rawAllOf) {
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
normalized,
|
||||
s.normalizeNode(part, refTrail, depth+1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if rawAnyOf, ok := normalized["anyOf"]; ok {
|
||||
delete(normalized, "anyOf")
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
s.mergeUnionBranches(schemaSlice(rawAnyOf), refTrail, depth+1),
|
||||
normalized,
|
||||
)
|
||||
}
|
||||
|
||||
if rawOneOf, ok := normalized["oneOf"]; ok {
|
||||
delete(normalized, "oneOf")
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
s.mergeUnionBranches(schemaSlice(rawOneOf), refTrail, depth+1),
|
||||
normalized,
|
||||
)
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) mergeUnionBranches(
|
||||
branches []map[string]any,
|
||||
refTrail map[string]struct{},
|
||||
depth int,
|
||||
) map[string]any {
|
||||
if len(branches) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
objectBranches := make([]map[string]any, 0, len(branches))
|
||||
arrayBranches := make([]map[string]any, 0, len(branches))
|
||||
nonNullBranches := make([]map[string]any, 0, len(branches))
|
||||
sameType := ""
|
||||
sameTypeConsistent := true
|
||||
|
||||
for _, branch := range branches {
|
||||
normalized := s.normalizeNode(branch, refTrail, depth+1)
|
||||
if len(normalized) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
branchType := geminiSchemaBranchType(normalized["type"])
|
||||
if branchType == "null" {
|
||||
continue
|
||||
}
|
||||
nonNullBranches = append(nonNullBranches, normalized)
|
||||
|
||||
if sameType == "" {
|
||||
sameType = branchType
|
||||
} else if branchType != "" && branchType != sameType {
|
||||
sameTypeConsistent = false
|
||||
}
|
||||
|
||||
if branchType == "object" || hasSchemaProperties(normalized) {
|
||||
objectBranches = append(objectBranches, normalized)
|
||||
continue
|
||||
}
|
||||
if branchType == "array" || hasSchemaItems(normalized) {
|
||||
arrayBranches = append(arrayBranches, normalized)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonNullBranches) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
if len(objectBranches) > 0 {
|
||||
return mergeUnionObjectSchemas(objectBranches)
|
||||
}
|
||||
if len(arrayBranches) == len(nonNullBranches) && len(arrayBranches) > 0 {
|
||||
return mergeUnionArraySchemas(arrayBranches)
|
||||
}
|
||||
if sameTypeConsistent && sameType != "" {
|
||||
merged := map[string]any{}
|
||||
for _, branch := range nonNullBranches {
|
||||
merged = mergeGeminiSchemaMaps(merged, branch)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
best := nonNullBranches[0]
|
||||
bestScore := geminiUnionBranchScore(best)
|
||||
for _, branch := range nonNullBranches[1:] {
|
||||
if score := geminiUnionBranchScore(branch); score > bestScore {
|
||||
best = branch
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
return cloneGeminiSchemaMap(best)
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) resolveLocalSchemaRef(ref string) (map[string]any, bool) {
|
||||
if ref == "#" {
|
||||
return s.root, true
|
||||
}
|
||||
if !strings.HasPrefix(ref, "#/") {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var current any = s.root
|
||||
for _, rawToken := range strings.Split(strings.TrimPrefix(ref, "#/"), "/") {
|
||||
token := strings.ReplaceAll(strings.ReplaceAll(rawToken, "~1", "/"), "~0", "~")
|
||||
switch value := current.(type) {
|
||||
case map[string]any:
|
||||
next, ok := value[token]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
current = next
|
||||
case []any:
|
||||
index, err := strconv.Atoi(token)
|
||||
if err != nil || index < 0 || index >= len(value) {
|
||||
return nil, false
|
||||
}
|
||||
current = value[index]
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
resolved, ok := current.(map[string]any)
|
||||
return resolved, ok
|
||||
}
|
||||
|
||||
func mergeUnionObjectSchemas(branches []map[string]any) map[string]any {
|
||||
merged := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
|
||||
var commonRequired map[string]struct{}
|
||||
var requiredOrder []string
|
||||
|
||||
for i, branch := range branches {
|
||||
merged = mergeGeminiSchemaMaps(merged, branch)
|
||||
|
||||
required := requiredStrings(branch["required"])
|
||||
if i == 0 {
|
||||
commonRequired = make(map[string]struct{}, len(required))
|
||||
requiredOrder = append(requiredOrder, required...)
|
||||
for _, name := range required {
|
||||
commonRequired[name] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
current := make(map[string]struct{}, len(required))
|
||||
for _, name := range required {
|
||||
current[name] = struct{}{}
|
||||
}
|
||||
for name := range commonRequired {
|
||||
if _, ok := current[name]; !ok {
|
||||
delete(commonRequired, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(commonRequired) > 0 {
|
||||
filtered := make([]string, 0, len(commonRequired))
|
||||
for _, name := range requiredOrder {
|
||||
if _, ok := commonRequired[name]; ok {
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
merged["required"] = filtered
|
||||
}
|
||||
} else {
|
||||
delete(merged, "required")
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func mergeUnionArraySchemas(branches []map[string]any) map[string]any {
|
||||
merged := map[string]any{
|
||||
"type": "array",
|
||||
}
|
||||
for _, branch := range branches {
|
||||
merged = mergeGeminiSchemaMaps(merged, branch)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func mergeGeminiSchemaMaps(base map[string]any, overlay map[string]any) map[string]any {
|
||||
if len(base) == 0 {
|
||||
return cloneGeminiSchemaMap(overlay)
|
||||
}
|
||||
if len(overlay) == 0 {
|
||||
return cloneGeminiSchemaMap(base)
|
||||
}
|
||||
|
||||
result := cloneGeminiSchemaMap(base)
|
||||
for key, value := range overlay {
|
||||
switch key {
|
||||
case "properties":
|
||||
overlayProps, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
existing, _ := result["properties"].(map[string]any)
|
||||
mergedProps := cloneGeminiSchemaMap(existing)
|
||||
if mergedProps == nil {
|
||||
mergedProps = make(map[string]any, len(overlayProps))
|
||||
}
|
||||
for name, rawProp := range overlayProps {
|
||||
propSchema, ok := rawProp.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if existingProp, ok := mergedProps[name].(map[string]any); ok {
|
||||
mergedProps[name] = mergeGeminiSchemaMaps(existingProp, propSchema)
|
||||
} else {
|
||||
mergedProps[name] = cloneGeminiSchemaMap(propSchema)
|
||||
}
|
||||
}
|
||||
result["properties"] = mergedProps
|
||||
case "items":
|
||||
overlayItems, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if existingItems, ok := result["items"].(map[string]any); ok {
|
||||
result["items"] = mergeGeminiSchemaMaps(existingItems, overlayItems)
|
||||
} else {
|
||||
result["items"] = cloneGeminiSchemaMap(overlayItems)
|
||||
}
|
||||
case "required":
|
||||
if merged := mergeRequiredLists(result["required"], value); len(merged) > 0 {
|
||||
result["required"] = merged
|
||||
}
|
||||
case "type":
|
||||
if mergedType := mergeGeminiSchemaTypes(result["type"], value); mergedType != "" {
|
||||
result["type"] = mergedType
|
||||
} else {
|
||||
delete(result, "type")
|
||||
}
|
||||
case "description":
|
||||
desc, ok := value.(string)
|
||||
if ok && strings.TrimSpace(desc) != "" {
|
||||
result["description"] = desc
|
||||
}
|
||||
default:
|
||||
result[key] = cloneGeminiSchemaValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func mergeGeminiSchemaTypes(left any, right any) string {
|
||||
leftType := geminiSchemaBranchType(left)
|
||||
rightType := geminiSchemaBranchType(right)
|
||||
|
||||
switch {
|
||||
case leftType == "":
|
||||
return rightType
|
||||
case rightType == "":
|
||||
return leftType
|
||||
case leftType == rightType:
|
||||
return leftType
|
||||
case leftType == "null":
|
||||
return rightType
|
||||
case rightType == "null":
|
||||
return leftType
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeGeminiSchemaType(raw any) string {
|
||||
typeName := geminiSchemaBranchType(raw)
|
||||
if typeName == "null" {
|
||||
return ""
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
|
||||
func geminiSchemaBranchType(raw any) string {
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
if value == "null" {
|
||||
return value
|
||||
}
|
||||
if geminiSupportedTypes[value] {
|
||||
return value
|
||||
}
|
||||
return ""
|
||||
case []string:
|
||||
return geminiSchemaBranchType(stringSliceToAny(value))
|
||||
case []any:
|
||||
candidate := ""
|
||||
sawNull := false
|
||||
for _, item := range value {
|
||||
typeName, ok := item.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if typeName == "null" {
|
||||
sawNull = true
|
||||
continue
|
||||
}
|
||||
if !geminiSupportedTypes[typeName] {
|
||||
continue
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = typeName
|
||||
continue
|
||||
}
|
||||
if candidate != typeName {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if candidate == "" && sawNull {
|
||||
return "null"
|
||||
}
|
||||
return candidate
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeGeminiEnum(raw any) []any {
|
||||
values, ok := raw.([]any)
|
||||
if !ok {
|
||||
if stringValues, ok := raw.([]string); ok {
|
||||
return stringSliceToAny(stringValues)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
sanitized := make([]any, 0, len(values))
|
||||
for _, value := range values {
|
||||
switch value.(type) {
|
||||
case string, bool, float64, int, int32, int64:
|
||||
sanitized = append(sanitized, value)
|
||||
}
|
||||
}
|
||||
if len(sanitized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func sanitizeGeminiRequired(raw any, properties map[string]any) []string {
|
||||
required := requiredStrings(raw)
|
||||
if len(required) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, len(required))
|
||||
seen := make(map[string]struct{}, len(required))
|
||||
for _, name := range required {
|
||||
if _, ok := properties[name]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func requiredStrings(raw any) []string {
|
||||
switch value := raw.(type) {
|
||||
case []string:
|
||||
return append([]string(nil), value...)
|
||||
case []any:
|
||||
required := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
name, ok := item.(string)
|
||||
if ok {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
return required
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func mergeRequiredLists(left any, right any) []string {
|
||||
merged := make([]string, 0)
|
||||
seen := map[string]struct{}{}
|
||||
|
||||
for _, name := range append(requiredStrings(left), requiredStrings(right)...) {
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func geminiUnionBranchScore(schema map[string]any) int {
|
||||
score := 0
|
||||
if hasSchemaProperties(schema) {
|
||||
score += 20
|
||||
}
|
||||
if hasSchemaItems(schema) {
|
||||
score += 10
|
||||
}
|
||||
if _, ok := schema["enum"]; ok {
|
||||
score += 5
|
||||
}
|
||||
if _, ok := schema["description"]; ok {
|
||||
score += 2
|
||||
}
|
||||
score += len(schema)
|
||||
return score
|
||||
}
|
||||
|
||||
func hasSchemaProperties(schema map[string]any) bool {
|
||||
props, ok := schema["properties"].(map[string]any)
|
||||
return ok && len(props) > 0
|
||||
}
|
||||
|
||||
func hasSchemaItems(schema map[string]any) bool {
|
||||
_, ok := schema["items"].(map[string]any)
|
||||
return ok
|
||||
}
|
||||
|
||||
func schemaSlice(raw any) []map[string]any {
|
||||
switch value := raw.(type) {
|
||||
case []map[string]any:
|
||||
return append([]map[string]any(nil), value...)
|
||||
case []any:
|
||||
schemas := make([]map[string]any, 0, len(value))
|
||||
for _, item := range value {
|
||||
schema, ok := item.(map[string]any)
|
||||
if ok {
|
||||
schemas = append(schemas, schema)
|
||||
}
|
||||
}
|
||||
return schemas
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func cloneGeminiSchemaMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for key, value := range in {
|
||||
out[key] = cloneGeminiSchemaValue(value)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneGeminiSchemaValue(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
return cloneGeminiSchemaMap(typed)
|
||||
case []any:
|
||||
out := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = cloneGeminiSchemaValue(item)
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
return append([]string(nil), typed...)
|
||||
default:
|
||||
return typed
|
||||
}
|
||||
}
|
||||
|
||||
func cloneRefTrail(in map[string]struct{}) map[string]struct{} {
|
||||
if len(in) == 0 {
|
||||
return make(map[string]struct{})
|
||||
}
|
||||
out := make(map[string]struct{}, len(in))
|
||||
for key := range in {
|
||||
out[key] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stringSliceToAny(values []string) []any {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]any, len(values))
|
||||
for i, value := range values {
|
||||
result[i] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package common
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeSchemaForGemini_DereferencesRefsAndFlattensUnions(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"parent": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/pageParent"},
|
||||
map[string]any{"$ref": "#/$defs/databaseParent"},
|
||||
},
|
||||
},
|
||||
"icon": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/emoji"},
|
||||
map[string]any{"type": "null"},
|
||||
},
|
||||
},
|
||||
"data": map[string]any{
|
||||
"$ref": "#/$defs/dataPayload",
|
||||
},
|
||||
},
|
||||
"required": []any{"parent", "icon", "missing"},
|
||||
"$defs": map[string]any{
|
||||
"pageParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"page_id": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []any{"page_id"},
|
||||
},
|
||||
"databaseParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"database_id": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []any{"database_id"},
|
||||
},
|
||||
"emoji": map[string]any{
|
||||
"type": "string",
|
||||
"pattern": "^:[a-z_]+:$",
|
||||
},
|
||||
"dataPayload": map[string]any{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"count": map[string]any{
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": []any{"name"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := SanitizeSchemaForGemini(schema)
|
||||
assertSchemaKeyAbsent(t, got, "$defs")
|
||||
assertSchemaKeyAbsent(t, got, "$ref")
|
||||
assertSchemaKeyAbsent(t, got, "anyOf")
|
||||
assertSchemaKeyAbsent(t, got, "oneOf")
|
||||
assertSchemaKeyAbsent(t, got, "allOf")
|
||||
assertSchemaKeyAbsent(t, got, "additionalProperties")
|
||||
assertSchemaKeyAbsent(t, got, "pattern")
|
||||
assertSchemaKeyAbsent(t, got, "minLength")
|
||||
assertSchemaKeyAbsent(t, got, "minimum")
|
||||
|
||||
if got["type"] != "object" {
|
||||
t.Fatalf("top-level type = %#v, want object", got["type"])
|
||||
}
|
||||
|
||||
props, ok := got["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("properties = %#v, want map", got["properties"])
|
||||
}
|
||||
|
||||
parent, ok := props["parent"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parent schema = %#v, want map", props["parent"])
|
||||
}
|
||||
if parent["type"] != "object" {
|
||||
t.Fatalf("parent.type = %#v, want object", parent["type"])
|
||||
}
|
||||
parentProps, ok := parent["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parent.properties = %#v, want map", parent["properties"])
|
||||
}
|
||||
if _, ok := parentProps["page_id"]; !ok {
|
||||
t.Fatalf("parent.properties missing page_id: %#v", parentProps)
|
||||
}
|
||||
if _, ok := parentProps["database_id"]; !ok {
|
||||
t.Fatalf("parent.properties missing database_id: %#v", parentProps)
|
||||
}
|
||||
if _, hasRequired := parent["required"]; hasRequired {
|
||||
t.Fatalf("parent.required = %#v, want omitted for merged anyOf branches", parent["required"])
|
||||
}
|
||||
|
||||
icon, ok := props["icon"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("icon schema = %#v, want map", props["icon"])
|
||||
}
|
||||
if icon["type"] != "string" {
|
||||
t.Fatalf("icon.type = %#v, want string", icon["type"])
|
||||
}
|
||||
|
||||
data, ok := props["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("data schema = %#v, want map", props["data"])
|
||||
}
|
||||
if data["type"] != "object" {
|
||||
t.Fatalf("data.type = %#v, want object", data["type"])
|
||||
}
|
||||
dataProps, ok := data["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("data.properties = %#v, want map", data["properties"])
|
||||
}
|
||||
if _, ok := dataProps["name"]; !ok {
|
||||
t.Fatalf("data.properties missing name: %#v", dataProps)
|
||||
}
|
||||
if _, ok := dataProps["count"]; !ok {
|
||||
t.Fatalf("data.properties missing count: %#v", dataProps)
|
||||
}
|
||||
|
||||
required, ok := got["required"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("required = %#v, want []string", got["required"])
|
||||
}
|
||||
if len(required) != 2 || required[0] != "parent" || required[1] != "icon" {
|
||||
t.Fatalf("required = %#v, want [parent icon]", required)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSchemaForGemini_MergesAllOfAndFiltersRequired(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"payload": map[string]any{
|
||||
"allOf": []any{
|
||||
map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"id": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []any{"id"},
|
||||
},
|
||||
map[string]any{
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
"count": map[string]any{
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": []any{"name", "missing"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := SanitizeSchemaForGemini(schema)
|
||||
props := got["properties"].(map[string]any)
|
||||
payload := props["payload"].(map[string]any)
|
||||
|
||||
if payload["type"] != "object" {
|
||||
t.Fatalf("payload.type = %#v, want object", payload["type"])
|
||||
}
|
||||
payloadProps, ok := payload["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("payload.properties = %#v, want map", payload["properties"])
|
||||
}
|
||||
for _, key := range []string{"id", "name", "count"} {
|
||||
if _, ok := payloadProps[key]; !ok {
|
||||
t.Fatalf("payload.properties missing %q: %#v", key, payloadProps)
|
||||
}
|
||||
}
|
||||
|
||||
required, ok := payload["required"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("payload.required = %#v, want []string", payload["required"])
|
||||
}
|
||||
if len(required) != 2 || required[0] != "id" || required[1] != "name" {
|
||||
t.Fatalf("payload.required = %#v, want [id name]", required)
|
||||
}
|
||||
|
||||
assertSchemaKeyAbsent(t, payload, "allOf")
|
||||
assertSchemaKeyAbsent(t, payload, "minimum")
|
||||
}
|
||||
|
||||
func TestSanitizeSchemaForGemini_HandlesRecursiveRefs(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"tree": map[string]any{
|
||||
"$ref": "#/$defs/node",
|
||||
},
|
||||
},
|
||||
"$defs": map[string]any{
|
||||
"node": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
"child": map[string]any{
|
||||
"$ref": "#/$defs/node",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := SanitizeSchemaForGemini(schema)
|
||||
props := got["properties"].(map[string]any)
|
||||
tree := props["tree"].(map[string]any)
|
||||
if tree["type"] != "object" {
|
||||
t.Fatalf("tree.type = %#v, want object", tree["type"])
|
||||
}
|
||||
assertSchemaKeyAbsent(t, tree, "$ref")
|
||||
}
|
||||
|
||||
func assertSchemaKeyAbsent(t *testing.T, value any, key string) {
|
||||
t.Helper()
|
||||
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
if _, found := typed[key]; found {
|
||||
t.Fatalf("schema still contains key %q: %#v", key, typed)
|
||||
}
|
||||
for _, nested := range typed {
|
||||
assertSchemaKeyAbsent(t, nested, key)
|
||||
}
|
||||
case []any:
|
||||
for _, nested := range typed {
|
||||
assertSchemaKeyAbsent(t, nested, key)
|
||||
}
|
||||
case []string:
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -12,66 +12,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake
|
||||
return ""
|
||||
}
|
||||
|
||||
var geminiUnsupportedKeywords = map[string]bool{
|
||||
"patternProperties": true,
|
||||
"additionalProperties": true,
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
"examples": true,
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"multipleOf": true,
|
||||
"pattern": true,
|
||||
"format": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"uniqueItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
}
|
||||
|
||||
func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
for k, v := range schema {
|
||||
if geminiUnsupportedKeywords[k] {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
result[k] = sanitizeSchemaForGemini(val)
|
||||
case []any:
|
||||
sanitized := make([]any, len(val))
|
||||
for i, item := range val {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
sanitized[i] = sanitizeSchemaForGemini(m)
|
||||
} else {
|
||||
sanitized[i] = item
|
||||
}
|
||||
}
|
||||
result[k] = sanitized
|
||||
default:
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if _, hasProps := result["properties"]; hasProps {
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "object"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func extractProtocol(model string) (protocol, modelID string) {
|
||||
model = strings.TrimSpace(model)
|
||||
protocol, modelID, found := strings.Cut(model, "/")
|
||||
|
||||
@@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody(
|
||||
funcDecls = append(funcDecls, geminiFunctionDeclaration{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
|
||||
Parameters: common.SanitizeSchemaForGemini(t.Function.Parameters),
|
||||
})
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
|
||||
@@ -259,6 +262,65 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"parent": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/pageParent"},
|
||||
map[string]any{"$ref": "#/$defs/databaseParent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"$defs": map[string]any{
|
||||
"pageParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"page_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"page_id"},
|
||||
},
|
||||
"databaseParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"database_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"database_id"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
[]ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "mcp_notion_create",
|
||||
Description: "Create a Notion object",
|
||||
Parameters: schema,
|
||||
},
|
||||
}},
|
||||
"gemini-3-flash-preview",
|
||||
nil,
|
||||
)
|
||||
|
||||
tools, ok := body["tools"].([]geminiTool)
|
||||
if !ok || len(tools) != 1 {
|
||||
t.Fatalf("tools = %#v, want one geminiTool", body["tools"])
|
||||
}
|
||||
got, ok := tools[0].FunctionDeclarations[0].Parameters.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters)
|
||||
}
|
||||
|
||||
want := providercommon.SanitizeSchemaForGemini(schema)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
@@ -298,7 +298,7 @@ func (p *AntigravityProvider) buildRequest(
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
params := sanitizeSchemaForGemini(t.Function.Parameters)
|
||||
params := common.SanitizeSchemaForGemini(t.Function.Parameters)
|
||||
funcDecls = append(funcDecls, antigravityFuncDecl{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
@@ -446,71 +446,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake
|
||||
return ""
|
||||
}
|
||||
|
||||
// --- Schema sanitization ---
|
||||
|
||||
// Google/Gemini doesn't support many JSON Schema keywords that other providers accept.
|
||||
var geminiUnsupportedKeywords = map[string]bool{
|
||||
"patternProperties": true,
|
||||
"additionalProperties": true,
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
"examples": true,
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"multipleOf": true,
|
||||
"pattern": true,
|
||||
"format": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"uniqueItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
}
|
||||
|
||||
func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
for k, v := range schema {
|
||||
if geminiUnsupportedKeywords[k] {
|
||||
continue
|
||||
}
|
||||
// Recursively sanitize nested objects
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
result[k] = sanitizeSchemaForGemini(val)
|
||||
case []any:
|
||||
sanitized := make([]any, len(val))
|
||||
for i, item := range val {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
sanitized[i] = sanitizeSchemaForGemini(m)
|
||||
} else {
|
||||
sanitized[i] = item
|
||||
}
|
||||
}
|
||||
result[k] = sanitized
|
||||
default:
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure top-level has type: "object" if properties are present
|
||||
if _, hasProps := result["properties"]; hasProps {
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "object"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// --- Token source ---
|
||||
|
||||
func createAntigravityTokenSource() func() (string, string, error) {
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package oauthprovider
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) {
|
||||
p := &AntigravityProvider{}
|
||||
@@ -71,3 +76,68 @@ func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
|
||||
t.Fatalf("Usage.TotalTokens = %v, want %d", resp.Usage, 216)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) {
|
||||
p := &AntigravityProvider{}
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"parent": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/pageParent"},
|
||||
map[string]any{"$ref": "#/$defs/databaseParent"},
|
||||
},
|
||||
},
|
||||
"icon": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"type": "null"},
|
||||
map[string]any{"$ref": "#/$defs/emoji"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"$defs": map[string]any{
|
||||
"pageParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"page_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"page_id"},
|
||||
},
|
||||
"databaseParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"database_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"database_id"},
|
||||
},
|
||||
"emoji": map[string]any{
|
||||
"type": "string",
|
||||
"pattern": "^:[a-z_]+:$",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := p.buildRequest(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
[]ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "mcp_notion_create",
|
||||
Description: "Create a Notion object",
|
||||
Parameters: schema,
|
||||
},
|
||||
}},
|
||||
"gemini-3-flash",
|
||||
nil,
|
||||
)
|
||||
|
||||
if len(req.Tools) != 1 || len(req.Tools[0].FunctionDeclarations) != 1 {
|
||||
t.Fatalf("request tools = %#v, want one function declaration", req.Tools)
|
||||
}
|
||||
|
||||
got := req.Tools[0].FunctionDeclarations[0].Parameters
|
||||
want := providercommon.SanitizeSchemaForGemini(schema)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user