From 38e1fe435a1a0431bd44452c50c22bd3f85b1c09 Mon Sep 17 00:00:00 2001 From: Bijin <38134380+sliverp@users.noreply.github.com> Date: Thu, 19 Mar 2026 21:24:46 +0800 Subject: [PATCH] fix(config): model_list inherits api_key/api_base from providers (#1786) When both providers and model_list are configured, model_list entries with empty api_key or api_base now automatically inherit from the matching provider (matched by protocol prefix in the Model field). Example: a model_list entry with model='deepseek/deepseek-chat' and no api_key will inherit from providers.deepseek.api_key. Explicit model_list values always take precedence. Changes: - Add InheritProviderCredentials() in migration.go - Call it in LoadConfig() after provider-to-model-list conversion - Add protocolProviderMapping for all 25 supported protocols - 6 new tests covering inheritance, precedence, and edge cases Closes #1635 --- pkg/config/config.go | 9 +++ pkg/config/migration.go | 81 ++++++++++++++++++++ pkg/config/migration_test.go | 140 +++++++++++++++++++++++++++++++++++ 3 files changed, 230 insertions(+) diff --git a/pkg/config/config.go b/pkg/config/config.go index d226bba51..4f8026d27 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -916,6 +916,15 @@ func LoadConfig(path string) (*Config, error) { cfg.ModelList = ConvertProvidersToModelList(cfg) } + // Inherit credentials from providers to model_list entries (#1635). + // When both providers and model_list are present, model_list entries + // whose api_key/api_base are empty will inherit from the matching + // provider (matched by protocol prefix). Explicit model_list values + // always take precedence. + if cfg.HasProvidersConfig() { + InheritProviderCredentials(cfg.ModelList, cfg.Providers) + } + // Validate model_list for uniqueness and required fields if err := cfg.ValidateModelList(); err != nil { return nil, err diff --git a/pkg/config/migration.go b/pkg/config/migration.go index c7fc214d5..832d8bf17 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -468,3 +468,84 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return result } + +// protocolProviderMapping maps a model protocol prefix (the part before "/" in +// the Model field) to a function that extracts the corresponding ProviderConfig +// from the legacy ProvidersConfig. Used by InheritProviderCredentials. +var protocolProviderMapping = map[string]func(p ProvidersConfig) ProviderConfig{ + "openai": func(p ProvidersConfig) ProviderConfig { return p.OpenAI.ProviderConfig }, + "anthropic": func(p ProvidersConfig) ProviderConfig { return p.Anthropic }, + "litellm": func(p ProvidersConfig) ProviderConfig { return p.LiteLLM }, + "openrouter": func(p ProvidersConfig) ProviderConfig { return p.OpenRouter }, + "groq": func(p ProvidersConfig) ProviderConfig { return p.Groq }, + "zhipu": func(p ProvidersConfig) ProviderConfig { return p.Zhipu }, + "vllm": func(p ProvidersConfig) ProviderConfig { return p.VLLM }, + "gemini": func(p ProvidersConfig) ProviderConfig { return p.Gemini }, + "nvidia": func(p ProvidersConfig) ProviderConfig { return p.Nvidia }, + "ollama": func(p ProvidersConfig) ProviderConfig { return p.Ollama }, + "moonshot": func(p ProvidersConfig) ProviderConfig { return p.Moonshot }, + "shengsuanyun": func(p ProvidersConfig) ProviderConfig { return p.ShengSuanYun }, + "deepseek": func(p ProvidersConfig) ProviderConfig { return p.DeepSeek }, + "cerebras": func(p ProvidersConfig) ProviderConfig { return p.Cerebras }, + "vivgrid": func(p ProvidersConfig) ProviderConfig { return p.Vivgrid }, + "volcengine": func(p ProvidersConfig) ProviderConfig { return p.VolcEngine }, + "github-copilot": func(p ProvidersConfig) ProviderConfig { return p.GitHubCopilot }, + "antigravity": func(p ProvidersConfig) ProviderConfig { return p.Antigravity }, + "qwen": func(p ProvidersConfig) ProviderConfig { return p.Qwen }, + "mistral": func(p ProvidersConfig) ProviderConfig { return p.Mistral }, + "avian": func(p ProvidersConfig) ProviderConfig { return p.Avian }, + "minimax": func(p ProvidersConfig) ProviderConfig { return p.Minimax }, + "longcat": func(p ProvidersConfig) ProviderConfig { return p.LongCat }, + "modelscope": func(p ProvidersConfig) ProviderConfig { return p.ModelScope }, + "novita": func(p ProvidersConfig) ProviderConfig { return p.Novita }, +} + +// InheritProviderCredentials fills in missing api_key, api_base, proxy, and +// request_timeout on model_list entries from the matching legacy providers +// configuration. The match is determined by the protocol prefix in the Model +// field (e.g. "deepseek/deepseek-chat" matches providers.deepseek). +// +// Only empty fields are filled — any value explicitly set on a model_list entry +// takes precedence. This function modifies the slice in place. +// +// This bridges the gap described in issue #1635: users who configure +// credentials once in the providers section expect model_list entries using +// the same protocol to "just work" without duplicating credentials. +func InheritProviderCredentials(models []ModelConfig, providers ProvidersConfig) { + if providers.IsEmpty() { + return + } + + for i := range models { + m := &models[i] + + // Extract protocol prefix from Model field + protocol := "" + if idx := strings.Index(m.Model, "/"); idx > 0 { + protocol = strings.ToLower(m.Model[:idx]) + } + if protocol == "" { + continue + } + + getProvider, ok := protocolProviderMapping[protocol] + if !ok { + continue + } + pc := getProvider(providers) + + // Only fill empty fields — explicit model_list values win + if m.APIKey == "" && pc.APIKey != "" { + m.APIKey = pc.APIKey + } + if m.APIBase == "" && pc.APIBase != "" { + m.APIBase = pc.APIBase + } + if m.Proxy == "" && pc.Proxy != "" { + m.Proxy = pc.Proxy + } + if m.RequestTimeout == 0 && pc.RequestTimeout != 0 { + m.RequestTimeout = pc.RequestTimeout + } + } +} diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index 1b6e5b032..bea5b9034 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -613,3 +613,143 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto") } } + +// ---------- InheritProviderCredentials tests ---------- + +func TestInheritProviderCredentials_FillsMissingAPIKey(t *testing.T) { + models := []ModelConfig{ + {ModelName: "my-deepseek", Model: "deepseek/deepseek-chat"}, + } + providers := ProvidersConfig{ + DeepSeek: ProviderConfig{ + APIKey: "sk-deepseek-from-providers", + APIBase: "https://api.deepseek.com/v1", + }, + } + + InheritProviderCredentials(models, providers) + + if models[0].APIKey != "sk-deepseek-from-providers" { + t.Errorf("APIKey = %q, want %q", models[0].APIKey, "sk-deepseek-from-providers") + } + if models[0].APIBase != "https://api.deepseek.com/v1" { + t.Errorf("APIBase = %q, want %q", models[0].APIBase, "https://api.deepseek.com/v1") + } +} + +func TestInheritProviderCredentials_ExplicitValuesTakePrecedence(t *testing.T) { + models := []ModelConfig{ + { + ModelName: "my-openai", + Model: "openai/gpt-5.4", + APIKey: "sk-explicit-model-key", + APIBase: "https://my-custom-endpoint.com/v1", + }, + } + providers := ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ + ProviderConfig: ProviderConfig{ + APIKey: "sk-provider-key", + APIBase: "https://api.openai.com/v1", + }, + }, + } + + InheritProviderCredentials(models, providers) + + if models[0].APIKey != "sk-explicit-model-key" { + t.Errorf("APIKey = %q, want %q (explicit should win)", models[0].APIKey, "sk-explicit-model-key") + } + if models[0].APIBase != "https://my-custom-endpoint.com/v1" { + t.Errorf("APIBase = %q, want %q (explicit should win)", models[0].APIBase, "https://my-custom-endpoint.com/v1") + } +} + +func TestInheritProviderCredentials_MultipleModels(t *testing.T) { + models := []ModelConfig{ + {ModelName: "groq-llama", Model: "groq/llama-3.1-70b"}, + {ModelName: "zhipu-glm", Model: "zhipu/glm-4"}, + {ModelName: "custom-openai", Model: "openai/gpt-5.4", APIKey: "sk-already-set"}, + } + providers := ProvidersConfig{ + Groq: ProviderConfig{APIKey: "gsk-groq-key", Proxy: "http://proxy:8080"}, + Zhipu: ProviderConfig{APIKey: "zhipu-key-123", APIBase: "https://zhipu.example.com"}, + OpenAI: OpenAIProviderConfig{ + ProviderConfig: ProviderConfig{APIKey: "sk-should-not-override"}, + }, + } + + InheritProviderCredentials(models, providers) + + // groq model should inherit + if models[0].APIKey != "gsk-groq-key" { + t.Errorf("groq APIKey = %q, want %q", models[0].APIKey, "gsk-groq-key") + } + if models[0].Proxy != "http://proxy:8080" { + t.Errorf("groq Proxy = %q, want %q", models[0].Proxy, "http://proxy:8080") + } + + // zhipu model should inherit + if models[1].APIKey != "zhipu-key-123" { + t.Errorf("zhipu APIKey = %q, want %q", models[1].APIKey, "zhipu-key-123") + } + if models[1].APIBase != "https://zhipu.example.com" { + t.Errorf("zhipu APIBase = %q, want %q", models[1].APIBase, "https://zhipu.example.com") + } + + // openai model already has key — should NOT be overridden + if models[2].APIKey != "sk-already-set" { + t.Errorf("openai APIKey = %q, want %q (should not be overridden)", models[2].APIKey, "sk-already-set") + } +} + +func TestInheritProviderCredentials_NoMatchingProvider(t *testing.T) { + models := []ModelConfig{ + {ModelName: "my-model", Model: "novelai/some-model"}, + } + providers := ProvidersConfig{ + DeepSeek: ProviderConfig{APIKey: "sk-deepseek"}, + } + + InheritProviderCredentials(models, providers) + + // No matching provider for "novelai" protocol — should stay empty + if models[0].APIKey != "" { + t.Errorf("APIKey = %q, want empty (no matching provider)", models[0].APIKey) + } +} + +func TestInheritProviderCredentials_EmptyProviders(t *testing.T) { + models := []ModelConfig{ + {ModelName: "my-model", Model: "openai/gpt-5.4"}, + } + providers := ProvidersConfig{} // all empty + + InheritProviderCredentials(models, providers) + + // Empty providers — nothing to inherit + if models[0].APIKey != "" { + t.Errorf("APIKey = %q, want empty", models[0].APIKey) + } +} + +func TestInheritProviderCredentials_InheritsRequestTimeout(t *testing.T) { + models := []ModelConfig{ + {ModelName: "my-ollama", Model: "ollama/llama3.2:3b"}, + } + providers := ProvidersConfig{ + Ollama: ProviderConfig{ + APIBase: "http://localhost:11434", + RequestTimeout: 120, + }, + } + + InheritProviderCredentials(models, providers) + + if models[0].APIBase != "http://localhost:11434" { + t.Errorf("APIBase = %q, want %q", models[0].APIBase, "http://localhost:11434") + } + if models[0].RequestTimeout != 120 { + t.Errorf("RequestTimeout = %d, want 120", models[0].RequestTimeout) + } +}