From 1ff8a418f65e36b46c39e3338568ec2e1f0baca7 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 26 Apr 2026 22:23:55 +0200 Subject: [PATCH 1/5] fix(mcp): sanitize MCP tool schemas for Gemini function calling --- pkg/providers/common/google_schema.go | 636 ++++++++++++++++++ pkg/providers/common/google_schema_test.go | 254 +++++++ pkg/providers/httpapi/gemini_helpers.go | 60 -- pkg/providers/httpapi/gemini_provider.go | 2 +- pkg/providers/httpapi/gemini_provider_test.go | 62 ++ pkg/providers/oauth/antigravity_provider.go | 67 +- .../oauth/antigravity_provider_test.go | 72 +- 7 files changed, 1025 insertions(+), 128 deletions(-) create mode 100644 pkg/providers/common/google_schema.go create mode 100644 pkg/providers/common/google_schema_test.go diff --git a/pkg/providers/common/google_schema.go b/pkg/providers/common/google_schema.go new file mode 100644 index 000000000..bdafadd23 --- /dev/null +++ b/pkg/providers/common/google_schema.go @@ -0,0 +1,636 @@ +package common + +import ( + "strconv" + "strings" +) + +const maxGeminiSchemaDepth = 64 + +var geminiSupportedTypes = map[string]bool{ + "array": true, + "boolean": true, + "integer": true, + "number": true, + "object": true, + "string": true, +} + +// SanitizeSchemaForGemini reduces a JSON Schema to the conservative subset +// accepted by Gemini-style function declarations. It resolves local refs, +// collapses composition keywords like anyOf/oneOf/allOf, and strips advanced +// keywords that Gemini rejects. +func SanitizeSchemaForGemini(schema map[string]any) map[string]any { + if schema == nil { + return nil + } + + sanitizer := geminiSchemaSanitizer{root: schema} + result := sanitizer.sanitizeNode(schema, nil, 0) + if len(result) == 0 { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + if _, hasProps := result["properties"]; hasProps { + result["type"] = "object" + } + return result +} + +type geminiSchemaSanitizer struct { + root map[string]any +} + +func (s geminiSchemaSanitizer) sanitizeNode( + node map[string]any, + refTrail map[string]struct{}, + depth int, +) map[string]any { + if node == nil || depth > maxGeminiSchemaDepth { + return map[string]any{} + } + + normalized := s.normalizeNode(node, refTrail, depth) + if len(normalized) == 0 { + return map[string]any{} + } + + result := make(map[string]any) + + if desc, ok := normalized["description"].(string); ok && strings.TrimSpace(desc) != "" { + result["description"] = desc + } + + if schemaType := sanitizeGeminiSchemaType(normalized["type"]); schemaType != "" { + result["type"] = schemaType + } + + if enumValues := sanitizeGeminiEnum(normalized["enum"]); len(enumValues) > 0 { + result["enum"] = enumValues + } + + if propsRaw, ok := normalized["properties"].(map[string]any); ok { + props := make(map[string]any, len(propsRaw)) + for name, rawProp := range propsRaw { + propSchema, ok := rawProp.(map[string]any) + if !ok { + continue + } + sanitizedProp := s.sanitizeNode(propSchema, refTrail, depth+1) + if len(sanitizedProp) == 0 { + sanitizedProp = map[string]any{} + } + props[name] = sanitizedProp + } + result["properties"] = props + result["type"] = "object" + if required := sanitizeGeminiRequired(normalized["required"], props); len(required) > 0 { + result["required"] = required + } + } + + if itemsRaw, ok := normalized["items"].(map[string]any); ok { + items := s.sanitizeNode(itemsRaw, refTrail, depth+1) + if len(items) == 0 { + items = map[string]any{} + } + result["items"] = items + if _, hasType := result["type"]; !hasType { + result["type"] = "array" + } + } + + return result +} + +func (s geminiSchemaSanitizer) normalizeNode( + node map[string]any, + refTrail map[string]struct{}, + depth int, +) map[string]any { + if node == nil || depth > maxGeminiSchemaDepth { + return map[string]any{} + } + + normalized := cloneGeminiSchemaMap(node) + + if ref, ok := normalized["$ref"].(string); ok { + delete(normalized, "$ref") + if _, seen := refTrail[ref]; !seen { + if target, ok := s.resolveLocalSchemaRef(ref); ok { + nextTrail := cloneRefTrail(refTrail) + nextTrail[ref] = struct{}{} + normalized = mergeGeminiSchemaMaps( + s.normalizeNode(target, nextTrail, depth+1), + normalized, + ) + } + } + } + + if rawAllOf, ok := normalized["allOf"]; ok { + delete(normalized, "allOf") + for _, part := range schemaSlice(rawAllOf) { + normalized = mergeGeminiSchemaMaps( + normalized, + s.normalizeNode(part, refTrail, depth+1), + ) + } + } + + if rawAnyOf, ok := normalized["anyOf"]; ok { + delete(normalized, "anyOf") + normalized = mergeGeminiSchemaMaps( + s.mergeUnionBranches(schemaSlice(rawAnyOf), refTrail, depth+1), + normalized, + ) + } + + if rawOneOf, ok := normalized["oneOf"]; ok { + delete(normalized, "oneOf") + normalized = mergeGeminiSchemaMaps( + s.mergeUnionBranches(schemaSlice(rawOneOf), refTrail, depth+1), + normalized, + ) + } + + return normalized +} + +func (s geminiSchemaSanitizer) mergeUnionBranches( + branches []map[string]any, + refTrail map[string]struct{}, + depth int, +) map[string]any { + if len(branches) == 0 { + return map[string]any{} + } + + objectBranches := make([]map[string]any, 0, len(branches)) + arrayBranches := make([]map[string]any, 0, len(branches)) + nonNullBranches := make([]map[string]any, 0, len(branches)) + sameType := "" + sameTypeConsistent := true + + for _, branch := range branches { + normalized := s.normalizeNode(branch, refTrail, depth+1) + if len(normalized) == 0 { + continue + } + + branchType := geminiSchemaBranchType(normalized["type"]) + if branchType == "null" { + continue + } + nonNullBranches = append(nonNullBranches, normalized) + + if sameType == "" { + sameType = branchType + } else if branchType != "" && branchType != sameType { + sameTypeConsistent = false + } + + if branchType == "object" || hasSchemaProperties(normalized) { + objectBranches = append(objectBranches, normalized) + continue + } + if branchType == "array" || hasSchemaItems(normalized) { + arrayBranches = append(arrayBranches, normalized) + } + } + + if len(nonNullBranches) == 0 { + return map[string]any{} + } + if len(objectBranches) > 0 { + return mergeUnionObjectSchemas(objectBranches) + } + if len(arrayBranches) == len(nonNullBranches) && len(arrayBranches) > 0 { + return mergeUnionArraySchemas(arrayBranches) + } + if sameTypeConsistent && sameType != "" { + merged := map[string]any{} + for _, branch := range nonNullBranches { + merged = mergeGeminiSchemaMaps(merged, branch) + } + return merged + } + + best := nonNullBranches[0] + bestScore := geminiUnionBranchScore(best) + for _, branch := range nonNullBranches[1:] { + if score := geminiUnionBranchScore(branch); score > bestScore { + best = branch + bestScore = score + } + } + return cloneGeminiSchemaMap(best) +} + +func (s geminiSchemaSanitizer) resolveLocalSchemaRef(ref string) (map[string]any, bool) { + if ref == "#" { + return s.root, true + } + if !strings.HasPrefix(ref, "#/") { + return nil, false + } + + var current any = s.root + for _, rawToken := range strings.Split(strings.TrimPrefix(ref, "#/"), "/") { + token := strings.ReplaceAll(strings.ReplaceAll(rawToken, "~1", "/"), "~0", "~") + switch value := current.(type) { + case map[string]any: + next, ok := value[token] + if !ok { + return nil, false + } + current = next + case []any: + index, err := strconv.Atoi(token) + if err != nil || index < 0 || index >= len(value) { + return nil, false + } + current = value[index] + default: + return nil, false + } + } + + resolved, ok := current.(map[string]any) + return resolved, ok +} + +func mergeUnionObjectSchemas(branches []map[string]any) map[string]any { + merged := map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + + var commonRequired map[string]struct{} + var requiredOrder []string + + for i, branch := range branches { + merged = mergeGeminiSchemaMaps(merged, branch) + + required := requiredStrings(branch["required"]) + if i == 0 { + commonRequired = make(map[string]struct{}, len(required)) + requiredOrder = append(requiredOrder, required...) + for _, name := range required { + commonRequired[name] = struct{}{} + } + continue + } + + current := make(map[string]struct{}, len(required)) + for _, name := range required { + current[name] = struct{}{} + } + for name := range commonRequired { + if _, ok := current[name]; !ok { + delete(commonRequired, name) + } + } + } + + if len(commonRequired) > 0 { + filtered := make([]string, 0, len(commonRequired)) + for _, name := range requiredOrder { + if _, ok := commonRequired[name]; ok { + filtered = append(filtered, name) + } + } + if len(filtered) > 0 { + merged["required"] = filtered + } + } else { + delete(merged, "required") + } + + return merged +} + +func mergeUnionArraySchemas(branches []map[string]any) map[string]any { + merged := map[string]any{ + "type": "array", + } + for _, branch := range branches { + merged = mergeGeminiSchemaMaps(merged, branch) + } + return merged +} + +func mergeGeminiSchemaMaps(base map[string]any, overlay map[string]any) map[string]any { + if len(base) == 0 { + return cloneGeminiSchemaMap(overlay) + } + if len(overlay) == 0 { + return cloneGeminiSchemaMap(base) + } + + result := cloneGeminiSchemaMap(base) + for key, value := range overlay { + switch key { + case "properties": + overlayProps, ok := value.(map[string]any) + if !ok { + continue + } + existing, _ := result["properties"].(map[string]any) + mergedProps := cloneGeminiSchemaMap(existing) + if mergedProps == nil { + mergedProps = make(map[string]any, len(overlayProps)) + } + for name, rawProp := range overlayProps { + propSchema, ok := rawProp.(map[string]any) + if !ok { + continue + } + if existingProp, ok := mergedProps[name].(map[string]any); ok { + mergedProps[name] = mergeGeminiSchemaMaps(existingProp, propSchema) + } else { + mergedProps[name] = cloneGeminiSchemaMap(propSchema) + } + } + result["properties"] = mergedProps + case "items": + overlayItems, ok := value.(map[string]any) + if !ok { + continue + } + if existingItems, ok := result["items"].(map[string]any); ok { + result["items"] = mergeGeminiSchemaMaps(existingItems, overlayItems) + } else { + result["items"] = cloneGeminiSchemaMap(overlayItems) + } + case "required": + if merged := mergeRequiredLists(result["required"], value); len(merged) > 0 { + result["required"] = merged + } + case "type": + if mergedType := mergeGeminiSchemaTypes(result["type"], value); mergedType != "" { + result["type"] = mergedType + } else { + delete(result, "type") + } + case "description": + desc, ok := value.(string) + if ok && strings.TrimSpace(desc) != "" { + result["description"] = desc + } + default: + result[key] = cloneGeminiSchemaValue(value) + } + } + + return result +} + +func mergeGeminiSchemaTypes(left any, right any) string { + leftType := geminiSchemaBranchType(left) + rightType := geminiSchemaBranchType(right) + + switch { + case leftType == "": + return rightType + case rightType == "": + return leftType + case leftType == rightType: + return leftType + case leftType == "null": + return rightType + case rightType == "null": + return leftType + default: + return "" + } +} + +func sanitizeGeminiSchemaType(raw any) string { + typeName := geminiSchemaBranchType(raw) + if typeName == "null" { + return "" + } + return typeName +} + +func geminiSchemaBranchType(raw any) string { + switch value := raw.(type) { + case string: + if value == "null" { + return value + } + if geminiSupportedTypes[value] { + return value + } + return "" + case []string: + return geminiSchemaBranchType(stringSliceToAny(value)) + case []any: + candidate := "" + sawNull := false + for _, item := range value { + typeName, ok := item.(string) + if !ok { + continue + } + if typeName == "null" { + sawNull = true + continue + } + if !geminiSupportedTypes[typeName] { + continue + } + if candidate == "" { + candidate = typeName + continue + } + if candidate != typeName { + return "" + } + } + if candidate == "" && sawNull { + return "null" + } + return candidate + default: + return "" + } +} + +func sanitizeGeminiEnum(raw any) []any { + values, ok := raw.([]any) + if !ok { + if stringValues, ok := raw.([]string); ok { + return stringSliceToAny(stringValues) + } + return nil + } + + sanitized := make([]any, 0, len(values)) + for _, value := range values { + switch value.(type) { + case string, bool, float64, int, int32, int64: + sanitized = append(sanitized, value) + } + } + if len(sanitized) == 0 { + return nil + } + return sanitized +} + +func sanitizeGeminiRequired(raw any, properties map[string]any) []string { + required := requiredStrings(raw) + if len(required) == 0 { + return nil + } + + filtered := make([]string, 0, len(required)) + seen := make(map[string]struct{}, len(required)) + for _, name := range required { + if _, ok := properties[name]; !ok { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + filtered = append(filtered, name) + } + if len(filtered) == 0 { + return nil + } + return filtered +} + +func requiredStrings(raw any) []string { + switch value := raw.(type) { + case []string: + return append([]string(nil), value...) + case []any: + required := make([]string, 0, len(value)) + for _, item := range value { + name, ok := item.(string) + if ok { + required = append(required, name) + } + } + return required + default: + return nil + } +} + +func mergeRequiredLists(left any, right any) []string { + merged := make([]string, 0) + seen := map[string]struct{}{} + + for _, name := range append(requiredStrings(left), requiredStrings(right)...) { + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + merged = append(merged, name) + } + + return merged +} + +func geminiUnionBranchScore(schema map[string]any) int { + score := 0 + if hasSchemaProperties(schema) { + score += 20 + } + if hasSchemaItems(schema) { + score += 10 + } + if _, ok := schema["enum"]; ok { + score += 5 + } + if _, ok := schema["description"]; ok { + score += 2 + } + score += len(schema) + return score +} + +func hasSchemaProperties(schema map[string]any) bool { + props, ok := schema["properties"].(map[string]any) + return ok && len(props) > 0 +} + +func hasSchemaItems(schema map[string]any) bool { + _, ok := schema["items"].(map[string]any) + return ok +} + +func schemaSlice(raw any) []map[string]any { + switch value := raw.(type) { + case []map[string]any: + return append([]map[string]any(nil), value...) + case []any: + schemas := make([]map[string]any, 0, len(value)) + for _, item := range value { + schema, ok := item.(map[string]any) + if ok { + schemas = append(schemas, schema) + } + } + return schemas + default: + return nil + } +} + +func cloneGeminiSchemaMap(in map[string]any) map[string]any { + if len(in) == 0 { + return nil + } + out := make(map[string]any, len(in)) + for key, value := range in { + out[key] = cloneGeminiSchemaValue(value) + } + return out +} + +func cloneGeminiSchemaValue(value any) any { + switch typed := value.(type) { + case map[string]any: + return cloneGeminiSchemaMap(typed) + case []any: + out := make([]any, len(typed)) + for i, item := range typed { + out[i] = cloneGeminiSchemaValue(item) + } + return out + case []string: + return append([]string(nil), typed...) + default: + return typed + } +} + +func cloneRefTrail(in map[string]struct{}) map[string]struct{} { + if len(in) == 0 { + return make(map[string]struct{}) + } + out := make(map[string]struct{}, len(in)) + for key := range in { + out[key] = struct{}{} + } + return out +} + +func stringSliceToAny(values []string) []any { + if len(values) == 0 { + return nil + } + result := make([]any, len(values)) + for i, value := range values { + result[i] = value + } + return result +} diff --git a/pkg/providers/common/google_schema_test.go b/pkg/providers/common/google_schema_test.go new file mode 100644 index 000000000..af2520493 --- /dev/null +++ b/pkg/providers/common/google_schema_test.go @@ -0,0 +1,254 @@ +package common + +import "testing" + +func TestSanitizeSchemaForGemini_DereferencesRefsAndFlattensUnions(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "parent": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/pageParent"}, + map[string]any{"$ref": "#/$defs/databaseParent"}, + }, + }, + "icon": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/emoji"}, + map[string]any{"type": "null"}, + }, + }, + "data": map[string]any{ + "$ref": "#/$defs/dataPayload", + }, + }, + "required": []any{"parent", "icon", "missing"}, + "$defs": map[string]any{ + "pageParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "page_id": map[string]any{ + "type": "string", + }, + }, + "required": []any{"page_id"}, + }, + "databaseParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "database_id": map[string]any{ + "type": "string", + }, + }, + "required": []any{"database_id"}, + }, + "emoji": map[string]any{ + "type": "string", + "pattern": "^:[a-z_]+:$", + }, + "dataPayload": map[string]any{ + "type": "object", + "additionalProperties": false, + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "minLength": 1, + }, + "count": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "required": []any{"name"}, + }, + }, + } + + got := SanitizeSchemaForGemini(schema) + assertSchemaKeyAbsent(t, got, "$defs") + assertSchemaKeyAbsent(t, got, "$ref") + assertSchemaKeyAbsent(t, got, "anyOf") + assertSchemaKeyAbsent(t, got, "oneOf") + assertSchemaKeyAbsent(t, got, "allOf") + assertSchemaKeyAbsent(t, got, "additionalProperties") + assertSchemaKeyAbsent(t, got, "pattern") + assertSchemaKeyAbsent(t, got, "minLength") + assertSchemaKeyAbsent(t, got, "minimum") + + if got["type"] != "object" { + t.Fatalf("top-level type = %#v, want object", got["type"]) + } + + props, ok := got["properties"].(map[string]any) + if !ok { + t.Fatalf("properties = %#v, want map", got["properties"]) + } + + parent, ok := props["parent"].(map[string]any) + if !ok { + t.Fatalf("parent schema = %#v, want map", props["parent"]) + } + if parent["type"] != "object" { + t.Fatalf("parent.type = %#v, want object", parent["type"]) + } + parentProps, ok := parent["properties"].(map[string]any) + if !ok { + t.Fatalf("parent.properties = %#v, want map", parent["properties"]) + } + if _, ok := parentProps["page_id"]; !ok { + t.Fatalf("parent.properties missing page_id: %#v", parentProps) + } + if _, ok := parentProps["database_id"]; !ok { + t.Fatalf("parent.properties missing database_id: %#v", parentProps) + } + if _, hasRequired := parent["required"]; hasRequired { + t.Fatalf("parent.required = %#v, want omitted for merged anyOf branches", parent["required"]) + } + + icon, ok := props["icon"].(map[string]any) + if !ok { + t.Fatalf("icon schema = %#v, want map", props["icon"]) + } + if icon["type"] != "string" { + t.Fatalf("icon.type = %#v, want string", icon["type"]) + } + + data, ok := props["data"].(map[string]any) + if !ok { + t.Fatalf("data schema = %#v, want map", props["data"]) + } + if data["type"] != "object" { + t.Fatalf("data.type = %#v, want object", data["type"]) + } + dataProps, ok := data["properties"].(map[string]any) + if !ok { + t.Fatalf("data.properties = %#v, want map", data["properties"]) + } + if _, ok := dataProps["name"]; !ok { + t.Fatalf("data.properties missing name: %#v", dataProps) + } + if _, ok := dataProps["count"]; !ok { + t.Fatalf("data.properties missing count: %#v", dataProps) + } + + required, ok := got["required"].([]string) + if !ok { + t.Fatalf("required = %#v, want []string", got["required"]) + } + if len(required) != 2 || required[0] != "parent" || required[1] != "icon" { + t.Fatalf("required = %#v, want [parent icon]", required) + } +} + +func TestSanitizeSchemaForGemini_MergesAllOfAndFiltersRequired(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "payload": map[string]any{ + "allOf": []any{ + map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "string", + }, + }, + "required": []any{"id"}, + }, + map[string]any{ + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + "count": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "required": []any{"name", "missing"}, + }, + }, + }, + }, + } + + got := SanitizeSchemaForGemini(schema) + props := got["properties"].(map[string]any) + payload := props["payload"].(map[string]any) + + if payload["type"] != "object" { + t.Fatalf("payload.type = %#v, want object", payload["type"]) + } + payloadProps, ok := payload["properties"].(map[string]any) + if !ok { + t.Fatalf("payload.properties = %#v, want map", payload["properties"]) + } + for _, key := range []string{"id", "name", "count"} { + if _, ok := payloadProps[key]; !ok { + t.Fatalf("payload.properties missing %q: %#v", key, payloadProps) + } + } + + required, ok := payload["required"].([]string) + if !ok { + t.Fatalf("payload.required = %#v, want []string", payload["required"]) + } + if len(required) != 2 || required[0] != "id" || required[1] != "name" { + t.Fatalf("payload.required = %#v, want [id name]", required) + } + + assertSchemaKeyAbsent(t, payload, "allOf") + assertSchemaKeyAbsent(t, payload, "minimum") +} + +func TestSanitizeSchemaForGemini_HandlesRecursiveRefs(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "tree": map[string]any{ + "$ref": "#/$defs/node", + }, + }, + "$defs": map[string]any{ + "node": map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + }, + "child": map[string]any{ + "$ref": "#/$defs/node", + }, + }, + }, + }, + } + + got := SanitizeSchemaForGemini(schema) + props := got["properties"].(map[string]any) + tree := props["tree"].(map[string]any) + if tree["type"] != "object" { + t.Fatalf("tree.type = %#v, want object", tree["type"]) + } + assertSchemaKeyAbsent(t, tree, "$ref") +} + +func assertSchemaKeyAbsent(t *testing.T, value any, key string) { + t.Helper() + + switch typed := value.(type) { + case map[string]any: + if _, found := typed[key]; found { + t.Fatalf("schema still contains key %q: %#v", key, typed) + } + for _, nested := range typed { + assertSchemaKeyAbsent(t, nested, key) + } + case []any: + for _, nested := range typed { + assertSchemaKeyAbsent(t, nested, key) + } + case []string: + return + } +} diff --git a/pkg/providers/httpapi/gemini_helpers.go b/pkg/providers/httpapi/gemini_helpers.go index a2b2d63c3..87cc4c084 100644 --- a/pkg/providers/httpapi/gemini_helpers.go +++ b/pkg/providers/httpapi/gemini_helpers.go @@ -12,66 +12,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake return "" } -var geminiUnsupportedKeywords = map[string]bool{ - "patternProperties": true, - "additionalProperties": true, - "$schema": true, - "$id": true, - "$ref": true, - "$defs": true, - "definitions": true, - "examples": true, - "minLength": true, - "maxLength": true, - "minimum": true, - "maximum": true, - "multipleOf": true, - "pattern": true, - "format": true, - "minItems": true, - "maxItems": true, - "uniqueItems": true, - "minProperties": true, - "maxProperties": true, -} - -func sanitizeSchemaForGemini(schema map[string]any) map[string]any { - if schema == nil { - return nil - } - - result := make(map[string]any) - for k, v := range schema { - if geminiUnsupportedKeywords[k] { - continue - } - switch val := v.(type) { - case map[string]any: - result[k] = sanitizeSchemaForGemini(val) - case []any: - sanitized := make([]any, len(val)) - for i, item := range val { - if m, ok := item.(map[string]any); ok { - sanitized[i] = sanitizeSchemaForGemini(m) - } else { - sanitized[i] = item - } - } - result[k] = sanitized - default: - result[k] = v - } - } - - if _, hasProps := result["properties"]; hasProps { - if _, hasType := result["type"]; !hasType { - result["type"] = "object" - } - } - - return result -} - func extractProtocol(model string) (protocol, modelID string) { model = strings.TrimSpace(model) protocol, modelID, found := strings.Cut(model, "/") diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index d1d523757..e73568c88 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody( funcDecls = append(funcDecls, geminiFunctionDeclaration{ Name: t.Function.Name, Description: t.Function.Description, - Parameters: sanitizeSchemaForGemini(t.Function.Parameters), + Parameters: common.SanitizeSchemaForGemini(t.Function.Parameters), }) } if len(funcDecls) > 0 { diff --git a/pkg/providers/httpapi/gemini_provider_test.go b/pkg/providers/httpapi/gemini_provider_test.go index aade90358..53770511e 100644 --- a/pkg/providers/httpapi/gemini_provider_test.go +++ b/pkg/providers/httpapi/gemini_provider_test.go @@ -5,8 +5,11 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "strings" "testing" + + providercommon "github.com/sipeed/picoclaw/pkg/providers/common" ) func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) { @@ -259,6 +262,65 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) { } } +func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.T) { + provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "parent": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/pageParent"}, + map[string]any{"$ref": "#/$defs/databaseParent"}, + }, + }, + }, + "$defs": map[string]any{ + "pageParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "page_id": map[string]any{"type": "string"}, + }, + "required": []any{"page_id"}, + }, + "databaseParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "database_id": map[string]any{"type": "string"}, + }, + "required": []any{"database_id"}, + }, + }, + } + + body := provider.buildRequestBody( + []Message{{Role: "user", Content: "hello"}}, + []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "mcp_notion_create", + Description: "Create a Notion object", + Parameters: schema, + }, + }}, + "gemini-3-flash-preview", + nil, + ) + + tools, ok := body["tools"].([]geminiTool) + if !ok || len(tools) != 1 { + t.Fatalf("tools = %#v, want one geminiTool", body["tools"]) + } + got, ok := tools[0].FunctionDeclarations[0].Parameters.(map[string]any) + if !ok { + t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters) + } + + want := providercommon.SanitizeSchemaForGemini(schema) + if !reflect.DeepEqual(got, want) { + t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want) + } +} + func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") diff --git a/pkg/providers/oauth/antigravity_provider.go b/pkg/providers/oauth/antigravity_provider.go index 1ac2d9c7f..fd5431b03 100644 --- a/pkg/providers/oauth/antigravity_provider.go +++ b/pkg/providers/oauth/antigravity_provider.go @@ -298,7 +298,7 @@ func (p *AntigravityProvider) buildRequest( if t.Type != "function" { continue } - params := sanitizeSchemaForGemini(t.Function.Parameters) + params := common.SanitizeSchemaForGemini(t.Function.Parameters) funcDecls = append(funcDecls, antigravityFuncDecl{ Name: t.Function.Name, Description: t.Function.Description, @@ -446,71 +446,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake return "" } -// --- Schema sanitization --- - -// Google/Gemini doesn't support many JSON Schema keywords that other providers accept. -var geminiUnsupportedKeywords = map[string]bool{ - "patternProperties": true, - "additionalProperties": true, - "$schema": true, - "$id": true, - "$ref": true, - "$defs": true, - "definitions": true, - "examples": true, - "minLength": true, - "maxLength": true, - "minimum": true, - "maximum": true, - "multipleOf": true, - "pattern": true, - "format": true, - "minItems": true, - "maxItems": true, - "uniqueItems": true, - "minProperties": true, - "maxProperties": true, -} - -func sanitizeSchemaForGemini(schema map[string]any) map[string]any { - if schema == nil { - return nil - } - - result := make(map[string]any) - for k, v := range schema { - if geminiUnsupportedKeywords[k] { - continue - } - // Recursively sanitize nested objects - switch val := v.(type) { - case map[string]any: - result[k] = sanitizeSchemaForGemini(val) - case []any: - sanitized := make([]any, len(val)) - for i, item := range val { - if m, ok := item.(map[string]any); ok { - sanitized[i] = sanitizeSchemaForGemini(m) - } else { - sanitized[i] = item - } - } - result[k] = sanitized - default: - result[k] = v - } - } - - // Ensure top-level has type: "object" if properties are present - if _, hasProps := result["properties"]; hasProps { - if _, hasType := result["type"]; !hasType { - result["type"] = "object" - } - } - - return result -} - // --- Token source --- func createAntigravityTokenSource() func() (string, string, error) { diff --git a/pkg/providers/oauth/antigravity_provider_test.go b/pkg/providers/oauth/antigravity_provider_test.go index 2989f8519..015b1ab80 100644 --- a/pkg/providers/oauth/antigravity_provider_test.go +++ b/pkg/providers/oauth/antigravity_provider_test.go @@ -1,6 +1,11 @@ package oauthprovider -import "testing" +import ( + "reflect" + "testing" + + providercommon "github.com/sipeed/picoclaw/pkg/providers/common" +) func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) { p := &AntigravityProvider{} @@ -71,3 +76,68 @@ func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) { t.Fatalf("Usage.TotalTokens = %v, want %d", resp.Usage, 216) } } + +func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) { + p := &AntigravityProvider{} + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "parent": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/pageParent"}, + map[string]any{"$ref": "#/$defs/databaseParent"}, + }, + }, + "icon": map[string]any{ + "anyOf": []any{ + map[string]any{"type": "null"}, + map[string]any{"$ref": "#/$defs/emoji"}, + }, + }, + }, + "$defs": map[string]any{ + "pageParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "page_id": map[string]any{"type": "string"}, + }, + "required": []any{"page_id"}, + }, + "databaseParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "database_id": map[string]any{"type": "string"}, + }, + "required": []any{"database_id"}, + }, + "emoji": map[string]any{ + "type": "string", + "pattern": "^:[a-z_]+:$", + }, + }, + } + + req := p.buildRequest( + []Message{{Role: "user", Content: "hello"}}, + []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "mcp_notion_create", + Description: "Create a Notion object", + Parameters: schema, + }, + }}, + "gemini-3-flash", + nil, + ) + + if len(req.Tools) != 1 || len(req.Tools[0].FunctionDeclarations) != 1 { + t.Fatalf("request tools = %#v, want one function declaration", req.Tools) + } + + got := req.Tools[0].FunctionDeclarations[0].Parameters + want := providercommon.SanitizeSchemaForGemini(schema) + if !reflect.DeepEqual(got, want) { + t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want) + } +} From 4eeb69688e6dffbb4e26c3ca015ce51ee7a0c145 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Sun, 26 Apr 2026 22:33:35 +0200 Subject: [PATCH 2/5] fix lint --- pkg/providers/common/google_schema_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/providers/common/google_schema_test.go b/pkg/providers/common/google_schema_test.go index af2520493..23aadbf98 100644 --- a/pkg/providers/common/google_schema_test.go +++ b/pkg/providers/common/google_schema_test.go @@ -95,10 +95,10 @@ func TestSanitizeSchemaForGemini_DereferencesRefsAndFlattensUnions(t *testing.T) if !ok { t.Fatalf("parent.properties = %#v, want map", parent["properties"]) } - if _, ok := parentProps["page_id"]; !ok { + if _, found := parentProps["page_id"]; !found { t.Fatalf("parent.properties missing page_id: %#v", parentProps) } - if _, ok := parentProps["database_id"]; !ok { + if _, found := parentProps["database_id"]; !found { t.Fatalf("parent.properties missing database_id: %#v", parentProps) } if _, hasRequired := parent["required"]; hasRequired { @@ -124,10 +124,10 @@ func TestSanitizeSchemaForGemini_DereferencesRefsAndFlattensUnions(t *testing.T) if !ok { t.Fatalf("data.properties = %#v, want map", data["properties"]) } - if _, ok := dataProps["name"]; !ok { + if _, found := dataProps["name"]; !found { t.Fatalf("data.properties missing name: %#v", dataProps) } - if _, ok := dataProps["count"]; !ok { + if _, found := dataProps["count"]; !found { t.Fatalf("data.properties missing count: %#v", dataProps) } @@ -184,7 +184,7 @@ func TestSanitizeSchemaForGemini_MergesAllOfAndFiltersRequired(t *testing.T) { t.Fatalf("payload.properties = %#v, want map", payload["properties"]) } for _, key := range []string{"id", "name", "count"} { - if _, ok := payloadProps[key]; !ok { + if _, found := payloadProps[key]; !found { t.Fatalf("payload.properties missing %q: %#v", key, payloadProps) } } From cd7717bc155a440c06d42dad4c1b9c4aaed1d83b Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Mon, 27 Apr 2026 21:10:30 +0200 Subject: [PATCH 3/5] feat(tool): tool schema semplification --- config/config.example.json | 1 + docs/guides/providers.md | 56 +++++++--- pkg/config/config.go | 85 +++++++------- pkg/config/config_test.go | 30 +++++ pkg/config/model_config_test.go | 18 +++ pkg/config/multikey_test.go | 25 +++-- pkg/providers/common/google_schema.go | 16 ++- pkg/providers/common/tool_schema_transform.go | 59 ++++++++++ pkg/providers/factory_provider.go | 50 +++++---- pkg/providers/factory_provider_test.go | 39 +++++++ pkg/providers/httpapi/gemini_provider.go | 2 +- pkg/providers/httpapi/gemini_provider_test.go | 10 +- pkg/providers/oauth/antigravity_provider.go | 5 +- .../oauth/antigravity_provider_test.go | 15 ++- pkg/providers/tool_schema_transform.go | 84 ++++++++++++++ pkg/providers/tool_schema_transform_test.go | 104 ++++++++++++++++++ web/backend/api/models.go | 63 ++++++----- web/backend/api/models_test.go | 94 ++++++++++++++++ web/frontend/src/api/models.ts | 1 + .../src/components/models/add-model-sheet.tsx | 14 +++ .../components/models/edit-model-sheet.tsx | 15 +++ web/frontend/src/i18n/locales/en.json | 2 + web/frontend/src/i18n/locales/zh.json | 2 + 23 files changed, 654 insertions(+), 136 deletions(-) create mode 100644 pkg/providers/common/tool_schema_transform.go create mode 100644 pkg/providers/tool_schema_transform.go create mode 100644 pkg/providers/tool_schema_transform_test.go diff --git a/config/config.example.json b/config/config.example.json index 30460c231..33eed63d1 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -41,6 +41,7 @@ }, { "model_name": "gemini", + "_comment": "Optional: set \"tool_schema_transform\": \"simple\" for providers that reject complex tool JSON Schema.", "model": "antigravity/gemini-2.0-flash", "auth_method": "oauth" }, diff --git a/docs/guides/providers.md b/docs/guides/providers.md index d99d8c016..62bbbbe56 100644 --- a/docs/guides/providers.md +++ b/docs/guides/providers.md @@ -116,23 +116,47 @@ This design also enables **multi-agent support** with flexible provider selectio #### `model_list` Entry Fields -| Field | Type | Required | Description | -|-------|------|----------|-------------| -| `model_name` | string | Yes | Unique name used to reference this model in agent config | -| `provider` | string | No | Preferred provider identifier. When present, PicoClaw sends `model` unchanged to that provider | -| `model` | string | Yes | Native model ID when `provider` is set. If `provider` is omitted, the legacy `provider/model` form is still supported | -| `api_keys` | string[] | Yes* | API key(s) for authentication. Multiple keys enable per-request rotation. Not required for local providers (Ollama, LM Studio, VLLM) | -| `api_base` | string | No | Override the default API endpoint URL | -| `proxy` | string | No | HTTP proxy URL for this model entry | -| `user_agent` | string | No | Custom `User-Agent` header sent with API requests (supported by OpenAI-compatible, Gemini, Anthropic, and Azure providers) | -| `request_timeout` | int | No | Request timeout in seconds (default varies by provider) | -| `max_tokens_field` | string | No | Override the max tokens field name in request body (e.g., `max_completion_tokens` for o1 models) | -| `thinking_level` | string | No | Extended thinking level: `off`, `low`, `medium`, `high`, `xhigh`, or `adaptive` | -| `extra_body` | object | No | Additional fields to inject into every request body | +| Field | Type | Required | Description | +|-------|------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `model_name` | string | Yes | Unique name used to reference this model in agent config | +| `provider` | string | No | Preferred provider identifier. When present, PicoClaw sends `model` unchanged to that provider | +| `model` | string | Yes | Native model ID when `provider` is set. If `provider` is omitted, the legacy `provider/model` form is still supported | +| `api_keys` | string[] | Yes* | API key(s) for authentication. Multiple keys enable per-request rotation. Not required for local providers (Ollama, LM Studio, VLLM) | +| `api_base` | string | No | Override the default API endpoint URL | +| `proxy` | string | No | HTTP proxy URL for this model entry | +| `user_agent` | string | No | Custom `User-Agent` header sent with API requests (supported by OpenAI-compatible, Gemini, Anthropic, and Azure providers) | +| `request_timeout` | int | No | Request timeout in seconds (default varies by provider) | +| `max_tokens_field` | string | No | Override the max tokens field name in request body (e.g., `max_completion_tokens` for o1 models) | +| `thinking_level` | string | No | Extended thinking level: `off`, `low`, `medium`, `high`, `xhigh`, or `adaptive` | +| `tool_schema_transform` | string | No | Optional compatibility transform for tool parameter schemas. Default: disabled. Supported values: `simple`. | +| `extra_body` | object | No | Additional fields to inject into every request body | | `custom_headers` | object | No | Additional HTTP headers to inject into every request (e.g., `{"X-Source":"coding-plan"}`). If a key matches a built-in header, the custom value overrides the built-in one (e.g., `Authorization`, `User-Agent`, `Content-Type`, `Accept`). | -| `rpm` | int | No | Per-minute request rate limit | -| `fallbacks` | string[] | No | Fallback model names for automatic failover | -| `enabled` | bool | No | Whether this model entry is active (default: `true`) | +| `rpm` | int | No | Per-minute request rate limit | +| `fallbacks` | string[] | No | Fallback model names for automatic failover | +| `enabled` | bool | No | Whether this model entry is active (default: `true`) | + +#### Tool Schema Compatibility + +By default, PicoClaw now forwards tool JSON Schemas unchanged. + +Some providers reject advanced JSON Schema features such as `$ref`, `$defs`, `anyOf`, `oneOf`, `allOf`, `pattern`, or numeric/string constraints inside tool declarations. For those models, you can opt into a compatibility transform per model entry with `tool_schema_transform`. + +Use `simple` when the upstream provider expects the conservative style function schema subset: + +```json +{ + "model_name": "gemini-2.5-flash-safe-tools", + "provider": "gemini", + "model": "gemini-2.5-flash", + "api_keys": ["your-gemini-key"], + "tool_schema_transform": "simple" +} +``` + +Notes: + +- Default behavior is disabled. If you omit `tool_schema_transform`, PicoClaw sends the original tool schema. +- The setting is per model entry, so you can enable it only for the providers that need it. #### Provider / Model Resolution diff --git a/pkg/config/config.go b/pkg/config/config.go index 305c3a5c0..d109f1361 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -17,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/fileutil" "github.com/sipeed/picoclaw/pkg/logger" + providercommon "github.com/sipeed/picoclaw/pkg/providers/common" ) // rrCounter is a global counter for round-robin load balancing across models. @@ -553,12 +554,13 @@ type ModelConfig struct { Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers // Optional optimizations - RPM int `json:"rpm,omitempty"` // Requests per minute limit - MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") - RequestTimeout int `json:"request_timeout,omitempty"` - ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive - ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body - CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request + RPM int `json:"rpm,omitempty"` // Requests per minute limit + MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") + RequestTimeout int `json:"request_timeout,omitempty"` + ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive + ToolSchemaTransform string `json:"tool_schema_transform,omitempty"` // Optional tool schema compatibility transform (e.g. "simple") + ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body + CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover) @@ -595,6 +597,9 @@ func (c *ModelConfig) Validate() error { if c.Model == "" { return fmt.Errorf("model is required") } + if _, err := providercommon.NormalizeToolSchemaTransform(c.ToolSchemaTransform); err != nil { + return err + } return nil } @@ -1419,23 +1424,24 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig { // Create a copy for the additional key additionalEntry := &ModelConfig{ - ModelName: expandedName, - Provider: m.Provider, - Model: m.Model, - APIBase: m.APIBase, - APIKeys: SimpleSecureStrings(keys[i]), - Proxy: m.Proxy, - AuthMethod: m.AuthMethod, - ConnectMode: m.ConnectMode, - Workspace: m.Workspace, - RPM: m.RPM, - MaxTokensField: m.MaxTokensField, - RequestTimeout: m.RequestTimeout, - ThinkingLevel: m.ThinkingLevel, - ExtraBody: m.ExtraBody, - CustomHeaders: m.CustomHeaders, - UserAgent: m.UserAgent, - isVirtual: true, + ModelName: expandedName, + Provider: m.Provider, + Model: m.Model, + APIBase: m.APIBase, + APIKeys: SimpleSecureStrings(keys[i]), + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + ToolSchemaTransform: m.ToolSchemaTransform, + ExtraBody: m.ExtraBody, + CustomHeaders: m.CustomHeaders, + UserAgent: m.UserAgent, + isVirtual: true, } expanded = append(expanded, additionalEntry) fallbackNames = append(fallbackNames, expandedName) @@ -1443,22 +1449,23 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig { // Create the primary entry with first key and fallbacks primaryEntry := &ModelConfig{ - ModelName: originalName, - Provider: m.Provider, - Model: m.Model, - APIBase: m.APIBase, - Proxy: m.Proxy, - AuthMethod: m.AuthMethod, - ConnectMode: m.ConnectMode, - Workspace: m.Workspace, - RPM: m.RPM, - MaxTokensField: m.MaxTokensField, - RequestTimeout: m.RequestTimeout, - ThinkingLevel: m.ThinkingLevel, - ExtraBody: m.ExtraBody, - CustomHeaders: m.CustomHeaders, - UserAgent: m.UserAgent, - APIKeys: SimpleSecureStrings(keys[0]), + ModelName: originalName, + Provider: m.Provider, + Model: m.Model, + APIBase: m.APIBase, + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + ToolSchemaTransform: m.ToolSchemaTransform, + ExtraBody: m.ExtraBody, + CustomHeaders: m.CustomHeaders, + UserAgent: m.UserAgent, + APIKeys: SimpleSecureStrings(keys[0]), } // Prepend new fallbacks to existing ones diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 2be1bcc67..a8f32d47d 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1945,6 +1945,36 @@ func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) { } } +func TestModelConfig_ToolSchemaTransformRoundTrip(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + + cfg := &Config{ + Version: CurrentVersion, + ModelList: []*ModelConfig{ + { + ModelName: "test-model", + Model: "openai/test", + APIKeys: SimpleSecureStrings("sk-test"), + ToolSchemaTransform: "simple", + }, + }, + } + + if err := SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("SaveConfig error: %v", err) + } + + loaded, err := LoadConfig(cfgPath) + if err != nil { + t.Fatalf("LoadConfig error: %v", err) + } + + if got := loaded.ModelList[0].ToolSchemaTransform; got != "simple" { + t.Fatalf("ToolSchemaTransform = %q, want %q", got, "simple") + } +} + func TestDefaultConfig_MinimaxExtraBody(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go index 8fd501155..d22eb290f 100644 --- a/pkg/config/model_config_test.go +++ b/pkg/config/model_config_test.go @@ -158,6 +158,15 @@ func TestModelConfig_Validate(t *testing.T) { }, wantErr: false, }, + { + name: "valid tool schema transform", + config: ModelConfig{ + ModelName: "test", + Model: "openai/gpt-4o", + ToolSchemaTransform: "simple", + }, + wantErr: false, + }, { name: "missing model_name", config: ModelConfig{ @@ -177,6 +186,15 @@ func TestModelConfig_Validate(t *testing.T) { config: ModelConfig{}, wantErr: true, }, + { + name: "invalid tool schema transform", + config: ModelConfig{ + ModelName: "test", + Model: "openai/gpt-4o", + ToolSchemaTransform: "invalid", + }, + wantErr: true, + }, } for _, tt := range tests { diff --git a/pkg/config/multikey_test.go b/pkg/config/multikey_test.go index cb55db938..073cb7826 100644 --- a/pkg/config/multikey_test.go +++ b/pkg/config/multikey_test.go @@ -187,15 +187,16 @@ func TestExpandMultiKeyModels_Deduplication(t *testing.T) { func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) { modelCfg := &ModelConfig{ - ModelName: "gpt-4", - Provider: "openrouter", - Model: "openai/gpt-4o", - APIBase: "https://api.example.com", - Proxy: "http://proxy:8080", - RPM: 60, - MaxTokensField: "max_completion_tokens", - RequestTimeout: 30, - ThinkingLevel: "high", + ModelName: "gpt-4", + Provider: "openrouter", + Model: "openai/gpt-4o", + APIBase: "https://api.example.com", + Proxy: "http://proxy:8080", + RPM: 60, + MaxTokensField: "max_completion_tokens", + RequestTimeout: 30, + ThinkingLevel: "high", + ToolSchemaTransform: "simple", } modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing models := []*ModelConfig{modelCfg} @@ -225,6 +226,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) { if primary.ThinkingLevel != "high" { t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel) } + if primary.ToolSchemaTransform != "simple" { + t.Errorf("expected tool_schema_transform preserved, got %q", primary.ToolSchemaTransform) + } // Check additional entry also preserves fields additional := result[0] @@ -237,6 +241,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) { if additional.RPM != 60 { t.Errorf("expected additional rpm preserved, got %d", additional.RPM) } + if additional.ToolSchemaTransform != "simple" { + t.Errorf("expected additional tool_schema_transform preserved, got %q", additional.ToolSchemaTransform) + } } func TestExpandMultiKeyModels_IsVirtualFlag(t *testing.T) { diff --git a/pkg/providers/common/google_schema.go b/pkg/providers/common/google_schema.go index bdafadd23..f7b2a337b 100644 --- a/pkg/providers/common/google_schema.go +++ b/pkg/providers/common/google_schema.go @@ -16,11 +16,11 @@ var geminiSupportedTypes = map[string]bool{ "string": true, } -// SanitizeSchemaForGemini reduces a JSON Schema to the conservative subset -// accepted by Gemini-style function declarations. It resolves local refs, -// collapses composition keywords like anyOf/oneOf/allOf, and strips advanced -// keywords that Gemini rejects. -func SanitizeSchemaForGemini(schema map[string]any) map[string]any { +// SanitizeSchemaForGoogle reduces a JSON Schema to the conservative subset +// accepted by Google/Gemini-style function declarations. It resolves local +// refs, collapses composition keywords like anyOf/oneOf/allOf, and strips +// advanced keywords that Gemini-compatible backends often reject. +func SanitizeSchemaForGoogle(schema map[string]any) map[string]any { if schema == nil { return nil } @@ -39,6 +39,12 @@ func SanitizeSchemaForGemini(schema map[string]any) map[string]any { return result } +// SanitizeSchemaForGemini is kept as a compatibility alias for the original +// Google/Gemini sanitizer name. +func SanitizeSchemaForGemini(schema map[string]any) map[string]any { + return SanitizeSchemaForGoogle(schema) +} + type geminiSchemaSanitizer struct { root map[string]any } diff --git a/pkg/providers/common/tool_schema_transform.go b/pkg/providers/common/tool_schema_transform.go new file mode 100644 index 000000000..10e96d056 --- /dev/null +++ b/pkg/providers/common/tool_schema_transform.go @@ -0,0 +1,59 @@ +package common + +import ( + "fmt" + "strings" +) + +const ( + ToolSchemaTransformOff = "" + ToolSchemaTransformSimple = "simple" +) + +// NormalizeToolSchemaTransform resolves user-facing aliases to a canonical +// transform mode. Empty values and explicit "off"-style values disable schema +// transformation. +func NormalizeToolSchemaTransform(raw string) (string, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "off", "none", "native": + return ToolSchemaTransformOff, nil + case "simple", "basic", "strict", "flat": + return ToolSchemaTransformSimple, nil + default: + return "", fmt.Errorf("unsupported tool_schema_transform %q (supported: off, simple)", raw) + } +} + +// TransformToolDefinitions clones tool definitions and applies the configured +// schema transform to function parameter schemas. When the transform is off, the +// original slice is returned unchanged. +func TransformToolDefinitions(tools []ToolDefinition, transform string) ([]ToolDefinition, error) { + transform, err := NormalizeToolSchemaTransform(transform) + if err != nil { + return nil, err + } + if transform == ToolSchemaTransformOff || len(tools) == 0 { + return tools, nil + } + + out := make([]ToolDefinition, len(tools)) + for i, tool := range tools { + out[i] = tool + if tool.Type != "function" { + continue + } + out[i].Function = tool.Function + out[i].Function.Parameters = transformToolSchema(tool.Function.Parameters, transform) + } + + return out, nil +} + +func transformToolSchema(schema map[string]any, transform string) map[string]any { + switch transform { + case ToolSchemaTransformSimple: + return SanitizeSchemaForGoogle(schema) + default: + return cloneGeminiSchemaMap(schema) + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index ce83c6c54..a59e2de25 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -168,7 +168,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if err != nil { return nil, "", err } - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) } // OpenAI with API key if cfg.APIKey() == "" && cfg.APIBase == "" { @@ -189,7 +189,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.CustomHeaders, ) provider.SetProviderName(protocol) - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) case "azure", "azure-openai": // Azure OpenAI uses deployment-based URLs, api-key header auth, @@ -202,13 +202,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err "api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)", ) } - return azure.NewProviderWithTimeout( + return finalizeProviderFromConfig(azure.NewProviderWithTimeout( cfg.APIKey(), cfg.APIBase, cfg.Proxy, userAgent, cfg.RequestTimeout, - ), modelID, nil + ), modelID, cfg) case "bedrock": // AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.) @@ -244,7 +244,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if err != nil { return nil, "", fmt.Errorf("creating bedrock provider: %w", err) } - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", @@ -270,7 +270,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.CustomHeaders, ) provider.SetProviderName(protocol) - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) case "gemini": if cfg.APIKey() == "" && cfg.APIBase == "" { @@ -280,7 +280,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if apiBase == "" { apiBase = getDefaultAPIBase(protocol) } - return NewGeminiProvider( + return finalizeProviderFromConfig(NewGeminiProvider( cfg.APIKey(), apiBase, cfg.Proxy, @@ -288,7 +288,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.RequestTimeout, cfg.ExtraBody, cfg.CustomHeaders, - ), modelID, nil + ), modelID, cfg) case "minimax": // Minimax requires reasoning_split: true in the request body @@ -317,7 +317,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.CustomHeaders, ) provider.SetProviderName(protocol) - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) case "anthropic": if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" { @@ -326,7 +326,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if err != nil { return nil, "", err } - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) } // Use API key with HTTP API apiBase := cfg.APIBase @@ -347,7 +347,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.CustomHeaders, ) provider.SetProviderName(protocol) - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) case "anthropic-messages": // Anthropic Messages API with native format (HTTP-based, no SDK) @@ -358,12 +358,12 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if cfg.APIKey() == "" { return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model) } - return anthropicmessages.NewProviderWithTimeout( + return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout( cfg.APIKey(), apiBase, userAgent, cfg.RequestTimeout, - ), modelID, nil + ), modelID, cfg) case "coding-plan-anthropic", "alibaba-coding-anthropic": // Alibaba Coding Plan with Anthropic-compatible API @@ -374,29 +374,29 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if cfg.APIKey() == "" { return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model) } - return anthropicmessages.NewProviderWithTimeout( + return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout( cfg.APIKey(), apiBase, userAgent, cfg.RequestTimeout, - ), modelID, nil + ), modelID, cfg) case "antigravity": - return NewAntigravityProvider(), modelID, nil + return finalizeProviderFromConfig(NewAntigravityProvider(), modelID, cfg) case "claude-cli", "claudecli": workspace := cfg.Workspace if workspace == "" { workspace = "." } - return NewClaudeCliProvider(workspace), modelID, nil + return finalizeProviderFromConfig(NewClaudeCliProvider(workspace), modelID, cfg) case "codex-cli", "codexcli": workspace := cfg.Workspace if workspace == "" { workspace = "." } - return NewCodexCliProvider(workspace), modelID, nil + return finalizeProviderFromConfig(NewCodexCliProvider(workspace), modelID, cfg) case "github-copilot", "copilot": apiBase := cfg.APIBase @@ -411,13 +411,25 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if err != nil { return nil, "", err } - return provider, modelID, nil + return finalizeProviderFromConfig(provider, modelID, cfg) default: return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model) } } +func finalizeProviderFromConfig( + provider LLMProvider, + modelID string, + cfg *config.ModelConfig, +) (LLMProvider, string, error) { + wrapped, err := wrapProviderWithToolSchemaTransform(provider, cfg.ToolSchemaTransform) + if err != nil { + return nil, "", err + } + return wrapped, modelID, nil +} + func isEmptyAPIKeyAllowed(protocol string) bool { meta, ok := protocolMetaByName[protocol] return ok && meta.emptyAPIKeyAllowed diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 3dd1eefb3..3d3c30ce0 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -1202,3 +1202,42 @@ func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) { // Unexpected error - fail the test t.Errorf("unexpected error from bedrock provider: %v", err) } + +func TestCreateProviderFromConfig_ToolSchemaTransformWrapsProvider(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "claude-cli-test", + Provider: "claude-cli", + Model: "claude-sonnet-4.6", + Workspace: t.TempDir(), + ToolSchemaTransform: "simple", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if modelID != "claude-sonnet-4.6" { + t.Fatalf("modelID = %q, want %q", modelID, "claude-sonnet-4.6") + } + if _, ok := provider.(*toolSchemaTransformProvider); !ok { + t.Fatalf("provider = %T, want *toolSchemaTransformProvider", provider) + } +} + +func TestCreateProviderFromConfig_InvalidToolSchemaTransform(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "claude-cli-test", + Provider: "claude-cli", + Model: "claude-sonnet-4.6", + Workspace: t.TempDir(), + ToolSchemaTransform: "invalid", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for invalid tool_schema_transform") + } + if !strings.Contains(err.Error(), "tool_schema_transform") { + t.Fatalf("error = %v, want mention tool_schema_transform", err) + } +} diff --git a/pkg/providers/httpapi/gemini_provider.go b/pkg/providers/httpapi/gemini_provider.go index e73568c88..395c555d1 100644 --- a/pkg/providers/httpapi/gemini_provider.go +++ b/pkg/providers/httpapi/gemini_provider.go @@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody( funcDecls = append(funcDecls, geminiFunctionDeclaration{ Name: t.Function.Name, Description: t.Function.Description, - Parameters: common.SanitizeSchemaForGemini(t.Function.Parameters), + Parameters: t.Function.Parameters, }) } if len(funcDecls) > 0 { diff --git a/pkg/providers/httpapi/gemini_provider_test.go b/pkg/providers/httpapi/gemini_provider_test.go index 53770511e..b455357c0 100644 --- a/pkg/providers/httpapi/gemini_provider_test.go +++ b/pkg/providers/httpapi/gemini_provider_test.go @@ -5,11 +5,8 @@ import ( "fmt" "net/http" "net/http/httptest" - "reflect" "strings" "testing" - - providercommon "github.com/sipeed/picoclaw/pkg/providers/common" ) func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) { @@ -262,7 +259,7 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) { } } -func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.T) { +func TestGeminiProvider_BuildRequestBody_PreservesComplexToolSchemasByDefault(t *testing.T) { provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil) schema := map[string]any{ "type": "object", @@ -315,9 +312,8 @@ func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing. t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters) } - want := providercommon.SanitizeSchemaForGemini(schema) - if !reflect.DeepEqual(got, want) { - t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want) + if got["$defs"] == nil { + t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got) } } diff --git a/pkg/providers/oauth/antigravity_provider.go b/pkg/providers/oauth/antigravity_provider.go index fd5431b03..abf1e4bd6 100644 --- a/pkg/providers/oauth/antigravity_provider.go +++ b/pkg/providers/oauth/antigravity_provider.go @@ -291,18 +291,17 @@ func (p *AntigravityProvider) buildRequest( } } - // Build tools (sanitize schemas for Gemini compatibility) + // Build tools if len(tools) > 0 { var funcDecls []antigravityFuncDecl for _, t := range tools { if t.Type != "function" { continue } - params := common.SanitizeSchemaForGemini(t.Function.Parameters) funcDecls = append(funcDecls, antigravityFuncDecl{ Name: t.Function.Name, Description: t.Function.Description, - Parameters: params, + Parameters: t.Function.Parameters, }) } if len(funcDecls) > 0 { diff --git a/pkg/providers/oauth/antigravity_provider_test.go b/pkg/providers/oauth/antigravity_provider_test.go index 015b1ab80..d85e47dfa 100644 --- a/pkg/providers/oauth/antigravity_provider_test.go +++ b/pkg/providers/oauth/antigravity_provider_test.go @@ -1,10 +1,7 @@ package oauthprovider import ( - "reflect" "testing" - - providercommon "github.com/sipeed/picoclaw/pkg/providers/common" ) func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) { @@ -77,7 +74,7 @@ func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) { } } -func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) { +func TestBuildRequest_PreservesComplexToolSchemasByDefault(t *testing.T) { p := &AntigravityProvider{} schema := map[string]any{ "type": "object", @@ -135,9 +132,11 @@ func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) { t.Fatalf("request tools = %#v, want one function declaration", req.Tools) } - got := req.Tools[0].FunctionDeclarations[0].Parameters - want := providercommon.SanitizeSchemaForGemini(schema) - if !reflect.DeepEqual(got, want) { - t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want) + got, ok := req.Tools[0].FunctionDeclarations[0].Parameters.(map[string]any) + if !ok { + t.Fatalf("parameters = %#v, want map", req.Tools[0].FunctionDeclarations[0].Parameters) + } + if got["$defs"] == nil { + t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got) } } diff --git a/pkg/providers/tool_schema_transform.go b/pkg/providers/tool_schema_transform.go new file mode 100644 index 000000000..6b6cab7a6 --- /dev/null +++ b/pkg/providers/tool_schema_transform.go @@ -0,0 +1,84 @@ +package providers + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/providers/common" +) + +type toolSchemaTransformProvider struct { + delegate LLMProvider + transform string +} + +type toolSchemaStreamingProvider struct { + *toolSchemaTransformProvider +} + +func wrapProviderWithToolSchemaTransform(delegate LLMProvider, transform string) (LLMProvider, error) { + transform, err := common.NormalizeToolSchemaTransform(transform) + if err != nil { + return nil, err + } + if transform == common.ToolSchemaTransformOff || delegate == nil { + return delegate, nil + } + base := &toolSchemaTransformProvider{ + delegate: delegate, + transform: transform, + } + if _, ok := delegate.(StreamingProvider); ok { + return &toolSchemaStreamingProvider{toolSchemaTransformProvider: base}, nil + } + return base, nil +} + +func (p *toolSchemaTransformProvider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + transformed, err := common.TransformToolDefinitions(tools, p.transform) + if err != nil { + return nil, err + } + return p.delegate.Chat(ctx, messages, transformed, model, options) +} + +func (p *toolSchemaTransformProvider) GetDefaultModel() string { + return p.delegate.GetDefaultModel() +} + +func (p *toolSchemaStreamingProvider) ChatStream( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + onChunk func(accumulated string), +) (*LLMResponse, error) { + streaming := p.delegate.(StreamingProvider) + transformed, err := common.TransformToolDefinitions(tools, p.transform) + if err != nil { + return nil, err + } + return streaming.ChatStream(ctx, messages, transformed, model, options, onChunk) +} + +func (p *toolSchemaTransformProvider) SupportsThinking() bool { + tc, ok := p.delegate.(ThinkingCapable) + return ok && tc.SupportsThinking() +} + +func (p *toolSchemaTransformProvider) SupportsNativeSearch() bool { + ns, ok := p.delegate.(NativeSearchCapable) + return ok && ns.SupportsNativeSearch() +} + +func (p *toolSchemaTransformProvider) Close() { + if stateful, ok := p.delegate.(StatefulProvider); ok { + stateful.Close() + } +} diff --git a/pkg/providers/tool_schema_transform_test.go b/pkg/providers/tool_schema_transform_test.go new file mode 100644 index 000000000..c83416377 --- /dev/null +++ b/pkg/providers/tool_schema_transform_test.go @@ -0,0 +1,104 @@ +package providers + +import ( + "context" + "reflect" + "testing" + + providercommon "github.com/sipeed/picoclaw/pkg/providers/common" +) + +type toolCaptureProvider struct { + lastTools []ToolDefinition +} + +func (p *toolCaptureProvider) Chat( + _ context.Context, + _ []Message, + tools []ToolDefinition, + _ string, + _ map[string]any, +) (*LLMResponse, error) { + p.lastTools = tools + return &LLMResponse{Content: "ok"}, nil +} + +func (p *toolCaptureProvider) GetDefaultModel() string { + return "test" +} + +func TestWrapProviderWithToolSchemaTransform_DisabledPassesToolsThrough(t *testing.T) { + capture := &toolCaptureProvider{} + wrapped, err := wrapProviderWithToolSchemaTransform(capture, "") + if err != nil { + t.Fatalf("wrapProviderWithToolSchemaTransform() error = %v", err) + } + + tools := []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "noop", + Parameters: map[string]any{"type": "object"}, + }, + }} + + _, err = wrapped.Chat(t.Context(), nil, tools, "test", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if !reflect.DeepEqual(capture.lastTools, tools) { + t.Fatalf("tools mutated with transform off\n got: %#v\nwant: %#v", capture.lastTools, tools) + } +} + +func TestWrapProviderWithToolSchemaTransform_GoogleSanitizesSchemas(t *testing.T) { + capture := &toolCaptureProvider{} + wrapped, err := wrapProviderWithToolSchemaTransform(capture, "google") + if err != nil { + t.Fatalf("wrapProviderWithToolSchemaTransform() error = %v", err) + } + + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "parent": map[string]any{ + "anyOf": []any{ + map[string]any{"$ref": "#/$defs/pageParent"}, + map[string]any{"$ref": "#/$defs/databaseParent"}, + }, + }, + }, + "$defs": map[string]any{ + "pageParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "page_id": map[string]any{"type": "string"}, + }, + }, + "databaseParent": map[string]any{ + "type": "object", + "properties": map[string]any{ + "database_id": map[string]any{"type": "string"}, + }, + }, + }, + } + tools := []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "mcp_notion_create", + Parameters: schema, + }, + }} + + _, err = wrapped.Chat(t.Context(), nil, tools, "test", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + want := providercommon.SanitizeSchemaForGoogle(schema) + got := capture.lastTools[0].Function.Parameters + if !reflect.DeepEqual(got, want) { + t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want) + } +} diff --git a/web/backend/api/models.go b/web/backend/api/models.go index cf903ce4c..61eb235cb 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -35,14 +35,15 @@ type modelResponse struct { Proxy string `json:"proxy,omitempty"` AuthMethod string `json:"auth_method,omitempty"` // Advanced fields - ConnectMode string `json:"connect_mode,omitempty"` - Workspace string `json:"workspace,omitempty"` - RPM int `json:"rpm,omitempty"` - MaxTokensField string `json:"max_tokens_field,omitempty"` - RequestTimeout int `json:"request_timeout,omitempty"` - ThinkingLevel string `json:"thinking_level,omitempty"` - ExtraBody map[string]any `json:"extra_body,omitempty"` - CustomHeaders map[string]string `json:"custom_headers,omitempty"` + ConnectMode string `json:"connect_mode,omitempty"` + Workspace string `json:"workspace,omitempty"` + RPM int `json:"rpm,omitempty"` + MaxTokensField string `json:"max_tokens_field,omitempty"` + RequestTimeout int `json:"request_timeout,omitempty"` + ThinkingLevel string `json:"thinking_level,omitempty"` + ToolSchemaTransform string `json:"tool_schema_transform,omitempty"` + ExtraBody map[string]any `json:"extra_body,omitempty"` + CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Meta Enabled bool `json:"enabled"` Available bool `json:"available"` @@ -78,27 +79,28 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) { for i, m := range cfg.ModelList { provider, modelID := providers.ExtractProtocol(m) models = append(models, modelResponse{ - Index: i, - ModelName: m.ModelName, - Provider: provider, - Model: modelID, - APIBase: m.APIBase, - APIKey: maskAPIKey(m.APIKey()), - Proxy: m.Proxy, - AuthMethod: m.AuthMethod, - ConnectMode: m.ConnectMode, - Workspace: m.Workspace, - RPM: m.RPM, - MaxTokensField: m.MaxTokensField, - RequestTimeout: m.RequestTimeout, - ThinkingLevel: m.ThinkingLevel, - ExtraBody: m.ExtraBody, - CustomHeaders: m.CustomHeaders, - Enabled: m.Enabled, - Available: modelStatuses[i].Available, - Status: modelStatuses[i].Status, - IsDefault: m.ModelName == defaultModel, - IsVirtual: m.IsVirtual(), + Index: i, + ModelName: m.ModelName, + Provider: provider, + Model: modelID, + APIBase: m.APIBase, + APIKey: maskAPIKey(m.APIKey()), + Proxy: m.Proxy, + AuthMethod: m.AuthMethod, + ConnectMode: m.ConnectMode, + Workspace: m.Workspace, + RPM: m.RPM, + MaxTokensField: m.MaxTokensField, + RequestTimeout: m.RequestTimeout, + ThinkingLevel: m.ThinkingLevel, + ToolSchemaTransform: m.ToolSchemaTransform, + ExtraBody: m.ExtraBody, + CustomHeaders: m.CustomHeaders, + Enabled: m.Enabled, + Available: modelStatuses[i].Available, + Status: modelStatuses[i].Status, + IsDefault: m.ModelName == defaultModel, + IsVirtual: m.IsVirtual(), }) } @@ -237,6 +239,9 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) { } else if len(mc.CustomHeaders) == 0 { mc.CustomHeaders = nil } + if _, ok := rawFields["tool_schema_transform"]; !ok { + mc.ToolSchemaTransform = cfg.ModelList[idx].ToolSchemaTransform + } // Preserve the existing Provider when the caller omits it. This keeps the // update API backward-compatible for clients that haven't started sending // the new field yet, while still allowing explicit clearing via "". diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index f374ac15b..a60c6ef7a 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -584,6 +584,37 @@ func TestHandleAddModel_PersistsCustomHeaders(t *testing.T) { } } +func TestHandleAddModel_PersistsToolSchemaTransform(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{ + "model_name":"new-model-transform", + "model":"openai/gpt-4o-mini", + "tool_schema_transform":"simple" + }`)) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + added := cfg.ModelList[len(cfg.ModelList)-1] + if got := added.ToolSchemaTransform; got != "simple" { + t.Fatalf("tool_schema_transform = %q, want %q", got, "simple") + } +} + func TestHandleUpdateModel_CustomHeadersPreserveAndClear(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() @@ -649,6 +680,69 @@ func TestHandleUpdateModel_CustomHeadersPreserveAndClear(t *testing.T) { } } +func TestHandleUpdateModel_ToolSchemaTransformPreserveAndClear(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []*config.ModelConfig{{ + ModelName: "editable", + Model: "openai/gpt-4o-mini", + APIKeys: config.SimpleSecureStrings("sk-existing"), + ToolSchemaTransform: "google", + }} + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + recPreserve := httptest.NewRecorder() + reqPreserve := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{ + "model_name":"editable", + "model":"openai/gpt-4o-mini" + }`)) + reqPreserve.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(recPreserve, reqPreserve) + if recPreserve.Code != http.StatusOK { + t.Fatalf("preserve status = %d, want %d, body=%s", recPreserve.Code, http.StatusOK, recPreserve.Body.String()) + } + + afterPreserve, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() after preserve error = %v", err) + } + if got := afterPreserve.ModelList[0].ToolSchemaTransform; got != "google" { + t.Fatalf("preserved tool_schema_transform = %q, want %q", got, "google") + } + + recClear := httptest.NewRecorder() + reqClear := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{ + "model_name":"editable", + "model":"openai/gpt-4o-mini", + "tool_schema_transform":"" + }`)) + reqClear.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(recClear, reqClear) + if recClear.Code != http.StatusOK { + t.Fatalf("clear status = %d, want %d, body=%s", recClear.Code, http.StatusOK, recClear.Body.String()) + } + + afterClear, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() after clear error = %v", err) + } + if afterClear.ModelList[0].ToolSchemaTransform != "" { + t.Fatalf("tool_schema_transform = %q, want empty", afterClear.ModelList[0].ToolSchemaTransform) + } +} + func TestHandleUpdateModel_PersistsProvider(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() diff --git a/web/frontend/src/api/models.ts b/web/frontend/src/api/models.ts index d2d2dca88..926bf8a0a 100644 --- a/web/frontend/src/api/models.ts +++ b/web/frontend/src/api/models.ts @@ -19,6 +19,7 @@ export interface ModelInfo { max_tokens_field?: string request_timeout?: number thinking_level?: string + tool_schema_transform?: string extra_body?: Record custom_headers?: Record // Meta diff --git a/web/frontend/src/components/models/add-model-sheet.tsx b/web/frontend/src/components/models/add-model-sheet.tsx index a9102aa8a..be2a8fd64 100644 --- a/web/frontend/src/components/models/add-model-sheet.tsx +++ b/web/frontend/src/components/models/add-model-sheet.tsx @@ -36,6 +36,7 @@ interface AddForm { maxTokensField: string requestTimeout: string thinkingLevel: string + toolSchemaTransform: string extraBody: string customHeaders: string } @@ -54,6 +55,7 @@ const EMPTY_ADD_FORM: AddForm = { maxTokensField: "", requestTimeout: "", thinkingLevel: "", + toolSchemaTransform: "", extraBody: "", customHeaders: "", } @@ -139,6 +141,7 @@ export function AddModelSheet({ ? Number(form.requestTimeout) : undefined, thinking_level: form.thinkingLevel.trim() || undefined, + tool_schema_transform: form.toolSchemaTransform.trim() || undefined, extra_body: form.extraBody.trim() ? JSON.parse(form.extraBody.trim()) : undefined, @@ -333,6 +336,17 @@ export function AddModelSheet({ /> + + + + + + + + Date: Mon, 27 Apr 2026 21:18:19 +0200 Subject: [PATCH 4/5] fix test --- pkg/providers/tool_schema_transform_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/tool_schema_transform_test.go b/pkg/providers/tool_schema_transform_test.go index c83416377..a162c3cb4 100644 --- a/pkg/providers/tool_schema_transform_test.go +++ b/pkg/providers/tool_schema_transform_test.go @@ -53,7 +53,7 @@ func TestWrapProviderWithToolSchemaTransform_DisabledPassesToolsThrough(t *testi func TestWrapProviderWithToolSchemaTransform_GoogleSanitizesSchemas(t *testing.T) { capture := &toolCaptureProvider{} - wrapped, err := wrapProviderWithToolSchemaTransform(capture, "google") + wrapped, err := wrapProviderWithToolSchemaTransform(capture, "simple") if err != nil { t.Fatalf("wrapProviderWithToolSchemaTransform() error = %v", err) } From 23df824c776334317aae2cbcf9c6ff3a26695fe2 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Mon, 27 Apr 2026 21:27:02 +0200 Subject: [PATCH 5/5] fix test --- web/backend/api/models_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go index a60c6ef7a..dd5ff6a54 100644 --- a/web/backend/api/models_test.go +++ b/web/backend/api/models_test.go @@ -692,7 +692,7 @@ func TestHandleUpdateModel_ToolSchemaTransformPreserveAndClear(t *testing.T) { ModelName: "editable", Model: "openai/gpt-4o-mini", APIKeys: config.SimpleSecureStrings("sk-existing"), - ToolSchemaTransform: "google", + ToolSchemaTransform: "simple", }} err = config.SaveConfig(configPath, cfg) if err != nil { @@ -718,8 +718,8 @@ func TestHandleUpdateModel_ToolSchemaTransformPreserveAndClear(t *testing.T) { if err != nil { t.Fatalf("LoadConfig() after preserve error = %v", err) } - if got := afterPreserve.ModelList[0].ToolSchemaTransform; got != "google" { - t.Fatalf("preserved tool_schema_transform = %q, want %q", got, "google") + if got := afterPreserve.ModelList[0].ToolSchemaTransform; got != "simple" { + t.Fatalf("preserved tool_schema_transform = %q, want %q", got, "simple") } recClear := httptest.NewRecorder()