From 1e26312cb3ebfb75e1189151f3359fbd597239e6 Mon Sep 17 00:00:00 2001 From: yinwm Date: Thu, 19 Feb 2026 12:45:12 +0800 Subject: [PATCH] feat(config): validate duplicate model names Add validation to ensure model_name is unique across all entries in model_list. This prevents potential conflicts when multiple model configs share the same model_name identifier. --- pkg/config/config.go | 15 ++++++++++++++- pkg/config/model_config_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index c2b5ee01f..0e6063e73 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -294,6 +294,11 @@ func LoadConfig(path string) (*Config, error) { cfg.ModelList = ConvertProvidersToModelList(cfg) } + // Validate model_list for uniqueness and required fields + if err := cfg.ValidateModelList(); err != nil { + return nil, err + } + return cfg, nil } @@ -471,12 +476,20 @@ func (c *Config) HasProvidersConfig() bool { } // ValidateModelList validates all ModelConfig entries in the model_list. -// It checks that each model_name/model combination is valid. +// It checks that each model_name/model combination is valid and that +// model_name is unique across all entries. func (c *Config) ValidateModelList() error { + seen := make(map[string]int) for i := range c.ModelList { if err := c.ModelList[i].Validate(); err != nil { return fmt.Errorf("model_list[%d]: %w", i, err) } + // Check for duplicate model_name + name := c.ModelList[i].ModelName + if prevIdx, exists := seen[name]; exists { + return fmt.Errorf("model_list: duplicate model_name %q at index %d and %d", name, prevIdx, i) + } + seen[name] = i } return nil } diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go index 9d817964a..867e9ebf1 100644 --- a/pkg/config/model_config_test.go +++ b/pkg/config/model_config_test.go @@ -6,6 +6,7 @@ package config import ( + "strings" "sync" "testing" ) @@ -163,6 +164,7 @@ func TestConfig_ValidateModelList(t *testing.T) { name string config *Config wantErr bool + errMsg string // partial error message to check }{ { name: "valid list", @@ -183,6 +185,7 @@ func TestConfig_ValidateModelList(t *testing.T) { }, }, wantErr: true, + errMsg: "model_name is required", }, { name: "empty list", @@ -191,6 +194,29 @@ func TestConfig_ValidateModelList(t *testing.T) { }, wantErr: false, }, + { + name: "duplicate model_name", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4o", APIKey: "key1"}, + {ModelName: "gpt-4", Model: "openai/gpt-4-turbo", APIKey: "key2"}, + }, + }, + wantErr: true, + errMsg: "duplicate model_name", + }, + { + name: "duplicate model_name non-adjacent", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "model-a", Model: "openai/gpt-4o"}, + {ModelName: "model-b", Model: "anthropic/claude"}, + {ModelName: "model-a", Model: "openai/gpt-4-turbo"}, + }, + }, + wantErr: true, + errMsg: "duplicate model_name \"model-a\"", + }, } for _, tt := range tests { @@ -199,6 +225,11 @@ func TestConfig_ValidateModelList(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("ValidateModelList() error = %v, wantErr %v", err, tt.wantErr) } + if err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("ValidateModelList() error = %v, want error containing %q", err, tt.errMsg) + } + } }) } }