Files
picoclaw/pkg/channels/manager_test.go
T
Hoshina 38a26d702c refactor(channels): add per-channel rate limiting and send retry with error classification
Define sentinel error types (ErrNotRunning, ErrRateLimit, ErrTemporary,
ErrSendFailed) so the Manager can classify Send failures and choose the
right retry strategy: permanent errors bail immediately, rate-limit
errors use a fixed 1s delay, and temporary/unknown errors use exponential
backoff (500ms→1s→2s, capped at 8s, up to 3 retries). A per-channel
token-bucket rate limiter (golang.org/x/time/rate) throttles outbound
sends before they hit the platform API.
2026-02-22 23:51:55 +08:00

419 lines
10 KiB
Go

package channels
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/time/rate"
"github.com/sipeed/picoclaw/pkg/bus"
)
// mockChannel is a test double that delegates Send to a configurable function.
type mockChannel struct {
BaseChannel
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
}
func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
return m.sendFn(ctx, msg)
}
func (m *mockChannel) Start(ctx context.Context) error { return nil }
func (m *mockChannel) Stop(ctx context.Context) error { return nil }
// newTestManager creates a minimal Manager suitable for unit tests.
func newTestManager() *Manager {
return &Manager{
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
}
}
func TestSendWithRetry_Success(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 1 {
t.Fatalf("expected 1 Send call, got %d", callCount)
}
}
func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
if callCount <= 2 {
return fmt.Errorf("network error: %w", ErrTemporary)
}
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 3 {
t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount)
}
}
func TestSendWithRetry_PermanentFailure(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return fmt.Errorf("bad chat ID: %w", ErrSendFailed)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 1 {
t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount)
}
}
func TestSendWithRetry_NotRunning(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return ErrNotRunning
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 1 {
t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount)
}
}
func TestSendWithRetry_RateLimitRetry(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
if callCount == 1 {
return fmt.Errorf("429: %w", ErrRateLimit)
}
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
elapsed := time.Since(start)
if callCount != 2 {
t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount)
}
// Should have waited at least rateLimitDelay (1s) but allow some slack
if elapsed < 900*time.Millisecond {
t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed)
}
}
func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return fmt.Errorf("timeout: %w", ErrTemporary)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
expected := maxRetries + 1 // initial attempt + maxRetries retries
if callCount != expected {
t.Fatalf("expected %d Send calls, got %d", expected, callCount)
}
}
func TestSendWithRetry_UnknownError(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
if callCount == 1 {
return errors.New("random unexpected error")
}
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 2 {
t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount)
}
}
func TestSendWithRetry_ContextCancelled(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return fmt.Errorf("timeout: %w", ErrTemporary)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx, cancel := context.WithCancel(context.Background())
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
// Cancel context after first Send attempt returns
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
cancel()
return fmt.Errorf("timeout: %w", ErrTemporary)
}
m.sendWithRetry(ctx, "test", w, msg)
// Should have called Send once, then noticed ctx cancelled during backoff
if callCount != 1 {
t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount)
}
}
func TestWorkerRateLimiter(t *testing.T) {
m := newTestManager()
var mu sync.Mutex
var sendTimes []time.Time
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
mu.Lock()
sendTimes = append(sendTimes, time.Now())
mu.Unlock()
return nil
},
}
// Create a worker with a low rate: 2 msg/s, burst 1
w := &channelWorker{
ch: ch,
queue: make(chan bus.OutboundMessage, 10),
done: make(chan struct{}),
limiter: rate.NewLimiter(2, 1),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.runWorker(ctx, "test", w)
// Enqueue 4 messages
for i := 0; i < 4; i++ {
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
}
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
time.Sleep(3 * time.Second)
mu.Lock()
times := make([]time.Time, len(sendTimes))
copy(times, sendTimes)
mu.Unlock()
if len(times) != 4 {
t.Fatalf("expected 4 sends, got %d", len(times))
}
// Verify rate limiting: total duration should be at least 1s
// (first message immediate, then ~500ms between each subsequent one at 2/s)
totalDuration := times[len(times)-1].Sub(times[0])
if totalDuration < 1*time.Second {
t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration)
}
}
func TestNewChannelWorker_DefaultRate(t *testing.T) {
ch := &mockChannel{}
w := newChannelWorker("unknown_channel", ch)
if w.limiter == nil {
t.Fatal("expected limiter to be non-nil")
}
if w.limiter.Limit() != rate.Limit(defaultRateLimit) {
t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit())
}
}
func TestNewChannelWorker_ConfiguredRate(t *testing.T) {
ch := &mockChannel{}
for name, expectedRate := range channelRateConfig {
w := newChannelWorker(name, ch)
if w.limiter.Limit() != rate.Limit(expectedRate) {
t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit())
}
}
}
func TestRunWorker_MessageSplitting(t *testing.T) {
m := newTestManager()
var mu sync.Mutex
var received []string
ch := &mockChannelWithLength{
mockChannel: mockChannel{
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
mu.Lock()
received = append(received, msg.Content)
mu.Unlock()
return nil
},
},
maxLen: 5,
}
w := &channelWorker{
ch: ch,
queue: make(chan bus.OutboundMessage, 10),
done: make(chan struct{}),
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.runWorker(ctx, "test", w)
// Send a message that should be split
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
time.Sleep(100 * time.Millisecond)
mu.Lock()
count := len(received)
mu.Unlock()
if count < 2 {
t.Fatalf("expected message to be split into at least 2 chunks, got %d", count)
}
}
// mockChannelWithLength implements MessageLengthProvider.
type mockChannelWithLength struct {
mockChannel
maxLen int
}
func (m *mockChannelWithLength) MaxMessageLength() int {
return m.maxLen
}
func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
m := newTestManager()
var callTimes []time.Time
var callCount atomic.Int32
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callTimes = append(callTimes, time.Now())
callCount.Add(1)
return fmt.Errorf("timeout: %w", ErrTemporary)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
totalElapsed := time.Since(start)
// With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s
// Total backoff: 500ms + 1s + 2s = 3.5s
// Allow some margin
if totalElapsed < 3*time.Second {
t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed)
}
if int(callCount.Load()) != maxRetries+1 {
t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load())
}
}