diff --git a/docs/rate-limiting.md b/docs/rate-limiting.md new file mode 100644 index 000000000..b54c757f8 --- /dev/null +++ b/docs/rate-limiting.md @@ -0,0 +1,95 @@ +# Dynamic Rate Limiting + +PicoClaw prevents 429 errors from LLM provider APIs by enforcing configurable per-model request-rate limits **before** sending each request. Unlike the reactive cooldown/fallback system (which activates *after* a 429 is received), rate limiting is **proactive**: it keeps outbound QPS within the provider's free-tier or plan limits. + +## How it works + +### Token-bucket algorithm + +Each rate-limited model gets a token bucket: + +- **Capacity** = `rpm` (burst size equals the per-minute limit) +- **Refill rate** = `rpm / 60` tokens per second +- Tokens are consumed one per LLM call; if the bucket is empty, the call blocks until a token refills or the request context is cancelled + +### Call chain integration + +``` +AgentLoop.callLLM() + └─ FallbackChain.Execute() ← iterate candidates + ├─ CooldownTracker.IsAvailable() ← skip if post-429 cooldown active + ├─ RateLimiterRegistry.Wait() ← NEW: block until token available + └─ provider.Chat() ← actual LLM HTTP call +``` + +The rate limiter runs **after** the cooldown check and **before** the provider call, so: +- Candidates already in cooldown are skipped entirely (no token consumed) +- Candidates that are available get throttled to the configured RPM + +The same check applies in `ExecuteImage`. + +### Thread safety + +`RateLimiterRegistry` is safe for concurrent use. The per-limiter token bucket uses a fine-grained mutex so concurrent goroutines each acquire their own token independently. + +## Configuration + +Set `rpm` on any model in `model_list`: + +```yaml +model_list: + - model_name: gpt-4o-free + model: openai/gpt-4o + api_base: https://api.openai.com/v1 + rpm: 3 # max 3 requests per minute + api_keys: + - sk-... + + - model_name: claude-haiku + model: anthropic/claude-haiku-4-5 + rpm: 60 # 60 rpm (Anthropic free tier) + api_keys: + - sk-ant-... + + - model_name: local-llm + model: openai/llama3 + api_base: http://localhost:11434/v1 + # no rpm → unrestricted +``` + +| Field | Type | Default | Description | +|---|---|---|---| +| `rpm` | `int` | `0` | Requests per minute. `0` means no limit. | + +### Interaction with fallbacks + +When a model has fallbacks configured, each candidate is rate-limited **independently**: + +```yaml +model_list: + - model_name: gpt4-with-fallback + model: openai/gpt-4o + rpm: 5 + fallbacks: + - gpt-4o-mini # must also be in model_list; its own rpm applies +``` + +If the current candidate's bucket is empty and there are more candidates available, PicoClaw skips the locally saturated candidate and tries the next fallback immediately. Only the last remaining candidate waits for a token to refill. If the context deadline is hit while waiting on that last candidate, the wait error propagates. + +For `model_list` aliases that resolve to the same underlying provider/model, rate limiting is keyed by the stable config identity (for example `model_name`) rather than the resolved runtime model string. This preserves distinct RPM settings for multi-key and alias-based configurations. + +### Burst behaviour + +The bucket starts **full** (burst = RPM). For `rpm: 3`, the first 3 requests fire instantly; subsequent requests are spaced ~20 s apart. + +To reduce burstiness for strict APIs, set a lower `rpm` and rely on the steady-state refill. + +## Files changed + +| File | What | +|---|---| +| `pkg/providers/ratelimiter.go` | `RateLimiter` (token bucket) + `RateLimiterRegistry` | +| `pkg/providers/ratelimiter_test.go` | Unit tests for limiter and registry | +| `pkg/providers/fallback.go` | `FallbackCandidate.RPM` field; `FallbackChain.rl`; `Wait()` call in `Execute`/`ExecuteImage` | +| `pkg/agent/model_resolution.go` | Resolves candidates from `model_list`, preserving stable config identity and propagating `RPM` into `FallbackCandidate` | +| `pkg/agent/loop.go` | Build `RateLimiterRegistry`, register all agents' candidates, pass to `NewFallbackChain` | diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index 7c043d88f..ba907e88b 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -165,6 +165,58 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) { } } +func TestNewAgentInstance_PreservesDistinctLimiterIdentityForSharedResolvedModel(t *testing.T) { + tmpDir := t.TempDir() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "glm-4.7", + ModelFallbacks: []string{"glm-4.7__key_1"}, + }, + }, + ModelList: []*config.ModelConfig{ + { + ModelName: "glm-4.7", + Model: "zhipu/glm-4.7", + RPM: 1, + }, + { + ModelName: "glm-4.7__key_1", + Model: "zhipu/glm-4.7", + RPM: 3, + }, + }, + } + + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{}) + if len(agent.Candidates) != 2 { + t.Fatalf("len(Candidates) = %d, want 2", len(agent.Candidates)) + } + + first := agent.Candidates[0] + second := agent.Candidates[1] + if first.Provider != "zhipu" || first.Model != "glm-4.7" { + t.Fatalf("first candidate = %s/%s, want zhipu/glm-4.7", first.Provider, first.Model) + } + if second.Provider != "zhipu" || second.Model != "glm-4.7" { + t.Fatalf("second candidate = %s/%s, want zhipu/glm-4.7", second.Provider, second.Model) + } + if first.IdentityKey != "model_name:glm-4.7" { + t.Fatalf("first identity key = %q, want %q", first.IdentityKey, "model_name:glm-4.7") + } + if second.IdentityKey != "model_name:glm-4.7__key_1" { + t.Fatalf("second identity key = %q, want %q", second.IdentityKey, "model_name:glm-4.7__key_1") + } + if first.RPM != 1 { + t.Fatalf("first RPM = %d, want 1", first.RPM) + } + if second.RPM != 3 { + t.Fatalf("second RPM = %d, want 3", second.RPM) + } +} + func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) { workspace := t.TempDir() mediaDir := media.TempDir() diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 624ff261b..808d12c07 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -119,9 +119,18 @@ func NewAgentLoop( ) *AgentLoop { registry := NewAgentRegistry(cfg, provider) - // Set up shared fallback chain + // Set up shared fallback chain with rate limiting. cooldown := providers.NewCooldownTracker() - fallbackChain := providers.NewFallbackChain(cooldown) + rl := providers.NewRateLimiterRegistry() + // Register rate limiters for all agents' candidates so that RPM limits + // configured in ModelConfig are enforced before each LLM call. + for _, agentID := range registry.ListAgentIDs() { + if agent, ok := registry.GetAgent(agentID); ok { + rl.RegisterCandidates(agent.Candidates) + rl.RegisterCandidates(agent.LightCandidates) + } + } + fallbackChain := providers.NewFallbackChain(cooldown, rl) // Create state manager using default agent's workspace for channel recording defaultAgent := registry.GetDefaultAgent() @@ -1032,8 +1041,15 @@ func (al *AgentLoop) ReloadProviderAndConfig( al.cfg = cfg al.registry = registry - // Also update fallback chain with new config - al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker()) + // Also update fallback chain with new config; rebuild rate limiter registry. + newRL := providers.NewRateLimiterRegistry() + for _, agentID := range registry.ListAgentIDs() { + if agent, ok := registry.GetAgent(agentID); ok { + newRL.RegisterCandidates(agent.Candidates) + newRL.RegisterCandidates(agent.LightCandidates) + } + } + al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL) al.mu.Unlock() @@ -3229,7 +3245,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return "", fmt.Errorf("failed to initialize model %q: %w", value, err) } - nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks) + nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, value, agent.Fallbacks) if len(nextCandidates) == 0 { return "", fmt.Errorf("model %q did not resolve to any provider candidates", value) } diff --git a/pkg/agent/model_resolution.go b/pkg/agent/model_resolution.go index 140cff718..7cbf3a8d6 100644 --- a/pkg/agent/model_resolution.go +++ b/pkg/agent/model_resolution.go @@ -8,44 +8,102 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" ) -func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) { - ensureProtocol := func(model string) string { - model = strings.TrimSpace(model) - if model == "" { - return "" - } - if strings.Contains(model, "/") { - return model - } - return "openai/" + model +func ensureProtocolModel(model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "" + } + if strings.Contains(model, "/") { + return model + } + return "openai/" + model +} + +func modelConfigIdentityKey(mc *config.ModelConfig) string { + if mc == nil { + return "" + } + if name := strings.TrimSpace(mc.ModelName); name != "" { + return "model_name:" + name + } + return "" +} + +func candidateFromModelConfig( + defaultProvider string, + mc *config.ModelConfig, +) (providers.FallbackCandidate, bool) { + if mc == nil { + return providers.FallbackCandidate{}, false } - return func(raw string) (string, bool) { - raw = strings.TrimSpace(raw) - if raw == "" || cfg == nil { - return "", false - } - - if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" { - return ensureProtocol(mc.Model), true - } - - for i := range cfg.ModelList { - fullModel := strings.TrimSpace(cfg.ModelList[i].Model) - if fullModel == "" { - continue - } - if fullModel == raw { - return ensureProtocol(fullModel), true - } - _, modelID := providers.ExtractProtocol(fullModel) - if modelID == raw { - return ensureProtocol(fullModel), true - } - } - - return "", false + ref := providers.ParseModelRef(ensureProtocolModel(mc.Model), defaultProvider) + if ref == nil { + return providers.FallbackCandidate{}, false } + + return providers.FallbackCandidate{ + Provider: ref.Provider, + Model: ref.Model, + RPM: mc.RPM, + IdentityKey: modelConfigIdentityKey(mc), + }, true +} + +func lookupModelConfigByRef(cfg *config.Config, raw string) *config.ModelConfig { + raw = strings.TrimSpace(raw) + if raw == "" || cfg == nil { + return nil + } + + if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" { + return mc + } + + for i := range cfg.ModelList { + mc := cfg.ModelList[i] + if mc == nil { + continue + } + fullModel := strings.TrimSpace(mc.Model) + if fullModel == "" { + continue + } + if fullModel == raw { + return mc + } + _, modelID := providers.ExtractProtocol(fullModel) + if modelID == raw { + return mc + } + } + + return nil +} + +func resolveModelCandidate( + cfg *config.Config, + defaultProvider string, + raw string, +) (providers.FallbackCandidate, bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return providers.FallbackCandidate{}, false + } + + if mc := lookupModelConfigByRef(cfg, raw); mc != nil { + return candidateFromModelConfig(defaultProvider, mc) + } + + ref := providers.ParseModelRef(raw, defaultProvider) + if ref == nil { + return providers.FallbackCandidate{}, false + } + + return providers.FallbackCandidate{ + Provider: ref.Provider, + Model: ref.Model, + }, true } func resolveModelCandidates( @@ -54,14 +112,29 @@ func resolveModelCandidates( primary string, fallbacks []string, ) []providers.FallbackCandidate { - return providers.ResolveCandidatesWithLookup( - providers.ModelConfig{ - Primary: primary, - Fallbacks: fallbacks, - }, - defaultProvider, - buildModelListResolver(cfg), - ) + seen := make(map[string]bool) + candidates := make([]providers.FallbackCandidate, 0, 1+len(fallbacks)) + + addCandidate := func(raw string) { + candidate, ok := resolveModelCandidate(cfg, defaultProvider, raw) + if !ok { + return + } + + key := candidate.StableKey() + if seen[key] { + return + } + seen[key] = true + candidates = append(candidates, candidate) + } + + addCandidate(primary) + for _, fallback := range fallbacks { + addCandidate(fallback) + } + + return candidates } func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string { diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go index 549ec7837..36092105b 100644 --- a/pkg/providers/fallback.go +++ b/pkg/providers/fallback.go @@ -10,12 +10,24 @@ import ( // FallbackChain orchestrates model fallback across multiple candidates. type FallbackChain struct { cooldown *CooldownTracker + rl *RateLimiterRegistry } // FallbackCandidate represents one model/provider to try. type FallbackCandidate struct { - Provider string - Model string + Provider string + Model string + RPM int // requests per minute; 0 means unrestricted + IdentityKey string // optional stable config identity for cooldown/rate limiting +} + +// StableKey returns the candidate's config-level identity when available, +// otherwise it falls back to the runtime provider/model key. +func (c FallbackCandidate) StableKey() string { + if key := strings.TrimSpace(c.IdentityKey); key != "" { + return key + } + return ModelKey(c.Provider, c.Model) } // FallbackResult contains the successful response and metadata about all attempts. @@ -36,9 +48,10 @@ type FallbackAttempt struct { Skipped bool // true if skipped due to cooldown } -// NewFallbackChain creates a new fallback chain with the given cooldown tracker. -func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain { - return &FallbackChain{cooldown: cooldown} +// NewFallbackChain creates a new fallback chain with the given cooldown tracker +// and rate limiter registry. +func NewFallbackChain(cooldown *CooldownTracker, rl *RateLimiterRegistry) *FallbackChain { + return &FallbackChain{cooldown: cooldown, rl: rl} } // ResolveCandidates parses model config into a deduplicated candidate list. @@ -117,9 +130,9 @@ func (fc *FallbackChain) Execute( return nil, context.Canceled } - // Check cooldown (per provider/model, not just provider). - // This allows multi-key failover where different keys use different model names. - cooldownKey := ModelKey(candidate.Provider, candidate.Model) + // Check cooldown per stable candidate identity, not just provider/model. + // This allows aliases and multi-key configs to fail over independently. + cooldownKey := candidate.StableKey() if !fc.cooldown.IsAvailable(cooldownKey) { remaining := fc.cooldown.CooldownRemaining(cooldownKey) result.Attempts = append(result.Attempts, FallbackAttempt{ @@ -136,6 +149,33 @@ func (fc *FallbackChain) Execute( continue } + // Enforce per-candidate rate limit before calling the provider. + // If this candidate is locally saturated, try other candidates first. + if fc.rl != nil { + if !fc.rl.TryAcquire(cooldownKey) { + if i < len(candidates)-1 { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: fmt.Errorf("%s waiting for local rate limit token", cooldownKey), + }) + continue + } + if waitErr := fc.rl.Wait(ctx, cooldownKey); waitErr != nil { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: waitErr, + }) + return nil, waitErr + } + } + } + // Execute the run function. start := time.Now() resp, err := run(ctx, candidate.Provider, candidate.Model) @@ -229,6 +269,34 @@ func (fc *FallbackChain) ExecuteImage( return nil, context.Canceled } + // Enforce per-candidate rate limit before calling the provider. + // If this candidate is locally saturated, try other candidates first. + imageKey := candidate.StableKey() + if fc.rl != nil { + if !fc.rl.TryAcquire(imageKey) { + if i < len(candidates)-1 { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: fmt.Errorf("%s waiting for local rate limit token", imageKey), + }) + continue + } + if waitErr := fc.rl.Wait(ctx, imageKey); waitErr != nil { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: waitErr, + }) + return nil, waitErr + } + } + } + start := time.Now() resp, err := run(ctx, candidate.Provider, candidate.Model) elapsed := time.Since(start) diff --git a/pkg/providers/fallback_multikey_test.go b/pkg/providers/fallback_multikey_test.go index 9ed8fa73c..10481ec61 100644 --- a/pkg/providers/fallback_multikey_test.go +++ b/pkg/providers/fallback_multikey_test.go @@ -25,7 +25,7 @@ func TestMultiKeyFailover(t *testing.T) { // Create fallback chain cooldown := NewCooldownTracker() - chain := NewFallbackChain(cooldown) + chain := NewFallbackChain(cooldown, nil) // Mock run function: first call fails with 429, second succeeds callCount := 0 @@ -82,7 +82,7 @@ func TestMultiKeyFailoverAllFail(t *testing.T) { candidates := ResolveCandidates(cfg, "zhipu") cooldown := NewCooldownTracker() - chain := NewFallbackChain(cooldown) + chain := NewFallbackChain(cooldown, nil) // Mock run function: all calls fail with rate limit callCount := 0 @@ -127,7 +127,7 @@ func TestMultiKeyFailoverCooldown(t *testing.T) { candidates := ResolveCandidates(cfg, "zhipu") cooldown := NewCooldownTracker() - chain := NewFallbackChain(cooldown) + chain := NewFallbackChain(cooldown, nil) // Put the first model in cooldown (using ModelKey now, not just provider) cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model) @@ -183,7 +183,7 @@ func TestMultiKeyFailoverWithFormatError(t *testing.T) { candidates := ResolveCandidates(cfg, "zhipu") cooldown := NewCooldownTracker() - chain := NewFallbackChain(cooldown) + chain := NewFallbackChain(cooldown, nil) // Mock run function: first call fails with format error (bad request) callCount := 0 @@ -263,7 +263,7 @@ func TestMultiKeyWithModelFallback(t *testing.T) { } cooldown := NewCooldownTracker() - chain := NewFallbackChain(cooldown) + chain := NewFallbackChain(cooldown, nil) // Mock run function: first two fail, third succeeds (model fallback) callCount := 0 @@ -337,7 +337,7 @@ func TestMultiKeyFailoverMixedErrors(t *testing.T) { candidates := ResolveCandidates(cfg, "zhipu") cooldown := NewCooldownTracker() - chain := NewFallbackChain(cooldown) + chain := NewFallbackChain(cooldown, nil) // Mock run function: different errors for each key callCount := 0 diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go index 1a1118e33..54fb9b6ea 100644 --- a/pkg/providers/fallback_test.go +++ b/pkg/providers/fallback_test.go @@ -19,7 +19,7 @@ func successRun(content string) func(ctx context.Context, provider, model string func TestFallback_SingleCandidate_Success(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} result, err := fc.Execute(context.Background(), candidates, successRun("hello")) @@ -36,7 +36,7 @@ func TestFallback_SingleCandidate_Success(t *testing.T) { func TestFallback_SecondCandidateSuccess(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -69,7 +69,7 @@ func TestFallback_SecondCandidateSuccess(t *testing.T) { func TestFallback_AllFail(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -96,7 +96,7 @@ func TestFallback_AllFail(t *testing.T) { func TestFallback_ContextCanceled(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) ctx, cancel := context.WithCancel(context.Background()) candidates := []FallbackCandidate{ @@ -123,7 +123,7 @@ func TestFallback_ContextCanceled(t *testing.T) { func TestFallback_NonRetriableError(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -155,7 +155,7 @@ func TestFallback_NonRetriableError(t *testing.T) { func TestFallback_CooldownSkip(t *testing.T) { now := time.Now() ct, _ := newTestTracker(now) - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) // Put openai/gpt-4 in cooldown (using ModelKey now) ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) @@ -193,7 +193,7 @@ func TestFallback_CooldownSkip(t *testing.T) { func TestFallback_AllInCooldown(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) // Put all models in cooldown (using ModelKey now) ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit) @@ -221,7 +221,7 @@ func TestFallback_AllInCooldown(t *testing.T) { func TestFallback_NoCandidates(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) _, err := fc.Execute(context.Background(), nil, successRun("ok")) if err == nil { @@ -232,7 +232,7 @@ func TestFallback_NoCandidates(t *testing.T) { func TestFallback_EmptyFallbacks(t *testing.T) { // Single primary, no fallbacks: should work like direct call ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} result, err := fc.Execute(context.Background(), candidates, successRun("ok")) @@ -246,7 +246,7 @@ func TestFallback_EmptyFallbacks(t *testing.T) { func TestFallback_UnclassifiedError(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4"), @@ -270,7 +270,7 @@ func TestFallback_UnclassifiedError(t *testing.T) { func TestFallback_SuccessResetsCooldown(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} modelKey := ModelKey("openai", "gpt-4") @@ -293,11 +293,78 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) { } } +func assertLocalRateLimitSkipsToHealthyFallback( + t *testing.T, + primaryKey string, + fallbackKey string, + fallbackProvider string, + fallbackModel string, + execute func(context.Context, *FallbackChain, []FallbackCandidate, + func(context.Context, string, string) (*LLMResponse, error), + ) (*FallbackResult, error), + responseContent string, +) { + t.Helper() + + ct := NewCooldownTracker() + rl := NewRateLimiterRegistry() + rl.Register(primaryKey, 1) + if err := rl.Wait(context.Background(), primaryKey); err != nil { + t.Fatalf("failed to pre-drain primary limiter: %v", err) + } + + fc := NewFallbackChain(ct, rl) + candidates := []FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", IdentityKey: primaryKey}, + {Provider: fallbackProvider, Model: fallbackModel, IdentityKey: fallbackKey}, + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + if provider != fallbackProvider || model != fallbackModel { + t.Fatalf("expected fallback candidate to run, got %s/%s", provider, model) + } + return &LLMResponse{Content: responseContent, FinishReason: "stop"}, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + + result, err := execute(ctx, fc, candidates, run) + if err != nil { + t.Fatalf("expected fallback success, got error: %v", err) + } + if result.Provider != fallbackProvider || result.Model != fallbackModel { + t.Fatalf("result = %s/%s, want %s/%s", result.Provider, result.Model, fallbackProvider, fallbackModel) + } + if len(result.Attempts) != 1 || !result.Attempts[0].Skipped { + t.Fatalf("expected one skipped primary attempt, got %+v", result.Attempts) + } +} + +func TestFallback_LocalRateLimitSkipsToHealthyFallback(t *testing.T) { + assertLocalRateLimitSkipsToHealthyFallback( + t, + "model_name:primary", + "model_name:fallback", + "anthropic", + "claude", + func( + ctx context.Context, + fc *FallbackChain, + candidates []FallbackCandidate, + run func(context.Context, string, string) (*LLMResponse, error), + ) (*FallbackResult, error) { + return fc.Execute(ctx, candidates, run) + }, + "fallback ok", + ) +} + // --- Image Fallback Tests --- func TestImageFallback_Success(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")} result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result")) @@ -311,7 +378,7 @@ func TestImageFallback_Success(t *testing.T) { func TestImageFallback_DimensionError(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4o"), @@ -335,7 +402,7 @@ func TestImageFallback_DimensionError(t *testing.T) { func TestImageFallback_SizeError(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4o"), @@ -359,7 +426,7 @@ func TestImageFallback_SizeError(t *testing.T) { func TestImageFallback_RetryOnOtherErrors(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) candidates := []FallbackCandidate{ makeCandidate("openai", "gpt-4o"), @@ -384,9 +451,28 @@ func TestImageFallback_RetryOnOtherErrors(t *testing.T) { } } +func TestImageFallback_LocalRateLimitSkipsToHealthyFallback(t *testing.T) { + assertLocalRateLimitSkipsToHealthyFallback( + t, + "model_name:primary-image", + "model_name:fallback-image", + "anthropic", + "claude-sonnet", + func( + ctx context.Context, + fc *FallbackChain, + candidates []FallbackCandidate, + run func(context.Context, string, string) (*LLMResponse, error), + ) (*FallbackResult, error) { + return fc.ExecuteImage(ctx, candidates, run) + }, + "image fallback ok", + ) +} + func TestImageFallback_NoCandidates(t *testing.T) { ct := NewCooldownTracker() - fc := NewFallbackChain(ct) + fc := NewFallbackChain(ct, nil) _, err := fc.ExecuteImage(context.Background(), nil, successRun("ok")) if err == nil { diff --git a/pkg/providers/ratelimiter.go b/pkg/providers/ratelimiter.go new file mode 100644 index 000000000..f475b58fb --- /dev/null +++ b/pkg/providers/ratelimiter.go @@ -0,0 +1,144 @@ +package providers + +import ( + "context" + "sync" + "time" +) + +// RateLimiter implements a token-bucket rate limiter for a single key. +// Allows up to RPM requests per minute with a burst equal to RPM. +// Thread-safe. +type RateLimiter struct { + mu sync.Mutex + rpm int + tokens float64 + maxBurst float64 + lastTick time.Time + nowFunc func() time.Time // for testing +} + +func (rl *RateLimiter) refillLocked(now time.Time) { + elapsed := now.Sub(rl.lastTick).Seconds() + rl.lastTick = now + + // Refill tokens proportional to elapsed time. + refill := elapsed * float64(rl.rpm) / 60.0 + rl.tokens = min(rl.maxBurst, rl.tokens+refill) +} + +// newRateLimiter creates a RateLimiter that allows rpm requests/minute. +func newRateLimiter(rpm int) *RateLimiter { + return &RateLimiter{ + rpm: rpm, + tokens: float64(rpm), // start full + maxBurst: float64(rpm), + lastTick: time.Now(), + nowFunc: time.Now, + } +} + +// Wait blocks until a token is available or ctx is canceled. +// Returns ctx.Err() if canceled while waiting. +func (rl *RateLimiter) Wait(ctx context.Context) error { + for { + rl.mu.Lock() + now := rl.nowFunc() + rl.refillLocked(now) + + if rl.tokens >= 1.0 { + rl.tokens-- + rl.mu.Unlock() + return nil + } + + // Calculate how long until a token is available. + deficit := 1.0 - rl.tokens + waitSec := deficit / (float64(rl.rpm) / 60.0) + rl.mu.Unlock() + + timer := time.NewTimer(time.Duration(waitSec * float64(time.Second))) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ctx.Err() + case <-timer.C: + // Loop to re-check (another goroutine may have consumed the token). + } + } +} + +// TryAcquire attempts to consume a token without blocking. +func (rl *RateLimiter) TryAcquire() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.refillLocked(rl.nowFunc()) + if rl.tokens < 1.0 { + return false + } + rl.tokens-- + return true +} + +// RateLimiterRegistry holds per-candidate rate limiters. +// Candidates with RPM=0 are unrestricted. +// Thread-safe for concurrent reads/writes. +type RateLimiterRegistry struct { + mu sync.RWMutex + limiters map[string]*RateLimiter +} + +// NewRateLimiterRegistry creates an empty registry. +func NewRateLimiterRegistry() *RateLimiterRegistry { + return &RateLimiterRegistry{ + limiters: make(map[string]*RateLimiter), + } +} + +// Register adds a rate limiter for the given key at the given RPM. +// If rpm <= 0, no limiter is registered (unrestricted). +func (r *RateLimiterRegistry) Register(key string, rpm int) { + if rpm <= 0 { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.limiters[key] = newRateLimiter(rpm) +} + +// Wait acquires a token for the given key, blocking if needed. +// If no limiter is registered for key, returns immediately. +func (r *RateLimiterRegistry) Wait(ctx context.Context, key string) error { + r.mu.RLock() + rl := r.limiters[key] + r.mu.RUnlock() + if rl == nil { + return nil + } + return rl.Wait(ctx) +} + +// TryAcquire attempts to consume a token for the given key without blocking. +// If no limiter is registered for key, it returns true. +func (r *RateLimiterRegistry) TryAcquire(key string) bool { + r.mu.RLock() + rl := r.limiters[key] + r.mu.RUnlock() + if rl == nil { + return true + } + return rl.TryAcquire() +} + +// RegisterCandidates registers rate limiters for all candidates that have RPM > 0. +// Candidates with RPM == 0 are ignored (no restriction). +func (r *RateLimiterRegistry) RegisterCandidates(candidates []FallbackCandidate) { + for _, c := range candidates { + if c.RPM > 0 { + r.Register(c.StableKey(), c.RPM) + } + } +} diff --git a/pkg/providers/ratelimiter_test.go b/pkg/providers/ratelimiter_test.go new file mode 100644 index 000000000..9972616e9 --- /dev/null +++ b/pkg/providers/ratelimiter_test.go @@ -0,0 +1,209 @@ +package providers + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestRateLimiter_AllowsUpToRPM verifies that up to RPM requests pass immediately +// (burst capacity) and the (RPM+1)-th request is delayed. +func TestRateLimiter_AllowsUpToRPM(t *testing.T) { + rpm := 5 + rl := newRateLimiter(rpm) + + // All rpm tokens should be available immediately (bucket starts full). + for i := 0; i < rpm; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + if err := rl.Wait(ctx); err != nil { + t.Fatalf("request %d should pass immediately, got: %v", i+1, err) + } + cancel() + } + + // The next request must wait; cancel it to confirm it blocks. + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + err := rl.Wait(ctx) + if err == nil { + t.Fatal("expected request beyond RPM to block, but it passed immediately") + } +} + +// TestRateLimiter_ContextCancellation verifies that a blocked Wait respects cancellation. +func TestRateLimiter_ContextCancellation(t *testing.T) { + rl := newRateLimiter(1) + + // Drain the one token. + ctx := context.Background() + if err := rl.Wait(ctx); err != nil { + t.Fatalf("first request failed: %v", err) + } + + // Second request should block; cancel it. + cancelCtx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + err := rl.Wait(cancelCtx) + if err == nil { + t.Fatal("expected cancellation error, got nil") + } +} + +// TestRateLimiter_TokenRefill verifies that tokens refill over time. +func TestRateLimiter_TokenRefill(t *testing.T) { + rpm := 60 // 1 token per second + rl := newRateLimiter(rpm) + + // Drain all tokens. + for i := 0; i < rpm; i++ { + rl.Wait(context.Background()) //nolint:errcheck + } + + // Advance time via nowFunc: simulate 2 seconds passing (should give 2 tokens). + start := time.Now() + rl.nowFunc = func() time.Time { return start.Add(2 * time.Second) } + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + if err := rl.Wait(ctx); err != nil { + t.Fatalf("expected refilled token to be available: %v", err) + } +} + +// TestRateLimiterRegistry_NoLimiter verifies that keys without a registered limiter pass freely. +func TestRateLimiterRegistry_NoLimiter(t *testing.T) { + r := NewRateLimiterRegistry() + ctx := context.Background() + for i := 0; i < 100; i++ { + if err := r.Wait(ctx, "unregistered/key"); err != nil { + t.Fatalf("unregistered key should not block: %v", err) + } + } +} + +// TestRateLimiterRegistry_ZeroRPM verifies that RPM=0 means no limiter is registered. +func TestRateLimiterRegistry_ZeroRPM(t *testing.T) { + r := NewRateLimiterRegistry() + r.Register("some/key", 0) + ctx := context.Background() + for i := 0; i < 50; i++ { + if err := r.Wait(ctx, "some/key"); err != nil { + t.Fatalf("zero-RPM key should not block: %v", err) + } + } +} + +// TestRateLimiterRegistry_Enforcement verifies the registry enforces RPM per key. +func TestRateLimiterRegistry_Enforcement(t *testing.T) { + r := NewRateLimiterRegistry() + r.Register("openai/gpt-4o", 3) + + // First 3 calls should pass (burst = RPM). + for i := 0; i < 3; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + if err := r.Wait(ctx, "openai/gpt-4o"); err != nil { + t.Fatalf("call %d should pass: %v", i+1, err) + } + cancel() + } + + // 4th call should block. + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if err := r.Wait(ctx, "openai/gpt-4o"); err == nil { + t.Fatal("4th call should have been rate-limited") + } +} + +// TestRateLimiterRegistry_RegisterCandidates verifies that RegisterCandidates +// correctly picks up RPM from FallbackCandidate. +func TestRateLimiterRegistry_RegisterCandidates(t *testing.T) { + r := NewRateLimiterRegistry() + candidates := []FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", RPM: 2}, + {Provider: "anthropic", Model: "claude-3", RPM: 0}, // no limit + } + r.RegisterCandidates(candidates) + + // openai/gpt-4o: 2 tokens burst, 3rd should block. + for i := 0; i < 2; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + if err := r.Wait(ctx, "openai/gpt-4o"); err != nil { + t.Fatalf("openai call %d should pass: %v", i+1, err) + } + cancel() + } + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if err := r.Wait(ctx, "openai/gpt-4o"); err == nil { + t.Fatal("openai 3rd call should have been limited") + } + + // anthropic/claude-3: no limit, should always pass. + for i := 0; i < 10; i++ { + if err := r.Wait(context.Background(), "anthropic/claude-3"); err != nil { + t.Fatalf("anthropic call should not be limited: %v", err) + } + } +} + +func TestRateLimiterRegistry_RegisterCandidatesUsesStableIdentity(t *testing.T) { + r := NewRateLimiterRegistry() + candidates := []FallbackCandidate{ + {Provider: "openai", Model: "gpt-4o", RPM: 1, IdentityKey: "model_name:primary"}, + {Provider: "openai", Model: "gpt-4o", RPM: 2, IdentityKey: "model_name:fallback"}, + } + r.RegisterCandidates(candidates) + + if err := r.Wait(context.Background(), "model_name:primary"); err != nil { + t.Fatalf("primary first call should pass: %v", err) + } + if err := r.Wait(context.Background(), "model_name:fallback"); err != nil { + t.Fatalf("fallback first call should pass: %v", err) + } + if err := r.Wait(context.Background(), "model_name:fallback"); err != nil { + t.Fatalf("fallback second call should pass: %v", err) + } + + ctxPrimary, cancelPrimary := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancelPrimary() + if err := r.Wait(ctxPrimary, "model_name:primary"); err == nil { + t.Fatal("primary second call should have been limited") + } + + ctxFallback, cancelFallback := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancelFallback() + if err := r.Wait(ctxFallback, "model_name:fallback"); err == nil { + t.Fatal("fallback third call should have been limited") + } +} + +// TestRateLimiter_Concurrency verifies thread safety under concurrent access. +func TestRateLimiter_Concurrency(t *testing.T) { + rpm := 20 + rl := newRateLimiter(rpm) + var passed atomic.Int64 + var wg sync.WaitGroup + + // Launch 30 goroutines; only ~20 should pass immediately. + for i := 0; i < 30; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + if rl.Wait(ctx) == nil { + passed.Add(1) + } + }() + } + wg.Wait() + + got := passed.Load() + // Allow small timing slack: between rpm-2 and rpm+2. + if got < int64(rpm-2) || got > int64(rpm+2) { + t.Fatalf("expected ~%d immediate passes, got %d", rpm, got) + } +}