feat(tool): tool schema semplification

This commit is contained in:
afjcjsbx
2026-04-27 21:10:30 +02:00
parent 4eeb69688e
commit cd7717bc15
23 changed files with 654 additions and 136 deletions
+46 -39
View File
@@ -17,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
)
// rrCounter is a global counter for round-robin load balancing across models.
@@ -553,12 +554,13 @@ type ModelConfig struct {
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
// Optional optimizations
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
ToolSchemaTransform string `json:"tool_schema_transform,omitempty"` // Optional tool schema compatibility transform (e.g. "simple")
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
@@ -595,6 +597,9 @@ func (c *ModelConfig) Validate() error {
if c.Model == "" {
return fmt.Errorf("model is required")
}
if _, err := providercommon.NormalizeToolSchemaTransform(c.ToolSchemaTransform); err != nil {
return err
}
return nil
}
@@ -1419,23 +1424,24 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
// Create a copy for the additional key
additionalEntry := &ModelConfig{
ModelName: expandedName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
APIKeys: SimpleSecureStrings(keys[i]),
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
isVirtual: true,
ModelName: expandedName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
APIKeys: SimpleSecureStrings(keys[i]),
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ToolSchemaTransform: m.ToolSchemaTransform,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
isVirtual: true,
}
expanded = append(expanded, additionalEntry)
fallbackNames = append(fallbackNames, expandedName)
@@ -1443,22 +1449,23 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
// Create the primary entry with first key and fallbacks
primaryEntry := &ModelConfig{
ModelName: originalName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
APIKeys: SimpleSecureStrings(keys[0]),
ModelName: originalName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ToolSchemaTransform: m.ToolSchemaTransform,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
APIKeys: SimpleSecureStrings(keys[0]),
}
// Prepend new fallbacks to existing ones
+30
View File
@@ -1945,6 +1945,36 @@ func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
}
}
func TestModelConfig_ToolSchemaTransformRoundTrip(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
cfg := &Config{
Version: CurrentVersion,
ModelList: []*ModelConfig{
{
ModelName: "test-model",
Model: "openai/test",
APIKeys: SimpleSecureStrings("sk-test"),
ToolSchemaTransform: "simple",
},
},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig error: %v", err)
}
loaded, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig error: %v", err)
}
if got := loaded.ModelList[0].ToolSchemaTransform; got != "simple" {
t.Fatalf("ToolSchemaTransform = %q, want %q", got, "simple")
}
}
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
cfg := DefaultConfig()
+18
View File
@@ -158,6 +158,15 @@ func TestModelConfig_Validate(t *testing.T) {
},
wantErr: false,
},
{
name: "valid tool schema transform",
config: ModelConfig{
ModelName: "test",
Model: "openai/gpt-4o",
ToolSchemaTransform: "simple",
},
wantErr: false,
},
{
name: "missing model_name",
config: ModelConfig{
@@ -177,6 +186,15 @@ func TestModelConfig_Validate(t *testing.T) {
config: ModelConfig{},
wantErr: true,
},
{
name: "invalid tool schema transform",
config: ModelConfig{
ModelName: "test",
Model: "openai/gpt-4o",
ToolSchemaTransform: "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
+16 -9
View File
@@ -187,15 +187,16 @@ func TestExpandMultiKeyModels_Deduplication(t *testing.T) {
func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
modelCfg := &ModelConfig{
ModelName: "gpt-4",
Provider: "openrouter",
Model: "openai/gpt-4o",
APIBase: "https://api.example.com",
Proxy: "http://proxy:8080",
RPM: 60,
MaxTokensField: "max_completion_tokens",
RequestTimeout: 30,
ThinkingLevel: "high",
ModelName: "gpt-4",
Provider: "openrouter",
Model: "openai/gpt-4o",
APIBase: "https://api.example.com",
Proxy: "http://proxy:8080",
RPM: 60,
MaxTokensField: "max_completion_tokens",
RequestTimeout: 30,
ThinkingLevel: "high",
ToolSchemaTransform: "simple",
}
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
models := []*ModelConfig{modelCfg}
@@ -225,6 +226,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
if primary.ThinkingLevel != "high" {
t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel)
}
if primary.ToolSchemaTransform != "simple" {
t.Errorf("expected tool_schema_transform preserved, got %q", primary.ToolSchemaTransform)
}
// Check additional entry also preserves fields
additional := result[0]
@@ -237,6 +241,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
if additional.RPM != 60 {
t.Errorf("expected additional rpm preserved, got %d", additional.RPM)
}
if additional.ToolSchemaTransform != "simple" {
t.Errorf("expected additional tool_schema_transform preserved, got %q", additional.ToolSchemaTransform)
}
}
func TestExpandMultiKeyModels_IsVirtualFlag(t *testing.T) {
+11 -5
View File
@@ -16,11 +16,11 @@ var geminiSupportedTypes = map[string]bool{
"string": true,
}
// SanitizeSchemaForGemini reduces a JSON Schema to the conservative subset
// accepted by Gemini-style function declarations. It resolves local refs,
// collapses composition keywords like anyOf/oneOf/allOf, and strips advanced
// keywords that Gemini rejects.
func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
// SanitizeSchemaForGoogle reduces a JSON Schema to the conservative subset
// accepted by Google/Gemini-style function declarations. It resolves local
// refs, collapses composition keywords like anyOf/oneOf/allOf, and strips
// advanced keywords that Gemini-compatible backends often reject.
func SanitizeSchemaForGoogle(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
@@ -39,6 +39,12 @@ func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
return result
}
// SanitizeSchemaForGemini is kept as a compatibility alias for the original
// Google/Gemini sanitizer name.
func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
return SanitizeSchemaForGoogle(schema)
}
type geminiSchemaSanitizer struct {
root map[string]any
}
@@ -0,0 +1,59 @@
package common
import (
"fmt"
"strings"
)
const (
ToolSchemaTransformOff = ""
ToolSchemaTransformSimple = "simple"
)
// NormalizeToolSchemaTransform resolves user-facing aliases to a canonical
// transform mode. Empty values and explicit "off"-style values disable schema
// transformation.
func NormalizeToolSchemaTransform(raw string) (string, error) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "", "off", "none", "native":
return ToolSchemaTransformOff, nil
case "simple", "basic", "strict", "flat":
return ToolSchemaTransformSimple, nil
default:
return "", fmt.Errorf("unsupported tool_schema_transform %q (supported: off, simple)", raw)
}
}
// TransformToolDefinitions clones tool definitions and applies the configured
// schema transform to function parameter schemas. When the transform is off, the
// original slice is returned unchanged.
func TransformToolDefinitions(tools []ToolDefinition, transform string) ([]ToolDefinition, error) {
transform, err := NormalizeToolSchemaTransform(transform)
if err != nil {
return nil, err
}
if transform == ToolSchemaTransformOff || len(tools) == 0 {
return tools, nil
}
out := make([]ToolDefinition, len(tools))
for i, tool := range tools {
out[i] = tool
if tool.Type != "function" {
continue
}
out[i].Function = tool.Function
out[i].Function.Parameters = transformToolSchema(tool.Function.Parameters, transform)
}
return out, nil
}
func transformToolSchema(schema map[string]any, transform string) map[string]any {
switch transform {
case ToolSchemaTransformSimple:
return SanitizeSchemaForGoogle(schema)
default:
return cloneGeminiSchemaMap(schema)
}
}
+31 -19
View File
@@ -168,7 +168,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if err != nil {
return nil, "", err
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
}
// OpenAI with API key
if cfg.APIKey() == "" && cfg.APIBase == "" {
@@ -189,7 +189,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "azure", "azure-openai":
// Azure OpenAI uses deployment-based URLs, api-key header auth,
@@ -202,13 +202,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
)
}
return azure.NewProviderWithTimeout(
return finalizeProviderFromConfig(azure.NewProviderWithTimeout(
cfg.APIKey(),
cfg.APIBase,
cfg.Proxy,
userAgent,
cfg.RequestTimeout,
), modelID, nil
), modelID, cfg)
case "bedrock":
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
@@ -244,7 +244,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if err != nil {
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
@@ -270,7 +270,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "gemini":
if cfg.APIKey() == "" && cfg.APIBase == "" {
@@ -280,7 +280,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewGeminiProvider(
return finalizeProviderFromConfig(NewGeminiProvider(
cfg.APIKey(),
apiBase,
cfg.Proxy,
@@ -288,7 +288,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
cfg.ExtraBody,
cfg.CustomHeaders,
), modelID, nil
), modelID, cfg)
case "minimax":
// Minimax requires reasoning_split: true in the request body
@@ -317,7 +317,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "anthropic":
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
@@ -326,7 +326,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if err != nil {
return nil, "", err
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
}
// Use API key with HTTP API
apiBase := cfg.APIBase
@@ -347,7 +347,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "anthropic-messages":
// Anthropic Messages API with native format (HTTP-based, no SDK)
@@ -358,12 +358,12 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model)
}
return anthropicmessages.NewProviderWithTimeout(
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
userAgent,
cfg.RequestTimeout,
), modelID, nil
), modelID, cfg)
case "coding-plan-anthropic", "alibaba-coding-anthropic":
// Alibaba Coding Plan with Anthropic-compatible API
@@ -374,29 +374,29 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
}
return anthropicmessages.NewProviderWithTimeout(
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
userAgent,
cfg.RequestTimeout,
), modelID, nil
), modelID, cfg)
case "antigravity":
return NewAntigravityProvider(), modelID, nil
return finalizeProviderFromConfig(NewAntigravityProvider(), modelID, cfg)
case "claude-cli", "claudecli":
workspace := cfg.Workspace
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), modelID, nil
return finalizeProviderFromConfig(NewClaudeCliProvider(workspace), modelID, cfg)
case "codex-cli", "codexcli":
workspace := cfg.Workspace
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), modelID, nil
return finalizeProviderFromConfig(NewCodexCliProvider(workspace), modelID, cfg)
case "github-copilot", "copilot":
apiBase := cfg.APIBase
@@ -411,13 +411,25 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if err != nil {
return nil, "", err
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
default:
return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
}
}
func finalizeProviderFromConfig(
provider LLMProvider,
modelID string,
cfg *config.ModelConfig,
) (LLMProvider, string, error) {
wrapped, err := wrapProviderWithToolSchemaTransform(provider, cfg.ToolSchemaTransform)
if err != nil {
return nil, "", err
}
return wrapped, modelID, nil
}
func isEmptyAPIKeyAllowed(protocol string) bool {
meta, ok := protocolMetaByName[protocol]
return ok && meta.emptyAPIKeyAllowed
+39
View File
@@ -1202,3 +1202,42 @@ func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}
func TestCreateProviderFromConfig_ToolSchemaTransformWrapsProvider(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "claude-cli-test",
Provider: "claude-cli",
Model: "claude-sonnet-4.6",
Workspace: t.TempDir(),
ToolSchemaTransform: "simple",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if modelID != "claude-sonnet-4.6" {
t.Fatalf("modelID = %q, want %q", modelID, "claude-sonnet-4.6")
}
if _, ok := provider.(*toolSchemaTransformProvider); !ok {
t.Fatalf("provider = %T, want *toolSchemaTransformProvider", provider)
}
}
func TestCreateProviderFromConfig_InvalidToolSchemaTransform(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "claude-cli-test",
Provider: "claude-cli",
Model: "claude-sonnet-4.6",
Workspace: t.TempDir(),
ToolSchemaTransform: "invalid",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for invalid tool_schema_transform")
}
if !strings.Contains(err.Error(), "tool_schema_transform") {
t.Fatalf("error = %v, want mention tool_schema_transform", err)
}
}
+1 -1
View File
@@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody(
funcDecls = append(funcDecls, geminiFunctionDeclaration{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: common.SanitizeSchemaForGemini(t.Function.Parameters),
Parameters: t.Function.Parameters,
})
}
if len(funcDecls) > 0 {
@@ -5,11 +5,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
)
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
@@ -262,7 +259,7 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
}
}
func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.T) {
func TestGeminiProvider_BuildRequestBody_PreservesComplexToolSchemasByDefault(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
schema := map[string]any{
"type": "object",
@@ -315,9 +312,8 @@ func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.
t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters)
}
want := providercommon.SanitizeSchemaForGemini(schema)
if !reflect.DeepEqual(got, want) {
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
if got["$defs"] == nil {
t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got)
}
}
+2 -3
View File
@@ -291,18 +291,17 @@ func (p *AntigravityProvider) buildRequest(
}
}
// Build tools (sanitize schemas for Gemini compatibility)
// Build tools
if len(tools) > 0 {
var funcDecls []antigravityFuncDecl
for _, t := range tools {
if t.Type != "function" {
continue
}
params := common.SanitizeSchemaForGemini(t.Function.Parameters)
funcDecls = append(funcDecls, antigravityFuncDecl{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: params,
Parameters: t.Function.Parameters,
})
}
if len(funcDecls) > 0 {
@@ -1,10 +1,7 @@
package oauthprovider
import (
"reflect"
"testing"
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
)
func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) {
@@ -77,7 +74,7 @@ func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
}
}
func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) {
func TestBuildRequest_PreservesComplexToolSchemasByDefault(t *testing.T) {
p := &AntigravityProvider{}
schema := map[string]any{
"type": "object",
@@ -135,9 +132,11 @@ func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) {
t.Fatalf("request tools = %#v, want one function declaration", req.Tools)
}
got := req.Tools[0].FunctionDeclarations[0].Parameters
want := providercommon.SanitizeSchemaForGemini(schema)
if !reflect.DeepEqual(got, want) {
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
got, ok := req.Tools[0].FunctionDeclarations[0].Parameters.(map[string]any)
if !ok {
t.Fatalf("parameters = %#v, want map", req.Tools[0].FunctionDeclarations[0].Parameters)
}
if got["$defs"] == nil {
t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got)
}
}
+84
View File
@@ -0,0 +1,84 @@
package providers
import (
"context"
"github.com/sipeed/picoclaw/pkg/providers/common"
)
type toolSchemaTransformProvider struct {
delegate LLMProvider
transform string
}
type toolSchemaStreamingProvider struct {
*toolSchemaTransformProvider
}
func wrapProviderWithToolSchemaTransform(delegate LLMProvider, transform string) (LLMProvider, error) {
transform, err := common.NormalizeToolSchemaTransform(transform)
if err != nil {
return nil, err
}
if transform == common.ToolSchemaTransformOff || delegate == nil {
return delegate, nil
}
base := &toolSchemaTransformProvider{
delegate: delegate,
transform: transform,
}
if _, ok := delegate.(StreamingProvider); ok {
return &toolSchemaStreamingProvider{toolSchemaTransformProvider: base}, nil
}
return base, nil
}
func (p *toolSchemaTransformProvider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
transformed, err := common.TransformToolDefinitions(tools, p.transform)
if err != nil {
return nil, err
}
return p.delegate.Chat(ctx, messages, transformed, model, options)
}
func (p *toolSchemaTransformProvider) GetDefaultModel() string {
return p.delegate.GetDefaultModel()
}
func (p *toolSchemaStreamingProvider) ChatStream(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
onChunk func(accumulated string),
) (*LLMResponse, error) {
streaming := p.delegate.(StreamingProvider)
transformed, err := common.TransformToolDefinitions(tools, p.transform)
if err != nil {
return nil, err
}
return streaming.ChatStream(ctx, messages, transformed, model, options, onChunk)
}
func (p *toolSchemaTransformProvider) SupportsThinking() bool {
tc, ok := p.delegate.(ThinkingCapable)
return ok && tc.SupportsThinking()
}
func (p *toolSchemaTransformProvider) SupportsNativeSearch() bool {
ns, ok := p.delegate.(NativeSearchCapable)
return ok && ns.SupportsNativeSearch()
}
func (p *toolSchemaTransformProvider) Close() {
if stateful, ok := p.delegate.(StatefulProvider); ok {
stateful.Close()
}
}
+104
View File
@@ -0,0 +1,104 @@
package providers
import (
"context"
"reflect"
"testing"
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
)
type toolCaptureProvider struct {
lastTools []ToolDefinition
}
func (p *toolCaptureProvider) Chat(
_ context.Context,
_ []Message,
tools []ToolDefinition,
_ string,
_ map[string]any,
) (*LLMResponse, error) {
p.lastTools = tools
return &LLMResponse{Content: "ok"}, nil
}
func (p *toolCaptureProvider) GetDefaultModel() string {
return "test"
}
func TestWrapProviderWithToolSchemaTransform_DisabledPassesToolsThrough(t *testing.T) {
capture := &toolCaptureProvider{}
wrapped, err := wrapProviderWithToolSchemaTransform(capture, "")
if err != nil {
t.Fatalf("wrapProviderWithToolSchemaTransform() error = %v", err)
}
tools := []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "noop",
Parameters: map[string]any{"type": "object"},
},
}}
_, err = wrapped.Chat(t.Context(), nil, tools, "test", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if !reflect.DeepEqual(capture.lastTools, tools) {
t.Fatalf("tools mutated with transform off\n got: %#v\nwant: %#v", capture.lastTools, tools)
}
}
func TestWrapProviderWithToolSchemaTransform_GoogleSanitizesSchemas(t *testing.T) {
capture := &toolCaptureProvider{}
wrapped, err := wrapProviderWithToolSchemaTransform(capture, "google")
if err != nil {
t.Fatalf("wrapProviderWithToolSchemaTransform() error = %v", err)
}
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"parent": map[string]any{
"anyOf": []any{
map[string]any{"$ref": "#/$defs/pageParent"},
map[string]any{"$ref": "#/$defs/databaseParent"},
},
},
},
"$defs": map[string]any{
"pageParent": map[string]any{
"type": "object",
"properties": map[string]any{
"page_id": map[string]any{"type": "string"},
},
},
"databaseParent": map[string]any{
"type": "object",
"properties": map[string]any{
"database_id": map[string]any{"type": "string"},
},
},
},
}
tools := []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "mcp_notion_create",
Parameters: schema,
},
}}
_, err = wrapped.Chat(t.Context(), nil, tools, "test", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
want := providercommon.SanitizeSchemaForGoogle(schema)
got := capture.lastTools[0].Function.Parameters
if !reflect.DeepEqual(got, want) {
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
}
}