mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
b114dcaeb1
* feat(model): rate limiting * fix(agent): preserve per-model identity in rate limiting and fallback * fix test
210 lines
6.5 KiB
Go
210 lines
6.5 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// TestRateLimiter_AllowsUpToRPM verifies that up to RPM requests pass immediately
|
|
// (burst capacity) and the (RPM+1)-th request is delayed.
|
|
func TestRateLimiter_AllowsUpToRPM(t *testing.T) {
|
|
rpm := 5
|
|
rl := newRateLimiter(rpm)
|
|
|
|
// All rpm tokens should be available immediately (bucket starts full).
|
|
for i := 0; i < rpm; i++ {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
if err := rl.Wait(ctx); err != nil {
|
|
t.Fatalf("request %d should pass immediately, got: %v", i+1, err)
|
|
}
|
|
cancel()
|
|
}
|
|
|
|
// The next request must wait; cancel it to confirm it blocks.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
|
defer cancel()
|
|
err := rl.Wait(ctx)
|
|
if err == nil {
|
|
t.Fatal("expected request beyond RPM to block, but it passed immediately")
|
|
}
|
|
}
|
|
|
|
// TestRateLimiter_ContextCancellation verifies that a blocked Wait respects cancellation.
|
|
func TestRateLimiter_ContextCancellation(t *testing.T) {
|
|
rl := newRateLimiter(1)
|
|
|
|
// Drain the one token.
|
|
ctx := context.Background()
|
|
if err := rl.Wait(ctx); err != nil {
|
|
t.Fatalf("first request failed: %v", err)
|
|
}
|
|
|
|
// Second request should block; cancel it.
|
|
cancelCtx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
|
|
defer cancel()
|
|
err := rl.Wait(cancelCtx)
|
|
if err == nil {
|
|
t.Fatal("expected cancellation error, got nil")
|
|
}
|
|
}
|
|
|
|
// TestRateLimiter_TokenRefill verifies that tokens refill over time.
|
|
func TestRateLimiter_TokenRefill(t *testing.T) {
|
|
rpm := 60 // 1 token per second
|
|
rl := newRateLimiter(rpm)
|
|
|
|
// Drain all tokens.
|
|
for i := 0; i < rpm; i++ {
|
|
rl.Wait(context.Background()) //nolint:errcheck
|
|
}
|
|
|
|
// Advance time via nowFunc: simulate 2 seconds passing (should give 2 tokens).
|
|
start := time.Now()
|
|
rl.nowFunc = func() time.Time { return start.Add(2 * time.Second) }
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
if err := rl.Wait(ctx); err != nil {
|
|
t.Fatalf("expected refilled token to be available: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestRateLimiterRegistry_NoLimiter verifies that keys without a registered limiter pass freely.
|
|
func TestRateLimiterRegistry_NoLimiter(t *testing.T) {
|
|
r := NewRateLimiterRegistry()
|
|
ctx := context.Background()
|
|
for i := 0; i < 100; i++ {
|
|
if err := r.Wait(ctx, "unregistered/key"); err != nil {
|
|
t.Fatalf("unregistered key should not block: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRateLimiterRegistry_ZeroRPM verifies that RPM=0 means no limiter is registered.
|
|
func TestRateLimiterRegistry_ZeroRPM(t *testing.T) {
|
|
r := NewRateLimiterRegistry()
|
|
r.Register("some/key", 0)
|
|
ctx := context.Background()
|
|
for i := 0; i < 50; i++ {
|
|
if err := r.Wait(ctx, "some/key"); err != nil {
|
|
t.Fatalf("zero-RPM key should not block: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRateLimiterRegistry_Enforcement verifies the registry enforces RPM per key.
|
|
func TestRateLimiterRegistry_Enforcement(t *testing.T) {
|
|
r := NewRateLimiterRegistry()
|
|
r.Register("openai/gpt-4o", 3)
|
|
|
|
// First 3 calls should pass (burst = RPM).
|
|
for i := 0; i < 3; i++ {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
if err := r.Wait(ctx, "openai/gpt-4o"); err != nil {
|
|
t.Fatalf("call %d should pass: %v", i+1, err)
|
|
}
|
|
cancel()
|
|
}
|
|
|
|
// 4th call should block.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
|
defer cancel()
|
|
if err := r.Wait(ctx, "openai/gpt-4o"); err == nil {
|
|
t.Fatal("4th call should have been rate-limited")
|
|
}
|
|
}
|
|
|
|
// TestRateLimiterRegistry_RegisterCandidates verifies that RegisterCandidates
|
|
// correctly picks up RPM from FallbackCandidate.
|
|
func TestRateLimiterRegistry_RegisterCandidates(t *testing.T) {
|
|
r := NewRateLimiterRegistry()
|
|
candidates := []FallbackCandidate{
|
|
{Provider: "openai", Model: "gpt-4o", RPM: 2},
|
|
{Provider: "anthropic", Model: "claude-3", RPM: 0}, // no limit
|
|
}
|
|
r.RegisterCandidates(candidates)
|
|
|
|
// openai/gpt-4o: 2 tokens burst, 3rd should block.
|
|
for i := 0; i < 2; i++ {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
if err := r.Wait(ctx, "openai/gpt-4o"); err != nil {
|
|
t.Fatalf("openai call %d should pass: %v", i+1, err)
|
|
}
|
|
cancel()
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
|
defer cancel()
|
|
if err := r.Wait(ctx, "openai/gpt-4o"); err == nil {
|
|
t.Fatal("openai 3rd call should have been limited")
|
|
}
|
|
|
|
// anthropic/claude-3: no limit, should always pass.
|
|
for i := 0; i < 10; i++ {
|
|
if err := r.Wait(context.Background(), "anthropic/claude-3"); err != nil {
|
|
t.Fatalf("anthropic call should not be limited: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimiterRegistry_RegisterCandidatesUsesStableIdentity(t *testing.T) {
|
|
r := NewRateLimiterRegistry()
|
|
candidates := []FallbackCandidate{
|
|
{Provider: "openai", Model: "gpt-4o", RPM: 1, IdentityKey: "model_name:primary"},
|
|
{Provider: "openai", Model: "gpt-4o", RPM: 2, IdentityKey: "model_name:fallback"},
|
|
}
|
|
r.RegisterCandidates(candidates)
|
|
|
|
if err := r.Wait(context.Background(), "model_name:primary"); err != nil {
|
|
t.Fatalf("primary first call should pass: %v", err)
|
|
}
|
|
if err := r.Wait(context.Background(), "model_name:fallback"); err != nil {
|
|
t.Fatalf("fallback first call should pass: %v", err)
|
|
}
|
|
if err := r.Wait(context.Background(), "model_name:fallback"); err != nil {
|
|
t.Fatalf("fallback second call should pass: %v", err)
|
|
}
|
|
|
|
ctxPrimary, cancelPrimary := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
|
defer cancelPrimary()
|
|
if err := r.Wait(ctxPrimary, "model_name:primary"); err == nil {
|
|
t.Fatal("primary second call should have been limited")
|
|
}
|
|
|
|
ctxFallback, cancelFallback := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
|
defer cancelFallback()
|
|
if err := r.Wait(ctxFallback, "model_name:fallback"); err == nil {
|
|
t.Fatal("fallback third call should have been limited")
|
|
}
|
|
}
|
|
|
|
// TestRateLimiter_Concurrency verifies thread safety under concurrent access.
|
|
func TestRateLimiter_Concurrency(t *testing.T) {
|
|
rpm := 20
|
|
rl := newRateLimiter(rpm)
|
|
var passed atomic.Int64
|
|
var wg sync.WaitGroup
|
|
|
|
// Launch 30 goroutines; only ~20 should pass immediately.
|
|
for i := 0; i < 30; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
|
defer cancel()
|
|
if rl.Wait(ctx) == nil {
|
|
passed.Add(1)
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
got := passed.Load()
|
|
// Allow small timing slack: between rpm-2 and rpm+2.
|
|
if got < int64(rpm-2) || got > int64(rpm+2) {
|
|
t.Fatalf("expected ~%d immediate passes, got %d", rpm, got)
|
|
}
|
|
}
|