diff --git a/pkg/config/config.go b/pkg/config/config.go index 6694ef3a1..ca0b6cbe7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1029,7 +1029,7 @@ func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) { } // Multiple configs - use round-robin for load balancing - idx := rrCounter.Add(1) % uint64(len(matches)) + idx := (rrCounter.Add(1) - 1) % uint64(len(matches)) return &matches[idx], nil } diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go index da6e506f8..9bc600ed9 100644 --- a/pkg/config/model_config_test.go +++ b/pkg/config/model_config_test.go @@ -80,6 +80,36 @@ func TestGetModelConfig_RoundRobin(t *testing.T) { } } +func TestGetModelConfig_RoundRobinStartsFromFirstMatch(t *testing.T) { + rrCounter.Store(0) + + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"}, + {ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"}, + {ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"}, + }, + } + + wantOrder := []string{ + "openai/gpt-4o-1", + "openai/gpt-4o-2", + "openai/gpt-4o-3", + "openai/gpt-4o-1", + "openai/gpt-4o-2", + } + + for i, want := range wantOrder { + result, err := cfg.GetModelConfig("lb-model") + if err != nil { + t.Fatalf("GetModelConfig() call %d error = %v", i, err) + } + if result.Model != want { + t.Fatalf("GetModelConfig() call %d model = %q, want %q", i, result.Model, want) + } + } +} + func TestGetModelConfig_Concurrent(t *testing.T) { cfg := &Config{ ModelList: []ModelConfig{