From 1ff8a418f65e36b46c39e3338568ec2e1f0baca7 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 26 Apr 2026 22:23:55 +0200 Subject: [PATCH] fix(mcp): sanitize MCP tool schemas for Gemini function calling --- pkg/providers/common/google_schema.go | 636 ++++++++++++++++++ pkg/providers/common/google_schema_test.go | 254 +++++++ pkg/providers/httpapi/gemini_helpers.go | 60 -- pkg/providers/httpapi/gemini_provider.go | 2 +- pkg/providers/httpapi/gemini_provider_test.go | 62 ++ pkg/providers/oauth/antigravity_provider.go | 67 +- .../oauth/antigravity_provider_test.go | 72 +- 7 files changed, 1025 insertions(+), 128 deletions(-) create mode 100644 pkg/providers/common/google_schema.go create mode 100644 pkg/providers/common/google_schema_test.go diff --git a/pkg/providers/common/google_schema.go b/pkg/providers/common/google_schema.go new file mode 100644 index 000000000..bdafadd23 --- /dev/null +++ b/pkg/providers/common/google_schema.go @@ -0,0 +1,636 @@ +package common + +import ( + "strconv" + "strings" +) + +const maxGeminiSchemaDepth = 64 + +var geminiSupportedTypes = map[string]bool{ + "array": true, + "boolean": true, + "integer": true, + "number": true, + "object": true, + "string": true, +} + +// SanitizeSchemaForGemini reduces a JSON Schema to the conservative subset +// accepted by Gemini-style function declarations. It resolves local refs, +// collapses composition keywords like anyOf/oneOf/allOf, and strips advanced +// keywords that Gemini rejects. +func SanitizeSchemaForGemini(schema map[string]any) map[string]any { + if schema == nil { + return nil + } + + sanitizer := geminiSchemaSanitizer{root: schema} + result := sanitizer.sanitizeNode(schema, nil, 0) + if len(result) == 0 { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + if _, hasProps := result["properties"]; hasProps { + result["type"] = "object" + } + return result +} + +type geminiSchemaSanitizer struct { + root map[string]any +} + +func (s geminiSchemaSanitizer) sanitizeNode( + node map[string]any, + refTrail map[string]struct{}, + depth int, +) map[string]any { + if node == nil || depth > maxGeminiSchemaDepth { + return map[string]any{} + } + + normalized := s.normalizeNode(node, refTrail, depth) + if len(normalized) == 0 { + return map[string]any{} + } + + result := make(map[string]any) + + if desc, ok := normalized["description"].(string); ok && strings.TrimSpace(desc) != "" { + result["description"] = desc + } + + if schemaType := sanitizeGeminiSchemaType(normalized["type"]); schemaType != "" { + result["type"] = schemaType + } + + if enumValues := sanitizeGeminiEnum(normalized["enum"]); len(enumValues) > 0 { + result["enum"] = enumValues + } + + if propsRaw, ok := normalized["properties"].(map[string]any); ok { + props := make(map[string]any, len(propsRaw)) + for name, rawProp := range propsRaw { + propSchema, ok := rawProp.(map[string]any) + if !ok { + continue + } + sanitizedProp := s.sanitizeNode(propSchema, refTrail, depth+1) + if len(sanitizedProp) == 0 { + sanitizedProp = map[string]any{} + } + props[name] = sanitizedProp + } + result["properties"] = props + result["type"] = "object" + if required := sanitizeGeminiRequired(normalized["required"], props); len(required) > 0 { + result["required"] = required + } + } + + if itemsRaw, ok := normalized["items"].(map[string]any); ok { + items := s.sanitizeNode(itemsRaw, refTrail, depth+1) + if len(items) == 0 { + items = map[string]any{} + } + result["items"] = items + if _, hasType := result["type"]; !hasType { + result["type"] = "array" + } + } + + return result +} + +func (s geminiSchemaSanitizer) normalizeNode( + node map[string]any, + refTrail map[string]struct{}, + depth int, +) map[string]any { + if node == nil || depth > maxGeminiSchemaDepth { + return map[string]any{} + } + + normalized := cloneGeminiSchemaMap(node) + + if ref, ok := normalized["$ref"].(string); ok { + delete(normalized, "$ref") + if _, seen := refTrail[ref]; !seen { + if target, ok := s.resolveLocalSchemaRef(ref); ok { + nextTrail := cloneRefTrail(refTrail) + nextTrail[ref] = struct{}{} + normalized = mergeGeminiSchemaMaps( + s.normalizeNode(target, nextTrail, depth+1), + normalized, + ) + } + } + } + + if rawAllOf, ok := normalized["allOf"]; ok { + delete(normalized, "allOf") + for _, part := range schemaSlice(rawAllOf) { + normalized = mergeGeminiSchemaMaps( + normalized, + s.normalizeNode(part, refTrail, depth+1), + ) + } + } + + if rawAnyOf, ok := normalized["anyOf"]; ok { + delete(normalized, "anyOf") + normalized = mergeGeminiSchemaMaps( + s.mergeUnionBranches(schemaSlice(rawAnyOf), refTrail, depth+1), + normalized, + ) + } + + if rawOneOf, ok := normalized["oneOf"]; ok { + delete(normalized, "oneOf") + normalized = mergeGeminiSchemaMaps( + s.mergeUnionBranches(schemaSlice(rawOneOf), refTrail, depth+1), + normalized, + ) + } + + return normalized +} + +func (s geminiSchemaSanitizer) mergeUnionBranches( + branches []map[string]any, + refTrail map[string]struct{}, + depth int, +) map[string]any { + if len(branches) == 0 { + return map[string]any{} + } + + objectBranches := make([]map[string]any, 0, len(branches)) + arrayBranches := make([]map[string]any, 0, len(branches)) + nonNullBranches := make([]map[string]any, 0, len(branches)) + sameType := "" + sameTypeConsistent := true + + for _, branch := range branches { + normalized := s.normalizeNode(branch, refTrail, depth+1) + if len(normalized) == 0 { + continue + } + + branchType := geminiSchemaBranchType(normalized["type"]) + if branchType == "null" { + continue + } + nonNullBranches = append(nonNullBranches, normalized) + + if sameType == "" { + sameType = branchType + } else if branchType != "" && branchType != sameType { + sameTypeConsistent = false + } + + if branchType == "object" || hasSchemaProperties(normalized) { + objectBranches = append(objectBranches, normalized) + continue + } + if branchType == "array" || hasSchemaItems(normalized) { + arrayBranches = append(arrayBranches, normalized) + } + } + + if len(nonNullBranches) == 0 { + return map[string]any{} + } + if len(objectBranches) > 0 { + return mergeUnionObjectSchemas(objectBranches) + } + if len(arrayBranches) == len(nonNullBranches) && len(arrayBranches) > 0 { + return mergeUnionArraySchemas(arrayBranches) + } + if sameTypeConsistent && sameType != "" { + merged := map[string]any{} + for _, branch := range nonNullBranches { + merged = mergeGeminiSchemaMaps(merged, branch) + } + return merged + } + + best := nonNullBranches[0] + bestScore := geminiUnionBranchScore(best) + for _, branch := range nonNullBranches[1:] { + if score := geminiUnionBranchScore(branch); score > bestScore { + best = branch + bestScore = score + } + } + return cloneGeminiSchemaMap(best) +} + +func (s geminiSchemaSanitizer) resolveLocalSchemaRef(ref string) (map[string]any, bool) { + if ref == "#" { + return s.root, true + } + if !strings.HasPrefix(ref, "#/") { + return nil, false + } + + var current any = s.root + for _, rawToken := range strings.Split(strings.TrimPrefix(ref, "#/"), "/") { + token := strings.ReplaceAll(strings.ReplaceAll(rawToken, "~1", "/"), "~0", "~") + switch value := current.(type) { + case map[string]any: + next, ok := value[token] + if !ok { + return nil, false + } + current = next + case []any: + index, err := strconv.Atoi(token) + if err != nil || index < 0 || index >= len(value) { + return nil, false + } + current = value[index] + default: + return nil, false + } + } + + resolved, ok := current.(map[string]any) + return resolved, ok +} + +func mergeUnionObjectSchemas(branches []map[string]any) map[string]any { + merged := map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + + var commonRequired map[string]struct{} + var requiredOrder []string + + for i, branch := range branches { + merged = mergeGeminiSchemaMaps(merged, branch) + + required := requiredStrings(branch["required"]) + if i == 0 { + commonRequired = make(map[string]struct{}, len(required)) + requiredOrder = append(requiredOrder, required...) + for _, name := range required { + commonRequired[name] = struct{}{} + } + continue + } + + current := make(map[string]struct{}, len(required)) + for _, name := range required { + current[name] = struct{}{} + } + for name := range commonRequired { + if _, ok := current[name]; !ok { + delete(commonRequired, name) + } + } + } + + if len(commonRequired) > 0 { + filtered := make([]string, 0, len(commonRequired)) + for _, name := range requiredOrder { + if _, ok := commonRequired[name]; ok { + filtered = append(filtered, name) + } + } + if len(filtered) > 0 { + merged["required"] = filtered + } + } else { + delete(merged, "required") + } + + return merged +} + +func mergeUnionArraySchemas(branches []map[string]any) map[string]any { + merged := map[string]any{ + "type": "array", + } + for _, branch := range branches { + merged = mergeGeminiSchemaMaps(merged, branch) + } + return merged +} + +func mergeGeminiSchemaMaps(base map[string]any, overlay map[string]any) map[string]any { + if len(base) == 0 { + return cloneGeminiSchemaMap(overlay) + } + if len(overlay) == 0 { + return cloneGeminiSchemaMap(base) + } + + result := cloneGeminiSchemaMap(base) + for key, value := range overlay { + switch key { + case "properties": + overlayProps, ok := value.(map[string]any) + if !ok { + continue + } + existing, _ := result["properties"].(map[string]any) + mergedProps := cloneGeminiSchemaMap(existing) + if mergedProps == nil { + mergedProps = make(map[string]any, len(overlayProps)) + } + for name, rawProp := range overlayProps { + propSchema, ok := rawProp.(map[string]any) + if !ok { + continue + } + if existingProp, ok := mergedProps[name].(map[string]any); ok { + mergedProps[name] = mergeGeminiSchemaMaps(existingProp, propSchema) + } else { + mergedProps[name] = cloneGeminiSchemaMap(propSchema) + } + } + result["properties"] = mergedProps + case "items": + overlayItems, ok := value.(map[string]any) + if !ok { + continue + } + if existingItems, ok := result["items"].(map[string]any); ok { + result["items"] = mergeGeminiSchemaMaps(existingItems, overlayItems) + } else { + result["items"] = cloneGeminiSchemaMap(overlayItems) + } + case "required": + if merged := mergeRequiredLists(result["required"], value); len(merged) > 0 { + result["required"] = merged + } + case "type": + if mergedType := mergeGeminiSchemaTypes(result["type"], value); mergedType != "" { + result["type"] = mergedType + } else { + delete(result, "type") + } + case "description": + desc, ok := value.(string) + if ok && strings.TrimSpace(desc) != "" { + result["description"] = desc + } + default: + result[key] = cloneGeminiSchemaValue(value) + } + } + + return result +} + +func mergeGeminiSchemaTypes(left any, right any) string { + leftType := geminiSchemaBranchType(left) + rightType := geminiSchemaBranchType(right) + + switch { + case leftType == "": + return rightType + case rightType == "": + return leftType + case leftType == rightType: + return leftType + case leftType == "null": + return rightType + case rightType == "null": + return leftType + default: + return "" + } +} + +func sanitizeGeminiSchemaType(raw any) string { + typeName := geminiSchemaBranchType(raw) + if typeName == "null" { + return "" + } + return typeName +} + +func geminiSchemaBranchType(raw any) string { + switch value := raw.(type) { + case string: + if value == "null" { + return value + } + if geminiSupportedTypes[value] { + return value + } + return "" + case []string: + return geminiSchemaBranchType(stringSliceToAny(value)) + case []any: + candidate := "" + sawNull := false + for _, item := range value { + typeName, ok := item.(string) + if !ok { + continue + } + if typeName == "null" { + sawNull = true + continue + } + if !geminiSupportedTypes[typeName] { + continue + } + if candidate == "" { + candidate = typeName + continue + } + if candidate != typeName { + return "" + } + } + if candidate == "" && sawNull { + return "null" + } + return candidate + default: + return "" + } +} + +func sanitizeGeminiEnum(raw any) []any { + values, ok := raw.([]any) + if !ok { + if stringValues, ok := raw.([]string); ok { + return stringSliceToAny(stringValues) + } + return nil + } + + sanitized := make([]any, 0, len(values)) + for _, value := range values { + switch value.(type) { + case string, bool, float64, int, int32, int64: + sanitized = append(sanitized, value) + } + } + if len(sanitized) == 0 { + return nil + } + return sanitized +} + +func sanitizeGeminiRequired(raw any, properties map[string]any) []string { + required := requiredStrings(raw) + if len(required) == 0 { + return nil + } + + filtered := make([]string, 0, len(required)) + seen := make(map[string]struct{}, len(required)) + for _, name := range required { + if _, ok := properties[name]; !ok { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + filtered = append(filtered, name) + } + if len(filtered) == 0 { + return nil + } + return filtered +} + +func requiredStrings(raw any) []string { + switch value := raw.(type) { + case []string: + return append([]string(nil), value...) + case []any: + required := make([]string, 0, len(value)) + for _, item := range value { + name, ok := item.(string) + if ok { + required = append(required, name) + } + } + return required + default: + return nil + } +} + +func mergeRequiredLists(left any, right any) []string { + merged := make([]string, 0) + seen := map[string]struct{}{} + + for _, name := range append(requiredStrings(left), requiredStrings(right)...) { + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + merged = append(merged, name) + } + + return merged +} + +func geminiUnionBranchScore(schema map[string]any) int { + score := 0 + if hasSchemaProperties(schema) { + score += 20 + } + if hasSchemaItems(schema) { + score += 10 + } + if _, ok := schema["enum"]; ok { + score += 5 + } + if _, ok := schema["description"]; ok { + score += 2 + } + score += len(schema) + return score +} + +func hasSchemaProperties(schema map[string]any) bool { + props, ok := schema["properties"].(map[string]any) + return ok && len(props) > 0 +} + +func hasSchemaItems(schema map[string]any) bool { + _, ok := schema["items"].(map[string]any) + return ok +} + +func schemaSlice(raw any) []map[string]any { + switch value := raw.(type) { + case []map[string]any: + return append([]map[string]any(nil), value...) + case []any: + schemas := make([]map[string]any, 0, len(value)) + for _, item := range value { + schema, ok := item.(map[string]any) + if ok { + schemas = append(schemas, schema) + } + } + return schemas + default: + return nil + } +} + +func cloneGeminiSchemaMap(in map[string]any) map[string]any { + if len(in) == 0 { + return nil + } + out := make(map[string]any, len(in)) + for key, value := range in { + out[key] = cloneGeminiSchemaValue(value) + } + return out +} + +func cloneGeminiSchemaValue(value any) any { + switch typed := value.(type) { + case map[string]any: + return cloneGeminiSchemaMap(typed) + case []any: + out := make([]any, len(typed)) + for i, item := range typed { + out[i] = cloneGeminiSchemaValue(item) + } + return out + case []string: + return append([]string(nil), typed...) + default: + return typed + } +} + +func cloneRefTrail(in map[string]struct{}) map[string]struct{} { + if len(in) == 0 { + return make(map[string]struct{}) + } + out := make(map[string]struct{}, len(in)) + for key := range in { + out[key] = struct{}{} + } + return out +} + +func stringSliceToAny(values []string) []any { + if len(values) == 0 { + return nil + } + result := make([]any, len(values)) + for i, value := range values { + result[i] = value + } + return result +} diff --git a/pkg/providers/common/google_schema_test.go b/pkg/providers/common/google_schema_test.go new file mode 100644 index 000000000..af2520493 --- /dev/null +++ b/pkg/providers/common/google_schema_test.go @@ -0,0 +1,254 @@ +package common + +import "testing" + +func TestSanitizeSchemaForGemini_DereferencesRefsAndFlattensUnions(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "parent": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/pageParent"}, + map[string]any{"$ref": "#/$defs/databaseParent"}, + }, + }, + "icon": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/emoji"}, + map[string]any{"type": "null"}, + }, + }, + "data": map[string]any{ + "$ref": "#/$defs/dataPayload", + }, + }, + "required": []any{"parent", "icon", "missing"}, + "$defs": map[string]any{ + "pageParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "page_id": map[string]any{ + "type": "string", + }, + }, + "required": []any{"page_id"}, + }, + "databaseParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "database_id": map[string]any{ + "type": "string", + }, + }, + "required": []any{"database_id"}, + }, + "emoji": map[string]any{ + "type": "string", + "pattern": "^:[a-z_]+:$", + }, + "dataPayload": map[string]any{ + "type": "object", + "additionalProperties": false, + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "minLength": 1, + }, + "count": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "required": []any{"name"}, + }, + }, + } + + got := SanitizeSchemaForGemini(schema) + assertSchemaKeyAbsent(t, got, "$defs") + assertSchemaKeyAbsent(t, got, "$ref") + assertSchemaKeyAbsent(t, got, "anyOf") + assertSchemaKeyAbsent(t, got, "oneOf") + assertSchemaKeyAbsent(t, got, "allOf") + assertSchemaKeyAbsent(t, got, "additionalProperties") + assertSchemaKeyAbsent(t, got, "pattern") + assertSchemaKeyAbsent(t, got, "minLength") + assertSchemaKeyAbsent(t, got, "minimum") + + if got["type"] != "object" { + t.Fatalf("top-level type = %#v, want object", got["type"]) + } + + props, ok := got["properties"].(map[string]any) + if !ok { + t.Fatalf("properties = %#v, want map", got["properties"]) + } + + parent, ok := props["parent"].(map[string]any) + if !ok { + t.Fatalf("parent schema = %#v, want map", props["parent"]) + } + if parent["type"] != "object" { + t.Fatalf("parent.type = %#v, want object", parent["type"]) + } + parentProps, ok := parent["properties"].(map[string]any) + if !ok { + t.Fatalf("parent.properties = %#v, want map", parent["properties"]) + } + if _, ok := parentProps["page_id"]; !ok { + t.Fatalf("parent.properties missing page_id: %#v", parentProps) + } + if _, ok := parentProps["database_id"]; !ok { + t.Fatalf("parent.properties missing database_id: %#v", parentProps) + } + if _, hasRequired := parent["required"]; hasRequired { + t.Fatalf("parent.required = %#v, want omitted for merged anyOf branches", parent["required"]) + } + + icon, ok := props["icon"].(map[string]any) + if !ok { + t.Fatalf("icon schema = %#v, want map", props["icon"]) + } + if icon["type"] != "string" { + t.Fatalf("icon.type = %#v, want string", icon["type"]) + } + + data, ok := props["data"].(map[string]any) + if !ok { + t.Fatalf("data schema = %#v, want map", props["data"]) + } + if data["type"] != "object" { + t.Fatalf("data.type = %#v, want object", data["type"]) + } + dataProps, ok := data["properties"].(map[string]any) + if !ok { + t.Fatalf("data.properties = %#v, want map", data["properties"]) + } + if _, ok := dataProps["name"]; !ok { + t.Fatalf("data.properties missing name: %#v", dataProps) + } + if _, ok := dataProps["count"]; !ok { + t.Fatalf("data.properties missing count: %#v", dataProps) + } + + required, ok := got["required"].([]string) + if !ok { + t.Fatalf("required = %#v, want []string", got["required"]) + } + if len(required) != 2 || required[0] != "parent" || required[1] != "icon" { + t.Fatalf("required = %#v, want [parent icon]", required) + } +} + +func TestSanitizeSchemaForGemini_MergesAllOfAndFiltersRequired(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "payload": map[string]any{ + "allOf": []any{ + map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + }, + }, + "required": []any{"id"}, + }, + map[string]any{ + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + "count": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "required": []any{"name", "missing"}, + }, + }, + }, + }, + } + + got := SanitizeSchemaForGemini(schema) + props := got["properties"].(map[string]any) + payload := props["payload"].(map[string]any) + + if payload["type"] != "object" { + t.Fatalf("payload.type = %#v, want object", payload["type"]) + } + payloadProps, ok := payload["properties"].(map[string]any) + if !ok { + t.Fatalf("payload.properties = %#v, want map", payload["properties"]) + } + for _, key := range []string{"id", "name", "count"} { + if _, ok := payloadProps[key]; !ok { + t.Fatalf("payload.properties missing %q: %#v", key, payloadProps) + } + } + + required, ok := payload["required"].([]string) + if !ok { + t.Fatalf("payload.required = %#v, want []string", payload["required"]) + } + if len(required) != 2 || required[0] != "id" || required[1] != "name" { + t.Fatalf("payload.required = %#v, want [id name]", required) + } + + assertSchemaKeyAbsent(t, payload, "allOf") + assertSchemaKeyAbsent(t, payload, "minimum") +} + +func TestSanitizeSchemaForGemini_HandlesRecursiveRefs(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "tree": map[string]any{ + "$ref": "#/$defs/node", + }, + }, + "$defs": map[string]any{ + "node": map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + "child": map[string]any{ + "$ref": "#/$defs/node", + }, + }, + }, + }, + } + + got := SanitizeSchemaForGemini(schema) + props := got["properties"].(map[string]any) + tree := props["tree"].(map[string]any) + if tree["type"] != "object" { + t.Fatalf("tree.type = %#v, want object", tree["type"]) + } + assertSchemaKeyAbsent(t, tree, "$ref") +} + +func assertSchemaKeyAbsent(t *testing.T, value any, key string) { + t.Helper() + + switch typed := value.(type) { + case map[string]any: + if _, found := typed[key]; found { + t.Fatalf("schema still contains key %q: %#v", key, typed) + } + for _, nested := range typed { + assertSchemaKeyAbsent(t, nested, key) + } + case []any: + for _, nested := range typed { + assertSchemaKeyAbsent(t, nested, key) + } + case []string: + return + } +} diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index a2b2d63c3..87cc4c084 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -12,66 +12,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake return "" } -var geminiUnsupportedKeywords = map[string]bool{ - "patternProperties": true, - "additionalProperties": true, - "$schema": true, - "$id": true, - "$ref": true, - "$defs": true, - "definitions": true, - "examples": true, - "minLength": true, - "maxLength": true, - "minimum": true, - "maximum": true, - "multipleOf": true, - "pattern": true, - "format": true, - "minItems": true, - "maxItems": true, - "uniqueItems": true, - "minProperties": true, - "maxProperties": true, -} - -func sanitizeSchemaForGemini(schema map[string]any) map[string]any { - if schema == nil { - return nil - } - - result := make(map[string]any) - for k, v := range schema { - if geminiUnsupportedKeywords[k] { - continue - } - switch val := v.(type) { - case map[string]any: - result[k] = sanitizeSchemaForGemini(val) - case []any: - sanitized := make([]any, len(val)) - for i, item := range val { - if m, ok := item.(map[string]any); ok { - sanitized[i] = sanitizeSchemaForGemini(m) - } else { - sanitized[i] = item - } - } - result[k] = sanitized - default: - result[k] = v - } - } - - if _, hasProps := result["properties"]; hasProps { - if _, hasType := result["type"]; !hasType { - result["type"] = "object" - } - } - - return result -} - func extractProtocol(model string) (protocol, modelID string) { model = strings.TrimSpace(model) protocol, modelID, found := strings.Cut(model, "/") diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index d1d523757..e73568c88 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody( funcDecls = append(funcDecls, geminiFunctionDeclaration{ Name: t.Function.Name, Description: t.Function.Description, - Parameters: sanitizeSchemaForGemini(t.Function.Parameters), + Parameters: common.SanitizeSchemaForGemini(t.Function.Parameters), }) } if len(funcDecls) > 0 { diff --git a/pkg/providers/httpapi/gemini_provider_test.go b/pkg/providers/httpapi/gemini_provider_test.go index aade90358..53770511e 100644 --- a/pkg/providers/httpapi/gemini_provider_test.go +++ b/pkg/providers/httpapi/gemini_provider_test.go @@ -5,8 +5,11 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "strings" "testing" + + providercommon "github.com/sipeed/picoclaw/pkg/providers/common" ) func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) { @@ -259,6 +262,65 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) { } } +func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "parent": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/pageParent"}, + map[string]any{"$ref": "#/$defs/databaseParent"}, + }, + }, + }, + "$defs": map[string]any{ + "pageParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "page_id": map[string]any{"type": "string"}, + }, + "required": []any{"page_id"}, + }, + "databaseParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "database_id": map[string]any{"type": "string"}, + }, + "required": []any{"database_id"}, + }, + }, + } + + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "mcp_notion_create", + Description: "Create a Notion object", + Parameters: schema, + }, + }}, + "gemini-3-flash-preview", + nil, + ) + + tools, ok := body["tools"].([]geminiTool) + if !ok || len(tools) != 1 { + t.Fatalf("tools = %#v, want one geminiTool", body["tools"]) + } + got, ok := tools[0].FunctionDeclarations[0].Parameters.(map[string]any) + if !ok { + t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters) + } + + want := providercommon.SanitizeSchemaForGemini(schema) + if !reflect.DeepEqual(got, want) { + t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want) + } +} + func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") diff --git a/pkg/providers/oauth/antigravity_provider.go b/pkg/providers/oauth/antigravity_provider.go index 1ac2d9c7f..fd5431b03 100644 --- a/pkg/providers/oauth/antigravity_provider.go +++ b/pkg/providers/oauth/antigravity_provider.go @@ -298,7 +298,7 @@ func (p *AntigravityProvider) buildRequest( if t.Type != "function" { continue } - params := sanitizeSchemaForGemini(t.Function.Parameters) + params := common.SanitizeSchemaForGemini(t.Function.Parameters) funcDecls = append(funcDecls, antigravityFuncDecl{ Name: t.Function.Name, Description: t.Function.Description, @@ -446,71 +446,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake return "" } -// --- Schema sanitization --- - -// Google/Gemini doesn't support many JSON Schema keywords that other providers accept. -var geminiUnsupportedKeywords = map[string]bool{ - "patternProperties": true, - "additionalProperties": true, - "$schema": true, - "$id": true, - "$ref": true, - "$defs": true, - "definitions": true, - "examples": true, - "minLength": true, - "maxLength": true, - "minimum": true, - "maximum": true, - "multipleOf": true, - "pattern": true, - "format": true, - "minItems": true, - "maxItems": true, - "uniqueItems": true, - "minProperties": true, - "maxProperties": true, -} - -func sanitizeSchemaForGemini(schema map[string]any) map[string]any { - if schema == nil { - return nil - } - - result := make(map[string]any) - for k, v := range schema { - if geminiUnsupportedKeywords[k] { - continue - } - // Recursively sanitize nested objects - switch val := v.(type) { - case map[string]any: - result[k] = sanitizeSchemaForGemini(val) - case []any: - sanitized := make([]any, len(val)) - for i, item := range val { - if m, ok := item.(map[string]any); ok { - sanitized[i] = sanitizeSchemaForGemini(m) - } else { - sanitized[i] = item - } - } - result[k] = sanitized - default: - result[k] = v - } - } - - // Ensure top-level has type: "object" if properties are present - if _, hasProps := result["properties"]; hasProps { - if _, hasType := result["type"]; !hasType { - result["type"] = "object" - } - } - - return result -} - // --- Token source --- func createAntigravityTokenSource() func() (string, string, error) { diff --git a/pkg/providers/oauth/antigravity_provider_test.go b/pkg/providers/oauth/antigravity_provider_test.go index 2989f8519..015b1ab80 100644 --- a/pkg/providers/oauth/antigravity_provider_test.go +++ b/pkg/providers/oauth/antigravity_provider_test.go @@ -1,6 +1,11 @@ package oauthprovider -import "testing" +import ( + "reflect" + "testing" + + providercommon "github.com/sipeed/picoclaw/pkg/providers/common" +) func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) { p := &AntigravityProvider{} @@ -71,3 +76,68 @@ func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) { t.Fatalf("Usage.TotalTokens = %v, want %d", resp.Usage, 216) } } + +func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) { + p := &AntigravityProvider{} + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "parent": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/pageParent"}, + map[string]any{"$ref": "#/$defs/databaseParent"}, + }, + }, + "icon": map[string]any{ + "anyOf": []any{ + map[string]any{"type": "null"}, + map[string]any{"$ref": "#/$defs/emoji"}, + }, + }, + }, + "$defs": map[string]any{ + "pageParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "page_id": map[string]any{"type": "string"}, + }, + "required": []any{"page_id"}, + }, + "databaseParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "database_id": map[string]any{"type": "string"}, + }, + "required": []any{"database_id"}, + }, + "emoji": map[string]any{ + "type": "string", + "pattern": "^:[a-z_]+:$", + }, + }, + } + + req := p.buildRequest( + []Message{{Role: "user", Content: "hello"}}, + []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "mcp_notion_create", + Description: "Create a Notion object", + Parameters: schema, + }, + }}, + "gemini-3-flash", + nil, + ) + + if len(req.Tools) != 1 || len(req.Tools[0].FunctionDeclarations) != 1 { + t.Fatalf("request tools = %#v, want one function declaration", req.Tools) + } + + got := req.Tools[0].FunctionDeclarations[0].Parameters + want := providercommon.SanitizeSchemaForGemini(schema) + if !reflect.DeepEqual(got, want) { + t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want) + } +}