mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(channels): add per-channel rate limiting and send retry with error classification
Define sentinel error types (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed) so the Manager can classify Send failures and choose the right retry strategy: permanent errors bail immediately, rate-limit errors use a fixed 1s delay, and temporary/unknown errors use exponential backoff (500ms→1s→2s, capped at 8s, up to 3 retries). A per-channel token-bucket rate limiter (golang.org/x/time/rate) throttles outbound sends before they hit the platform API.
This commit is contained in:
@@ -23,6 +23,7 @@ require (
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
|
||||
@@ -226,6 +226,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package channels
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNotRunning indicates the channel is not running.
|
||||
// Manager will not retry.
|
||||
ErrNotRunning = errors.New("channel not running")
|
||||
|
||||
// ErrRateLimit indicates the platform returned a rate-limit response (e.g. HTTP 429).
|
||||
// Manager will wait a fixed delay and retry.
|
||||
ErrRateLimit = errors.New("rate limited")
|
||||
|
||||
// ErrTemporary indicates a transient failure (e.g. network timeout, 5xx).
|
||||
// Manager will use exponential backoff and retry.
|
||||
ErrTemporary = errors.New("temporary failure")
|
||||
|
||||
// ErrSendFailed indicates a permanent failure (e.g. invalid chat ID, 4xx non-429).
|
||||
// Manager will not retry.
|
||||
ErrSendFailed = errors.New("send failed")
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorsIs(t *testing.T) {
|
||||
wrapped := fmt.Errorf("telegram API: %w", ErrRateLimit)
|
||||
if !errors.Is(wrapped, ErrRateLimit) {
|
||||
t.Error("wrapped ErrRateLimit should match")
|
||||
}
|
||||
if errors.Is(wrapped, ErrTemporary) {
|
||||
t.Error("wrapped ErrRateLimit should not match ErrTemporary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorsIsAllTypes(t *testing.T) {
|
||||
sentinels := []error{ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed}
|
||||
|
||||
for _, sentinel := range sentinels {
|
||||
wrapped := fmt.Errorf("context: %w", sentinel)
|
||||
if !errors.Is(wrapped, sentinel) {
|
||||
t.Errorf("wrapped %v should match itself", sentinel)
|
||||
}
|
||||
|
||||
// Verify it doesn't match other sentinel errors
|
||||
for _, other := range sentinels {
|
||||
if other == sentinel {
|
||||
continue
|
||||
}
|
||||
if errors.Is(wrapped, other) {
|
||||
t.Errorf("wrapped %v should not match %v", sentinel, other)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorMessages(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{ErrNotRunning, "channel not running"},
|
||||
{ErrRateLimit, "rate limited"},
|
||||
{ErrTemporary, "temporary failure"},
|
||||
{ErrSendFailed, "send failed"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := tt.err.Error(); got != tt.want {
|
||||
t.Errorf("error message = %q, want %q", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+103
-24
@@ -8,8 +8,13 @@ package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -19,12 +24,28 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const defaultChannelQueueSize = 100
|
||||
const (
|
||||
defaultChannelQueueSize = 100
|
||||
defaultRateLimit = 10 // default 10 msg/s
|
||||
maxRetries = 3
|
||||
rateLimitDelay = 1 * time.Second
|
||||
baseBackoff = 500 * time.Millisecond
|
||||
maxBackoff = 8 * time.Second
|
||||
)
|
||||
|
||||
// channelRateConfig maps channel name to per-second rate limit.
|
||||
var channelRateConfig = map[string]float64{
|
||||
"telegram": 20,
|
||||
"discord": 1,
|
||||
"slack": 1,
|
||||
"line": 10,
|
||||
}
|
||||
|
||||
type channelWorker struct {
|
||||
ch Channel
|
||||
queue chan bus.OutboundMessage
|
||||
done chan struct{}
|
||||
ch Channel
|
||||
queue chan bus.OutboundMessage
|
||||
done chan struct{}
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -83,11 +104,7 @@ func (m *Manager) initChannel(name, displayName string) {
|
||||
}
|
||||
}
|
||||
m.channels[name] = ch
|
||||
m.workers[name] = &channelWorker{
|
||||
ch: ch,
|
||||
queue: make(chan bus.OutboundMessage, defaultChannelQueueSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
m.workers[name] = newChannelWorker(name, ch)
|
||||
logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
|
||||
"channel": displayName,
|
||||
})
|
||||
@@ -227,6 +244,23 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// newChannelWorker creates a channelWorker with a rate limiter configured
|
||||
// for the given channel name.
|
||||
func newChannelWorker(name string, ch Channel) *channelWorker {
|
||||
rateVal := float64(defaultRateLimit)
|
||||
if r, ok := channelRateConfig[name]; ok {
|
||||
rateVal = r
|
||||
}
|
||||
burst := int(math.Max(1, math.Ceil(rateVal/2)))
|
||||
|
||||
return &channelWorker{
|
||||
ch: ch,
|
||||
queue: make(chan bus.OutboundMessage, defaultChannelQueueSize),
|
||||
done: make(chan struct{}),
|
||||
limiter: rate.NewLimiter(rate.Limit(rateVal), burst),
|
||||
}
|
||||
}
|
||||
|
||||
// runWorker processes outbound messages for a single channel, splitting
|
||||
// messages that exceed the channel's maximum message length.
|
||||
func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) {
|
||||
@@ -246,18 +280,10 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker)
|
||||
for _, chunk := range chunks {
|
||||
chunkMsg := msg
|
||||
chunkMsg.Content = chunk
|
||||
if err := w.ch.Send(ctx, chunkMsg); err != nil {
|
||||
logger.ErrorCF("channels", "Error sending chunk", map[string]any{
|
||||
"channel": name, "error": err.Error(),
|
||||
})
|
||||
}
|
||||
m.sendWithRetry(ctx, name, w, chunkMsg)
|
||||
}
|
||||
} else {
|
||||
if err := w.ch.Send(ctx, msg); err != nil {
|
||||
logger.ErrorCF("channels", "Error sending message", map[string]any{
|
||||
"channel": name, "error": err.Error(),
|
||||
})
|
||||
}
|
||||
m.sendWithRetry(ctx, name, w, msg)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -265,6 +291,63 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker)
|
||||
}
|
||||
}
|
||||
|
||||
// sendWithRetry sends a message through the channel with rate limiting and
|
||||
// retry logic. It classifies errors to determine the retry strategy:
|
||||
// - ErrNotRunning / ErrSendFailed: permanent, no retry
|
||||
// - ErrRateLimit: fixed delay retry
|
||||
// - ErrTemporary / unknown: exponential backoff retry
|
||||
func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMessage) {
|
||||
// Rate limit: wait for token
|
||||
if err := w.limiter.Wait(ctx); err != nil {
|
||||
// ctx cancelled, shutting down
|
||||
return
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
lastErr = w.ch.Send(ctx, msg)
|
||||
if lastErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Permanent failures — don't retry
|
||||
if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) {
|
||||
break
|
||||
}
|
||||
|
||||
// Last attempt exhausted — don't sleep
|
||||
if attempt == maxRetries {
|
||||
break
|
||||
}
|
||||
|
||||
// Rate limit error — fixed delay
|
||||
if errors.Is(lastErr, ErrRateLimit) {
|
||||
select {
|
||||
case <-time.After(rateLimitDelay):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ErrTemporary or unknown error — exponential backoff
|
||||
backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// All retries exhausted or permanent failure
|
||||
logger.ErrorCF("channels", "Send failed", map[string]any{
|
||||
"channel": name,
|
||||
"chat_id": msg.ChatID,
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
logger.InfoC("channels", "Outbound dispatcher started")
|
||||
|
||||
@@ -343,11 +426,7 @@ func (m *Manager) RegisterChannel(name string, channel Channel) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.channels[name] = channel
|
||||
m.workers[name] = &channelWorker{
|
||||
ch: channel,
|
||||
queue: make(chan bus.OutboundMessage, defaultChannelQueueSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
m.workers[name] = newChannelWorker(name, channel)
|
||||
}
|
||||
|
||||
func (m *Manager) UnregisterChannel(name string) {
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
)
|
||||
|
||||
// mockChannel is a test double that delegates Send to a configurable function.
|
||||
type mockChannel struct {
|
||||
BaseChannel
|
||||
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
|
||||
}
|
||||
|
||||
func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return m.sendFn(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *mockChannel) Start(ctx context.Context) error { return nil }
|
||||
func (m *mockChannel) Stop(ctx context.Context) error { return nil }
|
||||
|
||||
// newTestManager creates a minimal Manager suitable for unit tests.
|
||||
func newTestManager() *Manager {
|
||||
return &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_Success(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 Send call, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
if callCount <= 2 {
|
||||
return fmt.Errorf("network error: %w", ErrTemporary)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
if callCount != 3 {
|
||||
t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_PermanentFailure(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
return fmt.Errorf("bad chat ID: %w", ErrSendFailed)
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_NotRunning(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
return ErrNotRunning
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_RateLimitRetry(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
return fmt.Errorf("429: %w", ErrRateLimit)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if callCount != 2 {
|
||||
t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount)
|
||||
}
|
||||
// Should have waited at least rateLimitDelay (1s) but allow some slack
|
||||
if elapsed < 900*time.Millisecond {
|
||||
t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
return fmt.Errorf("timeout: %w", ErrTemporary)
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
expected := maxRetries + 1 // initial attempt + maxRetries retries
|
||||
if callCount != expected {
|
||||
t.Fatalf("expected %d Send calls, got %d", expected, callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_UnknownError(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
return errors.New("random unexpected error")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
if callCount != 2 {
|
||||
t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_ContextCancelled(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
return fmt.Errorf("timeout: %w", ErrTemporary)
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
// Cancel context after first Send attempt returns
|
||||
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callCount++
|
||||
cancel()
|
||||
return fmt.Errorf("timeout: %w", ErrTemporary)
|
||||
}
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
// Should have called Send once, then noticed ctx cancelled during backoff
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerRateLimiter(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var mu sync.Mutex
|
||||
var sendTimes []time.Time
|
||||
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
mu.Lock()
|
||||
sendTimes = append(sendTimes, time.Now())
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Create a worker with a low rate: 2 msg/s, burst 1
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
queue: make(chan bus.OutboundMessage, 10),
|
||||
done: make(chan struct{}),
|
||||
limiter: rate.NewLimiter(2, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
// Enqueue 4 messages
|
||||
for i := 0; i < 4; i++ {
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
|
||||
}
|
||||
|
||||
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
mu.Lock()
|
||||
times := make([]time.Time, len(sendTimes))
|
||||
copy(times, sendTimes)
|
||||
mu.Unlock()
|
||||
|
||||
if len(times) != 4 {
|
||||
t.Fatalf("expected 4 sends, got %d", len(times))
|
||||
}
|
||||
|
||||
// Verify rate limiting: total duration should be at least 1s
|
||||
// (first message immediate, then ~500ms between each subsequent one at 2/s)
|
||||
totalDuration := times[len(times)-1].Sub(times[0])
|
||||
if totalDuration < 1*time.Second {
|
||||
t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewChannelWorker_DefaultRate(t *testing.T) {
|
||||
ch := &mockChannel{}
|
||||
w := newChannelWorker("unknown_channel", ch)
|
||||
|
||||
if w.limiter == nil {
|
||||
t.Fatal("expected limiter to be non-nil")
|
||||
}
|
||||
if w.limiter.Limit() != rate.Limit(defaultRateLimit) {
|
||||
t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewChannelWorker_ConfiguredRate(t *testing.T) {
|
||||
ch := &mockChannel{}
|
||||
|
||||
for name, expectedRate := range channelRateConfig {
|
||||
w := newChannelWorker(name, ch)
|
||||
if w.limiter.Limit() != rate.Limit(expectedRate) {
|
||||
t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWorker_MessageSplitting(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var mu sync.Mutex
|
||||
var received []string
|
||||
|
||||
ch := &mockChannelWithLength{
|
||||
mockChannel: mockChannel{
|
||||
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
|
||||
mu.Lock()
|
||||
received = append(received, msg.Content)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
},
|
||||
},
|
||||
maxLen: 5,
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
queue: make(chan bus.OutboundMessage, 10),
|
||||
done: make(chan struct{}),
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
// Send a message that should be split
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
count := len(received)
|
||||
mu.Unlock()
|
||||
|
||||
if count < 2 {
|
||||
t.Fatalf("expected message to be split into at least 2 chunks, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// mockChannelWithLength implements MessageLengthProvider.
|
||||
type mockChannelWithLength struct {
|
||||
mockChannel
|
||||
maxLen int
|
||||
}
|
||||
|
||||
func (m *mockChannelWithLength) MaxMessageLength() int {
|
||||
return m.maxLen
|
||||
}
|
||||
|
||||
func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var callTimes []time.Time
|
||||
var callCount atomic.Int32
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
callTimes = append(callTimes, time.Now())
|
||||
callCount.Add(1)
|
||||
return fmt.Errorf("timeout: %w", ErrTemporary)
|
||||
},
|
||||
}
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
totalElapsed := time.Since(start)
|
||||
|
||||
// With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s
|
||||
// Total backoff: 500ms + 1s + 2s = 3.5s
|
||||
// Allow some margin
|
||||
if totalElapsed < 3*time.Second {
|
||||
t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed)
|
||||
}
|
||||
|
||||
if int(callCount.Load()) != maxRetries+1 {
|
||||
t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user