mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(tool): tool schema semplification
This commit is contained in:
+46
-39
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
// rrCounter is a global counter for round-robin load balancing across models.
|
||||
@@ -553,12 +554,13 @@ type ModelConfig struct {
|
||||
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
||||
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ToolSchemaTransform string `json:"tool_schema_transform,omitempty"` // Optional tool schema compatibility transform (e.g. "simple")
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
|
||||
|
||||
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
|
||||
|
||||
@@ -595,6 +597,9 @@ func (c *ModelConfig) Validate() error {
|
||||
if c.Model == "" {
|
||||
return fmt.Errorf("model is required")
|
||||
}
|
||||
if _, err := providercommon.NormalizeToolSchemaTransform(c.ToolSchemaTransform); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1419,23 +1424,24 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
|
||||
// Create a copy for the additional key
|
||||
additionalEntry := &ModelConfig{
|
||||
ModelName: expandedName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
APIKeys: SimpleSecureStrings(keys[i]),
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
isVirtual: true,
|
||||
ModelName: expandedName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
APIKeys: SimpleSecureStrings(keys[i]),
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ToolSchemaTransform: m.ToolSchemaTransform,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
isVirtual: true,
|
||||
}
|
||||
expanded = append(expanded, additionalEntry)
|
||||
fallbackNames = append(fallbackNames, expandedName)
|
||||
@@ -1443,22 +1449,23 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
|
||||
// Create the primary entry with first key and fallbacks
|
||||
primaryEntry := &ModelConfig{
|
||||
ModelName: originalName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
APIKeys: SimpleSecureStrings(keys[0]),
|
||||
ModelName: originalName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ToolSchemaTransform: m.ToolSchemaTransform,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
APIKeys: SimpleSecureStrings(keys[0]),
|
||||
}
|
||||
|
||||
// Prepend new fallbacks to existing ones
|
||||
|
||||
@@ -1945,6 +1945,36 @@ func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_ToolSchemaTransformRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
cfg := &Config{
|
||||
Version: CurrentVersion,
|
||||
ModelList: []*ModelConfig{
|
||||
{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test",
|
||||
APIKeys: SimpleSecureStrings("sk-test"),
|
||||
ToolSchemaTransform: "simple",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
|
||||
if got := loaded.ModelList[0].ToolSchemaTransform; got != "simple" {
|
||||
t.Fatalf("ToolSchemaTransform = %q, want %q", got, "simple")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
|
||||
@@ -158,6 +158,15 @@ func TestModelConfig_Validate(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid tool schema transform",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
Model: "openai/gpt-4o",
|
||||
ToolSchemaTransform: "simple",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model_name",
|
||||
config: ModelConfig{
|
||||
@@ -177,6 +186,15 @@ func TestModelConfig_Validate(t *testing.T) {
|
||||
config: ModelConfig{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool schema transform",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
Model: "openai/gpt-4o",
|
||||
ToolSchemaTransform: "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -187,15 +187,16 @@ func TestExpandMultiKeyModels_Deduplication(t *testing.T) {
|
||||
|
||||
func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
modelCfg := &ModelConfig{
|
||||
ModelName: "gpt-4",
|
||||
Provider: "openrouter",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: "https://api.example.com",
|
||||
Proxy: "http://proxy:8080",
|
||||
RPM: 60,
|
||||
MaxTokensField: "max_completion_tokens",
|
||||
RequestTimeout: 30,
|
||||
ThinkingLevel: "high",
|
||||
ModelName: "gpt-4",
|
||||
Provider: "openrouter",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: "https://api.example.com",
|
||||
Proxy: "http://proxy:8080",
|
||||
RPM: 60,
|
||||
MaxTokensField: "max_completion_tokens",
|
||||
RequestTimeout: 30,
|
||||
ThinkingLevel: "high",
|
||||
ToolSchemaTransform: "simple",
|
||||
}
|
||||
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
|
||||
models := []*ModelConfig{modelCfg}
|
||||
@@ -225,6 +226,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
if primary.ThinkingLevel != "high" {
|
||||
t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel)
|
||||
}
|
||||
if primary.ToolSchemaTransform != "simple" {
|
||||
t.Errorf("expected tool_schema_transform preserved, got %q", primary.ToolSchemaTransform)
|
||||
}
|
||||
|
||||
// Check additional entry also preserves fields
|
||||
additional := result[0]
|
||||
@@ -237,6 +241,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
if additional.RPM != 60 {
|
||||
t.Errorf("expected additional rpm preserved, got %d", additional.RPM)
|
||||
}
|
||||
if additional.ToolSchemaTransform != "simple" {
|
||||
t.Errorf("expected additional tool_schema_transform preserved, got %q", additional.ToolSchemaTransform)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_IsVirtualFlag(t *testing.T) {
|
||||
|
||||
@@ -16,11 +16,11 @@ var geminiSupportedTypes = map[string]bool{
|
||||
"string": true,
|
||||
}
|
||||
|
||||
// SanitizeSchemaForGemini reduces a JSON Schema to the conservative subset
|
||||
// accepted by Gemini-style function declarations. It resolves local refs,
|
||||
// collapses composition keywords like anyOf/oneOf/allOf, and strips advanced
|
||||
// keywords that Gemini rejects.
|
||||
func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
// SanitizeSchemaForGoogle reduces a JSON Schema to the conservative subset
|
||||
// accepted by Google/Gemini-style function declarations. It resolves local
|
||||
// refs, collapses composition keywords like anyOf/oneOf/allOf, and strips
|
||||
// advanced keywords that Gemini-compatible backends often reject.
|
||||
func SanitizeSchemaForGoogle(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -39,6 +39,12 @@ func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
return result
|
||||
}
|
||||
|
||||
// SanitizeSchemaForGemini is kept as a compatibility alias for the original
|
||||
// Google/Gemini sanitizer name.
|
||||
func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
return SanitizeSchemaForGoogle(schema)
|
||||
}
|
||||
|
||||
type geminiSchemaSanitizer struct {
|
||||
root map[string]any
|
||||
}
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolSchemaTransformOff = ""
|
||||
ToolSchemaTransformSimple = "simple"
|
||||
)
|
||||
|
||||
// NormalizeToolSchemaTransform resolves user-facing aliases to a canonical
|
||||
// transform mode. Empty values and explicit "off"-style values disable schema
|
||||
// transformation.
|
||||
func NormalizeToolSchemaTransform(raw string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "", "off", "none", "native":
|
||||
return ToolSchemaTransformOff, nil
|
||||
case "simple", "basic", "strict", "flat":
|
||||
return ToolSchemaTransformSimple, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported tool_schema_transform %q (supported: off, simple)", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TransformToolDefinitions clones tool definitions and applies the configured
|
||||
// schema transform to function parameter schemas. When the transform is off, the
|
||||
// original slice is returned unchanged.
|
||||
func TransformToolDefinitions(tools []ToolDefinition, transform string) ([]ToolDefinition, error) {
|
||||
transform, err := NormalizeToolSchemaTransform(transform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if transform == ToolSchemaTransformOff || len(tools) == 0 {
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
out := make([]ToolDefinition, len(tools))
|
||||
for i, tool := range tools {
|
||||
out[i] = tool
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
out[i].Function = tool.Function
|
||||
out[i].Function.Parameters = transformToolSchema(tool.Function.Parameters, transform)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func transformToolSchema(schema map[string]any, transform string) map[string]any {
|
||||
switch transform {
|
||||
case ToolSchemaTransformSimple:
|
||||
return SanitizeSchemaForGoogle(schema)
|
||||
default:
|
||||
return cloneGeminiSchemaMap(schema)
|
||||
}
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
}
|
||||
// OpenAI with API key
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
@@ -189,7 +189,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "azure", "azure-openai":
|
||||
// Azure OpenAI uses deployment-based URLs, api-key header auth,
|
||||
@@ -202,13 +202,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
|
||||
)
|
||||
}
|
||||
return azure.NewProviderWithTimeout(
|
||||
return finalizeProviderFromConfig(azure.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "bedrock":
|
||||
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
|
||||
@@ -244,7 +244,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
@@ -270,7 +270,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "gemini":
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
@@ -280,7 +280,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewGeminiProvider(
|
||||
return finalizeProviderFromConfig(NewGeminiProvider(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
@@ -288,7 +288,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "minimax":
|
||||
// Minimax requires reasoning_split: true in the request body
|
||||
@@ -317,7 +317,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "anthropic":
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
@@ -326,7 +326,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
}
|
||||
// Use API key with HTTP API
|
||||
apiBase := cfg.APIBase
|
||||
@@ -347,7 +347,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "anthropic-messages":
|
||||
// Anthropic Messages API with native format (HTTP-based, no SDK)
|
||||
@@ -358,12 +358,12 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "coding-plan-anthropic", "alibaba-coding-anthropic":
|
||||
// Alibaba Coding Plan with Anthropic-compatible API
|
||||
@@ -374,29 +374,29 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
|
||||
}
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
return finalizeProviderFromConfig(NewAntigravityProvider(), modelID, cfg)
|
||||
|
||||
case "claude-cli", "claudecli":
|
||||
workspace := cfg.Workspace
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), modelID, nil
|
||||
return finalizeProviderFromConfig(NewClaudeCliProvider(workspace), modelID, cfg)
|
||||
|
||||
case "codex-cli", "codexcli":
|
||||
workspace := cfg.Workspace
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), modelID, nil
|
||||
return finalizeProviderFromConfig(NewCodexCliProvider(workspace), modelID, cfg)
|
||||
|
||||
case "github-copilot", "copilot":
|
||||
apiBase := cfg.APIBase
|
||||
@@ -411,13 +411,25 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
default:
|
||||
return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func finalizeProviderFromConfig(
|
||||
provider LLMProvider,
|
||||
modelID string,
|
||||
cfg *config.ModelConfig,
|
||||
) (LLMProvider, string, error) {
|
||||
wrapped, err := wrapProviderWithToolSchemaTransform(provider, cfg.ToolSchemaTransform)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return wrapped, modelID, nil
|
||||
}
|
||||
|
||||
func isEmptyAPIKeyAllowed(protocol string) bool {
|
||||
meta, ok := protocolMetaByName[protocol]
|
||||
return ok && meta.emptyAPIKeyAllowed
|
||||
|
||||
@@ -1202,3 +1202,42 @@ func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
|
||||
// Unexpected error - fail the test
|
||||
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_ToolSchemaTransformWrapsProvider(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "claude-cli-test",
|
||||
Provider: "claude-cli",
|
||||
Model: "claude-sonnet-4.6",
|
||||
Workspace: t.TempDir(),
|
||||
ToolSchemaTransform: "simple",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if modelID != "claude-sonnet-4.6" {
|
||||
t.Fatalf("modelID = %q, want %q", modelID, "claude-sonnet-4.6")
|
||||
}
|
||||
if _, ok := provider.(*toolSchemaTransformProvider); !ok {
|
||||
t.Fatalf("provider = %T, want *toolSchemaTransformProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_InvalidToolSchemaTransform(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "claude-cli-test",
|
||||
Provider: "claude-cli",
|
||||
Model: "claude-sonnet-4.6",
|
||||
Workspace: t.TempDir(),
|
||||
ToolSchemaTransform: "invalid",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for invalid tool_schema_transform")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "tool_schema_transform") {
|
||||
t.Fatalf("error = %v, want mention tool_schema_transform", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody(
|
||||
funcDecls = append(funcDecls, geminiFunctionDeclaration{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: common.SanitizeSchemaForGemini(t.Function.Parameters),
|
||||
Parameters: t.Function.Parameters,
|
||||
})
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
|
||||
@@ -5,11 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
func TestGeminiProvider_ChatSeparatesThoughtAndToolCall(t *testing.T) {
|
||||
@@ -262,7 +259,7 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.T) {
|
||||
func TestGeminiProvider_BuildRequestBody_PreservesComplexToolSchemasByDefault(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
@@ -315,9 +312,8 @@ func TestGeminiProvider_BuildRequestBody_SanitizesComplexToolSchemas(t *testing.
|
||||
t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters)
|
||||
}
|
||||
|
||||
want := providercommon.SanitizeSchemaForGemini(schema)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
|
||||
if got["$defs"] == nil {
|
||||
t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -291,18 +291,17 @@ func (p *AntigravityProvider) buildRequest(
|
||||
}
|
||||
}
|
||||
|
||||
// Build tools (sanitize schemas for Gemini compatibility)
|
||||
// Build tools
|
||||
if len(tools) > 0 {
|
||||
var funcDecls []antigravityFuncDecl
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
params := common.SanitizeSchemaForGemini(t.Function.Parameters)
|
||||
funcDecls = append(funcDecls, antigravityFuncDecl{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: params,
|
||||
Parameters: t.Function.Parameters,
|
||||
})
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package oauthprovider
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) {
|
||||
@@ -77,7 +74,7 @@ func TestParseSSEResponse_SplitsThoughtAndVisibleContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) {
|
||||
func TestBuildRequest_PreservesComplexToolSchemasByDefault(t *testing.T) {
|
||||
p := &AntigravityProvider{}
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
@@ -135,9 +132,11 @@ func TestBuildRequest_SanitizesComplexToolSchemas(t *testing.T) {
|
||||
t.Fatalf("request tools = %#v, want one function declaration", req.Tools)
|
||||
}
|
||||
|
||||
got := req.Tools[0].FunctionDeclarations[0].Parameters
|
||||
want := providercommon.SanitizeSchemaForGemini(schema)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
|
||||
got, ok := req.Tools[0].FunctionDeclarations[0].Parameters.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters = %#v, want map", req.Tools[0].FunctionDeclarations[0].Parameters)
|
||||
}
|
||||
if got["$defs"] == nil {
|
||||
t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
type toolSchemaTransformProvider struct {
|
||||
delegate LLMProvider
|
||||
transform string
|
||||
}
|
||||
|
||||
type toolSchemaStreamingProvider struct {
|
||||
*toolSchemaTransformProvider
|
||||
}
|
||||
|
||||
func wrapProviderWithToolSchemaTransform(delegate LLMProvider, transform string) (LLMProvider, error) {
|
||||
transform, err := common.NormalizeToolSchemaTransform(transform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if transform == common.ToolSchemaTransformOff || delegate == nil {
|
||||
return delegate, nil
|
||||
}
|
||||
base := &toolSchemaTransformProvider{
|
||||
delegate: delegate,
|
||||
transform: transform,
|
||||
}
|
||||
if _, ok := delegate.(StreamingProvider); ok {
|
||||
return &toolSchemaStreamingProvider{toolSchemaTransformProvider: base}, nil
|
||||
}
|
||||
return base, nil
|
||||
}
|
||||
|
||||
func (p *toolSchemaTransformProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
transformed, err := common.TransformToolDefinitions(tools, p.transform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.delegate.Chat(ctx, messages, transformed, model, options)
|
||||
}
|
||||
|
||||
func (p *toolSchemaTransformProvider) GetDefaultModel() string {
|
||||
return p.delegate.GetDefaultModel()
|
||||
}
|
||||
|
||||
func (p *toolSchemaStreamingProvider) ChatStream(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
onChunk func(accumulated string),
|
||||
) (*LLMResponse, error) {
|
||||
streaming := p.delegate.(StreamingProvider)
|
||||
transformed, err := common.TransformToolDefinitions(tools, p.transform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return streaming.ChatStream(ctx, messages, transformed, model, options, onChunk)
|
||||
}
|
||||
|
||||
func (p *toolSchemaTransformProvider) SupportsThinking() bool {
|
||||
tc, ok := p.delegate.(ThinkingCapable)
|
||||
return ok && tc.SupportsThinking()
|
||||
}
|
||||
|
||||
func (p *toolSchemaTransformProvider) SupportsNativeSearch() bool {
|
||||
ns, ok := p.delegate.(NativeSearchCapable)
|
||||
return ok && ns.SupportsNativeSearch()
|
||||
}
|
||||
|
||||
func (p *toolSchemaTransformProvider) Close() {
|
||||
if stateful, ok := p.delegate.(StatefulProvider); ok {
|
||||
stateful.Close()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
type toolCaptureProvider struct {
|
||||
lastTools []ToolDefinition
|
||||
}
|
||||
|
||||
func (p *toolCaptureProvider) Chat(
|
||||
_ context.Context,
|
||||
_ []Message,
|
||||
tools []ToolDefinition,
|
||||
_ string,
|
||||
_ map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
p.lastTools = tools
|
||||
return &LLMResponse{Content: "ok"}, nil
|
||||
}
|
||||
|
||||
func (p *toolCaptureProvider) GetDefaultModel() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func TestWrapProviderWithToolSchemaTransform_DisabledPassesToolsThrough(t *testing.T) {
|
||||
capture := &toolCaptureProvider{}
|
||||
wrapped, err := wrapProviderWithToolSchemaTransform(capture, "")
|
||||
if err != nil {
|
||||
t.Fatalf("wrapProviderWithToolSchemaTransform() error = %v", err)
|
||||
}
|
||||
|
||||
tools := []ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "noop",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
}}
|
||||
|
||||
_, err = wrapped.Chat(t.Context(), nil, tools, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(capture.lastTools, tools) {
|
||||
t.Fatalf("tools mutated with transform off\n got: %#v\nwant: %#v", capture.lastTools, tools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapProviderWithToolSchemaTransform_GoogleSanitizesSchemas(t *testing.T) {
|
||||
capture := &toolCaptureProvider{}
|
||||
wrapped, err := wrapProviderWithToolSchemaTransform(capture, "google")
|
||||
if err != nil {
|
||||
t.Fatalf("wrapProviderWithToolSchemaTransform() error = %v", err)
|
||||
}
|
||||
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"parent": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/pageParent"},
|
||||
map[string]any{"$ref": "#/$defs/databaseParent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"$defs": map[string]any{
|
||||
"pageParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"page_id": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
"databaseParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"database_id": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tools := []ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "mcp_notion_create",
|
||||
Parameters: schema,
|
||||
},
|
||||
}}
|
||||
|
||||
_, err = wrapped.Chat(t.Context(), nil, tools, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
want := providercommon.SanitizeSchemaForGoogle(schema)
|
||||
got := capture.lastTools[0].Function.Parameters
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("sanitized parameters mismatch\n got: %#v\nwant: %#v", got, want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user