mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(model): llm rate limiting (#2198)
* feat(model): rate limiting * fix(agent): preserve per-model identity in rate limiting and fallback * fix test
This commit is contained in:
@@ -165,6 +165,58 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_PreservesDistinctLimiterIdentityForSharedResolvedModel(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "glm-4.7",
|
||||
ModelFallbacks: []string{"glm-4.7__key_1"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "glm-4.7",
|
||||
Model: "zhipu/glm-4.7",
|
||||
RPM: 1,
|
||||
},
|
||||
{
|
||||
ModelName: "glm-4.7__key_1",
|
||||
Model: "zhipu/glm-4.7",
|
||||
RPM: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
if len(agent.Candidates) != 2 {
|
||||
t.Fatalf("len(Candidates) = %d, want 2", len(agent.Candidates))
|
||||
}
|
||||
|
||||
first := agent.Candidates[0]
|
||||
second := agent.Candidates[1]
|
||||
if first.Provider != "zhipu" || first.Model != "glm-4.7" {
|
||||
t.Fatalf("first candidate = %s/%s, want zhipu/glm-4.7", first.Provider, first.Model)
|
||||
}
|
||||
if second.Provider != "zhipu" || second.Model != "glm-4.7" {
|
||||
t.Fatalf("second candidate = %s/%s, want zhipu/glm-4.7", second.Provider, second.Model)
|
||||
}
|
||||
if first.IdentityKey != "model_name:glm-4.7" {
|
||||
t.Fatalf("first identity key = %q, want %q", first.IdentityKey, "model_name:glm-4.7")
|
||||
}
|
||||
if second.IdentityKey != "model_name:glm-4.7__key_1" {
|
||||
t.Fatalf("second identity key = %q, want %q", second.IdentityKey, "model_name:glm-4.7__key_1")
|
||||
}
|
||||
if first.RPM != 1 {
|
||||
t.Fatalf("first RPM = %d, want 1", first.RPM)
|
||||
}
|
||||
if second.RPM != 3 {
|
||||
t.Fatalf("second RPM = %d, want 3", second.RPM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
|
||||
+21
-5
@@ -119,9 +119,18 @@ func NewAgentLoop(
|
||||
) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
// Set up shared fallback chain
|
||||
// Set up shared fallback chain with rate limiting.
|
||||
cooldown := providers.NewCooldownTracker()
|
||||
fallbackChain := providers.NewFallbackChain(cooldown)
|
||||
rl := providers.NewRateLimiterRegistry()
|
||||
// Register rate limiters for all agents' candidates so that RPM limits
|
||||
// configured in ModelConfig are enforced before each LLM call.
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
if agent, ok := registry.GetAgent(agentID); ok {
|
||||
rl.RegisterCandidates(agent.Candidates)
|
||||
rl.RegisterCandidates(agent.LightCandidates)
|
||||
}
|
||||
}
|
||||
fallbackChain := providers.NewFallbackChain(cooldown, rl)
|
||||
|
||||
// Create state manager using default agent's workspace for channel recording
|
||||
defaultAgent := registry.GetDefaultAgent()
|
||||
@@ -1032,8 +1041,15 @@ func (al *AgentLoop) ReloadProviderAndConfig(
|
||||
al.cfg = cfg
|
||||
al.registry = registry
|
||||
|
||||
// Also update fallback chain with new config
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
|
||||
// Also update fallback chain with new config; rebuild rate limiter registry.
|
||||
newRL := providers.NewRateLimiterRegistry()
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
if agent, ok := registry.GetAgent(agentID); ok {
|
||||
newRL.RegisterCandidates(agent.Candidates)
|
||||
newRL.RegisterCandidates(agent.LightCandidates)
|
||||
}
|
||||
}
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL)
|
||||
|
||||
al.mu.Unlock()
|
||||
|
||||
@@ -3229,7 +3245,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
|
||||
}
|
||||
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, value, agent.Fallbacks)
|
||||
if len(nextCandidates) == 0 {
|
||||
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
|
||||
}
|
||||
|
||||
+116
-43
@@ -8,44 +8,102 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
func ensureProtocolModel(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
func modelConfigIdentityKey(mc *config.ModelConfig) string {
|
||||
if mc == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(mc.ModelName); name != "" {
|
||||
return "model_name:" + name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func candidateFromModelConfig(
|
||||
defaultProvider string,
|
||||
mc *config.ModelConfig,
|
||||
) (providers.FallbackCandidate, bool) {
|
||||
if mc == nil {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
return func(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || cfg == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
ref := providers.ParseModelRef(ensureProtocolModel(mc.Model), defaultProvider)
|
||||
if ref == nil {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
return providers.FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
RPM: mc.RPM,
|
||||
IdentityKey: modelConfigIdentityKey(mc),
|
||||
}, true
|
||||
}
|
||||
|
||||
func lookupModelConfigByRef(cfg *config.Config, raw string) *config.ModelConfig {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return mc
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
mc := cfg.ModelList[i]
|
||||
if mc == nil {
|
||||
continue
|
||||
}
|
||||
fullModel := strings.TrimSpace(mc.Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return mc
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return mc
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveModelCandidate(
|
||||
cfg *config.Config,
|
||||
defaultProvider string,
|
||||
raw string,
|
||||
) (providers.FallbackCandidate, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
if mc := lookupModelConfigByRef(cfg, raw); mc != nil {
|
||||
return candidateFromModelConfig(defaultProvider, mc)
|
||||
}
|
||||
|
||||
ref := providers.ParseModelRef(raw, defaultProvider)
|
||||
if ref == nil {
|
||||
return providers.FallbackCandidate{}, false
|
||||
}
|
||||
|
||||
return providers.FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
}, true
|
||||
}
|
||||
|
||||
func resolveModelCandidates(
|
||||
@@ -54,14 +112,29 @@ func resolveModelCandidates(
|
||||
primary string,
|
||||
fallbacks []string,
|
||||
) []providers.FallbackCandidate {
|
||||
return providers.ResolveCandidatesWithLookup(
|
||||
providers.ModelConfig{
|
||||
Primary: primary,
|
||||
Fallbacks: fallbacks,
|
||||
},
|
||||
defaultProvider,
|
||||
buildModelListResolver(cfg),
|
||||
)
|
||||
seen := make(map[string]bool)
|
||||
candidates := make([]providers.FallbackCandidate, 0, 1+len(fallbacks))
|
||||
|
||||
addCandidate := func(raw string) {
|
||||
candidate, ok := resolveModelCandidate(cfg, defaultProvider, raw)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
key := candidate.StableKey()
|
||||
if seen[key] {
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
|
||||
addCandidate(primary)
|
||||
for _, fallback := range fallbacks {
|
||||
addCandidate(fallback)
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
|
||||
@@ -10,12 +10,24 @@ import (
|
||||
// FallbackChain orchestrates model fallback across multiple candidates.
|
||||
type FallbackChain struct {
|
||||
cooldown *CooldownTracker
|
||||
rl *RateLimiterRegistry
|
||||
}
|
||||
|
||||
// FallbackCandidate represents one model/provider to try.
|
||||
type FallbackCandidate struct {
|
||||
Provider string
|
||||
Model string
|
||||
Provider string
|
||||
Model string
|
||||
RPM int // requests per minute; 0 means unrestricted
|
||||
IdentityKey string // optional stable config identity for cooldown/rate limiting
|
||||
}
|
||||
|
||||
// StableKey returns the candidate's config-level identity when available,
|
||||
// otherwise it falls back to the runtime provider/model key.
|
||||
func (c FallbackCandidate) StableKey() string {
|
||||
if key := strings.TrimSpace(c.IdentityKey); key != "" {
|
||||
return key
|
||||
}
|
||||
return ModelKey(c.Provider, c.Model)
|
||||
}
|
||||
|
||||
// FallbackResult contains the successful response and metadata about all attempts.
|
||||
@@ -36,9 +48,10 @@ type FallbackAttempt struct {
|
||||
Skipped bool // true if skipped due to cooldown
|
||||
}
|
||||
|
||||
// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
|
||||
func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
|
||||
return &FallbackChain{cooldown: cooldown}
|
||||
// NewFallbackChain creates a new fallback chain with the given cooldown tracker
|
||||
// and rate limiter registry.
|
||||
func NewFallbackChain(cooldown *CooldownTracker, rl *RateLimiterRegistry) *FallbackChain {
|
||||
return &FallbackChain{cooldown: cooldown, rl: rl}
|
||||
}
|
||||
|
||||
// ResolveCandidates parses model config into a deduplicated candidate list.
|
||||
@@ -117,9 +130,9 @@ func (fc *FallbackChain) Execute(
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Check cooldown (per provider/model, not just provider).
|
||||
// This allows multi-key failover where different keys use different model names.
|
||||
cooldownKey := ModelKey(candidate.Provider, candidate.Model)
|
||||
// Check cooldown per stable candidate identity, not just provider/model.
|
||||
// This allows aliases and multi-key configs to fail over independently.
|
||||
cooldownKey := candidate.StableKey()
|
||||
if !fc.cooldown.IsAvailable(cooldownKey) {
|
||||
remaining := fc.cooldown.CooldownRemaining(cooldownKey)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
@@ -136,6 +149,33 @@ func (fc *FallbackChain) Execute(
|
||||
continue
|
||||
}
|
||||
|
||||
// Enforce per-candidate rate limit before calling the provider.
|
||||
// If this candidate is locally saturated, try other candidates first.
|
||||
if fc.rl != nil {
|
||||
if !fc.rl.TryAcquire(cooldownKey) {
|
||||
if i < len(candidates)-1 {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: fmt.Errorf("%s waiting for local rate limit token", cooldownKey),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if waitErr := fc.rl.Wait(ctx, cooldownKey); waitErr != nil {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: waitErr,
|
||||
})
|
||||
return nil, waitErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the run function.
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
@@ -229,6 +269,34 @@ func (fc *FallbackChain) ExecuteImage(
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Enforce per-candidate rate limit before calling the provider.
|
||||
// If this candidate is locally saturated, try other candidates first.
|
||||
imageKey := candidate.StableKey()
|
||||
if fc.rl != nil {
|
||||
if !fc.rl.TryAcquire(imageKey) {
|
||||
if i < len(candidates)-1 {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: fmt.Errorf("%s waiting for local rate limit token", imageKey),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if waitErr := fc.rl.Wait(ctx, imageKey); waitErr != nil {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: waitErr,
|
||||
})
|
||||
return nil, waitErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestMultiKeyFailover(t *testing.T) {
|
||||
|
||||
// Create fallback chain
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
chain := NewFallbackChain(cooldown, nil)
|
||||
|
||||
// Mock run function: first call fails with 429, second succeeds
|
||||
callCount := 0
|
||||
@@ -82,7 +82,7 @@ func TestMultiKeyFailoverAllFail(t *testing.T) {
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
chain := NewFallbackChain(cooldown, nil)
|
||||
|
||||
// Mock run function: all calls fail with rate limit
|
||||
callCount := 0
|
||||
@@ -127,7 +127,7 @@ func TestMultiKeyFailoverCooldown(t *testing.T) {
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
chain := NewFallbackChain(cooldown, nil)
|
||||
|
||||
// Put the first model in cooldown (using ModelKey now, not just provider)
|
||||
cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model)
|
||||
@@ -183,7 +183,7 @@ func TestMultiKeyFailoverWithFormatError(t *testing.T) {
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
chain := NewFallbackChain(cooldown, nil)
|
||||
|
||||
// Mock run function: first call fails with format error (bad request)
|
||||
callCount := 0
|
||||
@@ -263,7 +263,7 @@ func TestMultiKeyWithModelFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
chain := NewFallbackChain(cooldown, nil)
|
||||
|
||||
// Mock run function: first two fail, third succeeds (model fallback)
|
||||
callCount := 0
|
||||
@@ -337,7 +337,7 @@ func TestMultiKeyFailoverMixedErrors(t *testing.T) {
|
||||
candidates := ResolveCandidates(cfg, "zhipu")
|
||||
|
||||
cooldown := NewCooldownTracker()
|
||||
chain := NewFallbackChain(cooldown)
|
||||
chain := NewFallbackChain(cooldown, nil)
|
||||
|
||||
// Mock run function: different errors for each key
|
||||
callCount := 0
|
||||
|
||||
+102
-16
@@ -19,7 +19,7 @@ func successRun(content string) func(ctx context.Context, provider, model string
|
||||
|
||||
func TestFallback_SingleCandidate_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
|
||||
@@ -36,7 +36,7 @@ func TestFallback_SingleCandidate_Success(t *testing.T) {
|
||||
|
||||
func TestFallback_SecondCandidateSuccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
@@ -69,7 +69,7 @@ func TestFallback_SecondCandidateSuccess(t *testing.T) {
|
||||
|
||||
func TestFallback_AllFail(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
@@ -96,7 +96,7 @@ func TestFallback_AllFail(t *testing.T) {
|
||||
|
||||
func TestFallback_ContextCanceled(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
candidates := []FallbackCandidate{
|
||||
@@ -123,7 +123,7 @@ func TestFallback_ContextCanceled(t *testing.T) {
|
||||
|
||||
func TestFallback_NonRetriableError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
@@ -155,7 +155,7 @@ func TestFallback_NonRetriableError(t *testing.T) {
|
||||
func TestFallback_CooldownSkip(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, _ := newTestTracker(now)
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
// Put openai/gpt-4 in cooldown (using ModelKey now)
|
||||
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
|
||||
@@ -193,7 +193,7 @@ func TestFallback_CooldownSkip(t *testing.T) {
|
||||
|
||||
func TestFallback_AllInCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
// Put all models in cooldown (using ModelKey now)
|
||||
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
|
||||
@@ -221,7 +221,7 @@ func TestFallback_AllInCooldown(t *testing.T) {
|
||||
|
||||
func TestFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
@@ -232,7 +232,7 @@ func TestFallback_NoCandidates(t *testing.T) {
|
||||
func TestFallback_EmptyFallbacks(t *testing.T) {
|
||||
// Single primary, no fallbacks: should work like direct call
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
|
||||
@@ -246,7 +246,7 @@ func TestFallback_EmptyFallbacks(t *testing.T) {
|
||||
|
||||
func TestFallback_UnclassifiedError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
@@ -270,7 +270,7 @@ func TestFallback_UnclassifiedError(t *testing.T) {
|
||||
|
||||
func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
modelKey := ModelKey("openai", "gpt-4")
|
||||
@@ -293,11 +293,78 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertLocalRateLimitSkipsToHealthyFallback(
|
||||
t *testing.T,
|
||||
primaryKey string,
|
||||
fallbackKey string,
|
||||
fallbackProvider string,
|
||||
fallbackModel string,
|
||||
execute func(context.Context, *FallbackChain, []FallbackCandidate,
|
||||
func(context.Context, string, string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error),
|
||||
responseContent string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
ct := NewCooldownTracker()
|
||||
rl := NewRateLimiterRegistry()
|
||||
rl.Register(primaryKey, 1)
|
||||
if err := rl.Wait(context.Background(), primaryKey); err != nil {
|
||||
t.Fatalf("failed to pre-drain primary limiter: %v", err)
|
||||
}
|
||||
|
||||
fc := NewFallbackChain(ct, rl)
|
||||
candidates := []FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-4o", IdentityKey: primaryKey},
|
||||
{Provider: fallbackProvider, Model: fallbackModel, IdentityKey: fallbackKey},
|
||||
}
|
||||
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
if provider != fallbackProvider || model != fallbackModel {
|
||||
t.Fatalf("expected fallback candidate to run, got %s/%s", provider, model)
|
||||
}
|
||||
return &LLMResponse{Content: responseContent, FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
result, err := execute(ctx, fc, candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback success, got error: %v", err)
|
||||
}
|
||||
if result.Provider != fallbackProvider || result.Model != fallbackModel {
|
||||
t.Fatalf("result = %s/%s, want %s/%s", result.Provider, result.Model, fallbackProvider, fallbackModel)
|
||||
}
|
||||
if len(result.Attempts) != 1 || !result.Attempts[0].Skipped {
|
||||
t.Fatalf("expected one skipped primary attempt, got %+v", result.Attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_LocalRateLimitSkipsToHealthyFallback(t *testing.T) {
|
||||
assertLocalRateLimitSkipsToHealthyFallback(
|
||||
t,
|
||||
"model_name:primary",
|
||||
"model_name:fallback",
|
||||
"anthropic",
|
||||
"claude",
|
||||
func(
|
||||
ctx context.Context,
|
||||
fc *FallbackChain,
|
||||
candidates []FallbackCandidate,
|
||||
run func(context.Context, string, string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
return fc.Execute(ctx, candidates, run)
|
||||
},
|
||||
"fallback ok",
|
||||
)
|
||||
}
|
||||
|
||||
// --- Image Fallback Tests ---
|
||||
|
||||
func TestImageFallback_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
|
||||
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
|
||||
@@ -311,7 +378,7 @@ func TestImageFallback_Success(t *testing.T) {
|
||||
|
||||
func TestImageFallback_DimensionError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
@@ -335,7 +402,7 @@ func TestImageFallback_DimensionError(t *testing.T) {
|
||||
|
||||
func TestImageFallback_SizeError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
@@ -359,7 +426,7 @@ func TestImageFallback_SizeError(t *testing.T) {
|
||||
|
||||
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
@@ -384,9 +451,28 @@ func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_LocalRateLimitSkipsToHealthyFallback(t *testing.T) {
|
||||
assertLocalRateLimitSkipsToHealthyFallback(
|
||||
t,
|
||||
"model_name:primary-image",
|
||||
"model_name:fallback-image",
|
||||
"anthropic",
|
||||
"claude-sonnet",
|
||||
func(
|
||||
ctx context.Context,
|
||||
fc *FallbackChain,
|
||||
candidates []FallbackCandidate,
|
||||
run func(context.Context, string, string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
return fc.ExecuteImage(ctx, candidates, run)
|
||||
},
|
||||
"image fallback ok",
|
||||
)
|
||||
}
|
||||
|
||||
func TestImageFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
fc := NewFallbackChain(ct, nil)
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token-bucket rate limiter for a single key.
|
||||
// Allows up to RPM requests per minute with a burst equal to RPM.
|
||||
// Thread-safe.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
rpm int
|
||||
tokens float64
|
||||
maxBurst float64
|
||||
lastTick time.Time
|
||||
nowFunc func() time.Time // for testing
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) refillLocked(now time.Time) {
|
||||
elapsed := now.Sub(rl.lastTick).Seconds()
|
||||
rl.lastTick = now
|
||||
|
||||
// Refill tokens proportional to elapsed time.
|
||||
refill := elapsed * float64(rl.rpm) / 60.0
|
||||
rl.tokens = min(rl.maxBurst, rl.tokens+refill)
|
||||
}
|
||||
|
||||
// newRateLimiter creates a RateLimiter that allows rpm requests/minute.
|
||||
func newRateLimiter(rpm int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
rpm: rpm,
|
||||
tokens: float64(rpm), // start full
|
||||
maxBurst: float64(rpm),
|
||||
lastTick: time.Now(),
|
||||
nowFunc: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// Wait blocks until a token is available or ctx is canceled.
|
||||
// Returns ctx.Err() if canceled while waiting.
|
||||
func (rl *RateLimiter) Wait(ctx context.Context) error {
|
||||
for {
|
||||
rl.mu.Lock()
|
||||
now := rl.nowFunc()
|
||||
rl.refillLocked(now)
|
||||
|
||||
if rl.tokens >= 1.0 {
|
||||
rl.tokens--
|
||||
rl.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calculate how long until a token is available.
|
||||
deficit := 1.0 - rl.tokens
|
||||
waitSec := deficit / (float64(rl.rpm) / 60.0)
|
||||
rl.mu.Unlock()
|
||||
|
||||
timer := time.NewTimer(time.Duration(waitSec * float64(time.Second)))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
// Loop to re-check (another goroutine may have consumed the token).
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TryAcquire attempts to consume a token without blocking.
|
||||
func (rl *RateLimiter) TryAcquire() bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.refillLocked(rl.nowFunc())
|
||||
if rl.tokens < 1.0 {
|
||||
return false
|
||||
}
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// RateLimiterRegistry holds per-candidate rate limiters.
|
||||
// Candidates with RPM=0 are unrestricted.
|
||||
// Thread-safe for concurrent reads/writes.
|
||||
type RateLimiterRegistry struct {
|
||||
mu sync.RWMutex
|
||||
limiters map[string]*RateLimiter
|
||||
}
|
||||
|
||||
// NewRateLimiterRegistry creates an empty registry.
|
||||
func NewRateLimiterRegistry() *RateLimiterRegistry {
|
||||
return &RateLimiterRegistry{
|
||||
limiters: make(map[string]*RateLimiter),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a rate limiter for the given key at the given RPM.
|
||||
// If rpm <= 0, no limiter is registered (unrestricted).
|
||||
func (r *RateLimiterRegistry) Register(key string, rpm int) {
|
||||
if rpm <= 0 {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.limiters[key] = newRateLimiter(rpm)
|
||||
}
|
||||
|
||||
// Wait acquires a token for the given key, blocking if needed.
|
||||
// If no limiter is registered for key, returns immediately.
|
||||
func (r *RateLimiterRegistry) Wait(ctx context.Context, key string) error {
|
||||
r.mu.RLock()
|
||||
rl := r.limiters[key]
|
||||
r.mu.RUnlock()
|
||||
if rl == nil {
|
||||
return nil
|
||||
}
|
||||
return rl.Wait(ctx)
|
||||
}
|
||||
|
||||
// TryAcquire attempts to consume a token for the given key without blocking.
|
||||
// If no limiter is registered for key, it returns true.
|
||||
func (r *RateLimiterRegistry) TryAcquire(key string) bool {
|
||||
r.mu.RLock()
|
||||
rl := r.limiters[key]
|
||||
r.mu.RUnlock()
|
||||
if rl == nil {
|
||||
return true
|
||||
}
|
||||
return rl.TryAcquire()
|
||||
}
|
||||
|
||||
// RegisterCandidates registers rate limiters for all candidates that have RPM > 0.
|
||||
// Candidates with RPM == 0 are ignored (no restriction).
|
||||
func (r *RateLimiterRegistry) RegisterCandidates(candidates []FallbackCandidate) {
|
||||
for _, c := range candidates {
|
||||
if c.RPM > 0 {
|
||||
r.Register(c.StableKey(), c.RPM)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestRateLimiter_AllowsUpToRPM verifies that up to RPM requests pass immediately
|
||||
// (burst capacity) and the (RPM+1)-th request is delayed.
|
||||
func TestRateLimiter_AllowsUpToRPM(t *testing.T) {
|
||||
rpm := 5
|
||||
rl := newRateLimiter(rpm)
|
||||
|
||||
// All rpm tokens should be available immediately (bucket starts full).
|
||||
for i := 0; i < rpm; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
if err := rl.Wait(ctx); err != nil {
|
||||
t.Fatalf("request %d should pass immediately, got: %v", i+1, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
|
||||
// The next request must wait; cancel it to confirm it blocks.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
err := rl.Wait(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected request beyond RPM to block, but it passed immediately")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiter_ContextCancellation verifies that a blocked Wait respects cancellation.
|
||||
func TestRateLimiter_ContextCancellation(t *testing.T) {
|
||||
rl := newRateLimiter(1)
|
||||
|
||||
// Drain the one token.
|
||||
ctx := context.Background()
|
||||
if err := rl.Wait(ctx); err != nil {
|
||||
t.Fatalf("first request failed: %v", err)
|
||||
}
|
||||
|
||||
// Second request should block; cancel it.
|
||||
cancelCtx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
||||
defer cancel()
|
||||
err := rl.Wait(cancelCtx)
|
||||
if err == nil {
|
||||
t.Fatal("expected cancellation error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiter_TokenRefill verifies that tokens refill over time.
|
||||
func TestRateLimiter_TokenRefill(t *testing.T) {
|
||||
rpm := 60 // 1 token per second
|
||||
rl := newRateLimiter(rpm)
|
||||
|
||||
// Drain all tokens.
|
||||
for i := 0; i < rpm; i++ {
|
||||
rl.Wait(context.Background()) //nolint:errcheck
|
||||
}
|
||||
|
||||
// Advance time via nowFunc: simulate 2 seconds passing (should give 2 tokens).
|
||||
start := time.Now()
|
||||
rl.nowFunc = func() time.Time { return start.Add(2 * time.Second) }
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := rl.Wait(ctx); err != nil {
|
||||
t.Fatalf("expected refilled token to be available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiterRegistry_NoLimiter verifies that keys without a registered limiter pass freely.
|
||||
func TestRateLimiterRegistry_NoLimiter(t *testing.T) {
|
||||
r := NewRateLimiterRegistry()
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := r.Wait(ctx, "unregistered/key"); err != nil {
|
||||
t.Fatalf("unregistered key should not block: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiterRegistry_ZeroRPM verifies that RPM=0 means no limiter is registered.
|
||||
func TestRateLimiterRegistry_ZeroRPM(t *testing.T) {
|
||||
r := NewRateLimiterRegistry()
|
||||
r.Register("some/key", 0)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 50; i++ {
|
||||
if err := r.Wait(ctx, "some/key"); err != nil {
|
||||
t.Fatalf("zero-RPM key should not block: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiterRegistry_Enforcement verifies the registry enforces RPM per key.
|
||||
func TestRateLimiterRegistry_Enforcement(t *testing.T) {
|
||||
r := NewRateLimiterRegistry()
|
||||
r.Register("openai/gpt-4o", 3)
|
||||
|
||||
// First 3 calls should pass (burst = RPM).
|
||||
for i := 0; i < 3; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
if err := r.Wait(ctx, "openai/gpt-4o"); err != nil {
|
||||
t.Fatalf("call %d should pass: %v", i+1, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
|
||||
// 4th call should block.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := r.Wait(ctx, "openai/gpt-4o"); err == nil {
|
||||
t.Fatal("4th call should have been rate-limited")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiterRegistry_RegisterCandidates verifies that RegisterCandidates
|
||||
// correctly picks up RPM from FallbackCandidate.
|
||||
func TestRateLimiterRegistry_RegisterCandidates(t *testing.T) {
|
||||
r := NewRateLimiterRegistry()
|
||||
candidates := []FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-4o", RPM: 2},
|
||||
{Provider: "anthropic", Model: "claude-3", RPM: 0}, // no limit
|
||||
}
|
||||
r.RegisterCandidates(candidates)
|
||||
|
||||
// openai/gpt-4o: 2 tokens burst, 3rd should block.
|
||||
for i := 0; i < 2; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
if err := r.Wait(ctx, "openai/gpt-4o"); err != nil {
|
||||
t.Fatalf("openai call %d should pass: %v", i+1, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
if err := r.Wait(ctx, "openai/gpt-4o"); err == nil {
|
||||
t.Fatal("openai 3rd call should have been limited")
|
||||
}
|
||||
|
||||
// anthropic/claude-3: no limit, should always pass.
|
||||
for i := 0; i < 10; i++ {
|
||||
if err := r.Wait(context.Background(), "anthropic/claude-3"); err != nil {
|
||||
t.Fatalf("anthropic call should not be limited: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterRegistry_RegisterCandidatesUsesStableIdentity(t *testing.T) {
|
||||
r := NewRateLimiterRegistry()
|
||||
candidates := []FallbackCandidate{
|
||||
{Provider: "openai", Model: "gpt-4o", RPM: 1, IdentityKey: "model_name:primary"},
|
||||
{Provider: "openai", Model: "gpt-4o", RPM: 2, IdentityKey: "model_name:fallback"},
|
||||
}
|
||||
r.RegisterCandidates(candidates)
|
||||
|
||||
if err := r.Wait(context.Background(), "model_name:primary"); err != nil {
|
||||
t.Fatalf("primary first call should pass: %v", err)
|
||||
}
|
||||
if err := r.Wait(context.Background(), "model_name:fallback"); err != nil {
|
||||
t.Fatalf("fallback first call should pass: %v", err)
|
||||
}
|
||||
if err := r.Wait(context.Background(), "model_name:fallback"); err != nil {
|
||||
t.Fatalf("fallback second call should pass: %v", err)
|
||||
}
|
||||
|
||||
ctxPrimary, cancelPrimary := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancelPrimary()
|
||||
if err := r.Wait(ctxPrimary, "model_name:primary"); err == nil {
|
||||
t.Fatal("primary second call should have been limited")
|
||||
}
|
||||
|
||||
ctxFallback, cancelFallback := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancelFallback()
|
||||
if err := r.Wait(ctxFallback, "model_name:fallback"); err == nil {
|
||||
t.Fatal("fallback third call should have been limited")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiter_Concurrency verifies thread safety under concurrent access.
|
||||
func TestRateLimiter_Concurrency(t *testing.T) {
|
||||
rpm := 20
|
||||
rl := newRateLimiter(rpm)
|
||||
var passed atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Launch 30 goroutines; only ~20 should pass immediately.
|
||||
for i := 0; i < 30; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
if rl.Wait(ctx) == nil {
|
||||
passed.Add(1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
got := passed.Load()
|
||||
// Allow small timing slack: between rpm-2 and rpm+2.
|
||||
if got < int64(rpm-2) || got > int64(rpm+2) {
|
||||
t.Fatalf("expected ~%d immediate passes, got %d", rpm, got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user