refactor(channels): add per-channel rate limiting and send retry with error classification

Define sentinel error types (ErrNotRunning, ErrRateLimit, ErrTemporary,
ErrSendFailed) so the Manager can classify Send failures and choose the
right retry strategy: permanent errors bail immediately, rate-limit
errors use a fixed 1s delay, and temporary/unknown errors use exponential
backoff (500ms→1s→2s, capped at 8s, up to 3 retries). A per-channel
token-bucket rate limiter (golang.org/x/time/rate) throttles outbound
sends before they hit the platform API.
This commit is contained in:
Hoshina
2026-02-22 23:51:55 +08:00
parent 038fdf5000
commit 38a26d702c
6 changed files with 601 additions and 24 deletions
+1
View File
@@ -26,6 +26,7 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
golang.org/x/time v0.14.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+2
View File
@@ -236,6 +236,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
+21
View File
@@ -0,0 +1,21 @@
package channels
import "errors"
var (
// ErrNotRunning indicates the channel is not running.
// Manager will not retry.
ErrNotRunning = errors.New("channel not running")
// ErrRateLimit indicates the platform returned a rate-limit response (e.g. HTTP 429).
// Manager will wait a fixed delay and retry.
ErrRateLimit = errors.New("rate limited")
// ErrTemporary indicates a transient failure (e.g. network timeout, 5xx).
// Manager will use exponential backoff and retry.
ErrTemporary = errors.New("temporary failure")
// ErrSendFailed indicates a permanent failure (e.g. invalid chat ID, 4xx non-429).
// Manager will not retry.
ErrSendFailed = errors.New("send failed")
)
+56
View File
@@ -0,0 +1,56 @@
package channels
import (
"errors"
"fmt"
"testing"
)
func TestErrorsIs(t *testing.T) {
wrapped := fmt.Errorf("telegram API: %w", ErrRateLimit)
if !errors.Is(wrapped, ErrRateLimit) {
t.Error("wrapped ErrRateLimit should match")
}
if errors.Is(wrapped, ErrTemporary) {
t.Error("wrapped ErrRateLimit should not match ErrTemporary")
}
}
func TestErrorsIsAllTypes(t *testing.T) {
sentinels := []error{ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed}
for _, sentinel := range sentinels {
wrapped := fmt.Errorf("context: %w", sentinel)
if !errors.Is(wrapped, sentinel) {
t.Errorf("wrapped %v should match itself", sentinel)
}
// Verify it doesn't match other sentinel errors
for _, other := range sentinels {
if other == sentinel {
continue
}
if errors.Is(wrapped, other) {
t.Errorf("wrapped %v should not match %v", sentinel, other)
}
}
}
}
func TestErrorMessages(t *testing.T) {
tests := []struct {
err error
want string
}{
{ErrNotRunning, "channel not running"},
{ErrRateLimit, "rate limited"},
{ErrTemporary, "temporary failure"},
{ErrSendFailed, "send failed"},
}
for _, tt := range tests {
if got := tt.err.Error(); got != tt.want {
t.Errorf("error message = %q, want %q", got, tt.want)
}
}
}
+103 -24
View File
@@ -8,8 +8,13 @@ package channels
import (
"context"
"errors"
"fmt"
"math"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
@@ -19,12 +24,28 @@ import (
"github.com/sipeed/picoclaw/pkg/utils"
)
const defaultChannelQueueSize = 100
const (
defaultChannelQueueSize = 100
defaultRateLimit = 10 // default 10 msg/s
maxRetries = 3
rateLimitDelay = 1 * time.Second
baseBackoff = 500 * time.Millisecond
maxBackoff = 8 * time.Second
)
// channelRateConfig maps channel name to per-second rate limit.
var channelRateConfig = map[string]float64{
"telegram": 20,
"discord": 1,
"slack": 1,
"line": 10,
}
type channelWorker struct {
ch Channel
queue chan bus.OutboundMessage
done chan struct{}
ch Channel
queue chan bus.OutboundMessage
done chan struct{}
limiter *rate.Limiter
}
type Manager struct {
@@ -83,11 +104,7 @@ func (m *Manager) initChannel(name, displayName string) {
}
}
m.channels[name] = ch
m.workers[name] = &channelWorker{
ch: ch,
queue: make(chan bus.OutboundMessage, defaultChannelQueueSize),
done: make(chan struct{}),
}
m.workers[name] = newChannelWorker(name, ch)
logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
"channel": displayName,
})
@@ -227,6 +244,23 @@ func (m *Manager) StopAll(ctx context.Context) error {
return nil
}
// newChannelWorker creates a channelWorker with a rate limiter configured
// for the given channel name.
func newChannelWorker(name string, ch Channel) *channelWorker {
rateVal := float64(defaultRateLimit)
if r, ok := channelRateConfig[name]; ok {
rateVal = r
}
burst := int(math.Max(1, math.Ceil(rateVal/2)))
return &channelWorker{
ch: ch,
queue: make(chan bus.OutboundMessage, defaultChannelQueueSize),
done: make(chan struct{}),
limiter: rate.NewLimiter(rate.Limit(rateVal), burst),
}
}
// runWorker processes outbound messages for a single channel, splitting
// messages that exceed the channel's maximum message length.
func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) {
@@ -246,18 +280,10 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker)
for _, chunk := range chunks {
chunkMsg := msg
chunkMsg.Content = chunk
if err := w.ch.Send(ctx, chunkMsg); err != nil {
logger.ErrorCF("channels", "Error sending chunk", map[string]any{
"channel": name, "error": err.Error(),
})
}
m.sendWithRetry(ctx, name, w, chunkMsg)
}
} else {
if err := w.ch.Send(ctx, msg); err != nil {
logger.ErrorCF("channels", "Error sending message", map[string]any{
"channel": name, "error": err.Error(),
})
}
m.sendWithRetry(ctx, name, w, msg)
}
case <-ctx.Done():
return
@@ -265,6 +291,63 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker)
}
}
// sendWithRetry sends a message through the channel with rate limiting and
// retry logic. It classifies errors to determine the retry strategy:
// - ErrNotRunning / ErrSendFailed: permanent, no retry
// - ErrRateLimit: fixed delay retry
// - ErrTemporary / unknown: exponential backoff retry
func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMessage) {
// Rate limit: wait for token
if err := w.limiter.Wait(ctx); err != nil {
// ctx cancelled, shutting down
return
}
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
lastErr = w.ch.Send(ctx, msg)
if lastErr == nil {
return
}
// Permanent failures — don't retry
if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) {
break
}
// Last attempt exhausted — don't sleep
if attempt == maxRetries {
break
}
// Rate limit error — fixed delay
if errors.Is(lastErr, ErrRateLimit) {
select {
case <-time.After(rateLimitDelay):
continue
case <-ctx.Done():
return
}
}
// ErrTemporary or unknown error — exponential backoff
backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff)
select {
case <-time.After(backoff):
case <-ctx.Done():
return
}
}
// All retries exhausted or permanent failure
logger.ErrorCF("channels", "Send failed", map[string]any{
"channel": name,
"chat_id": msg.ChatID,
"error": lastErr.Error(),
"retries": maxRetries,
})
}
func (m *Manager) dispatchOutbound(ctx context.Context) {
logger.InfoC("channels", "Outbound dispatcher started")
@@ -343,11 +426,7 @@ func (m *Manager) RegisterChannel(name string, channel Channel) {
m.mu.Lock()
defer m.mu.Unlock()
m.channels[name] = channel
m.workers[name] = &channelWorker{
ch: channel,
queue: make(chan bus.OutboundMessage, defaultChannelQueueSize),
done: make(chan struct{}),
}
m.workers[name] = newChannelWorker(name, channel)
}
func (m *Manager) UnregisterChannel(name string) {
+418
View File
@@ -0,0 +1,418 @@
package channels
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/time/rate"
"github.com/sipeed/picoclaw/pkg/bus"
)
// mockChannel is a test double that delegates Send to a configurable function.
type mockChannel struct {
BaseChannel
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
}
func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
return m.sendFn(ctx, msg)
}
func (m *mockChannel) Start(ctx context.Context) error { return nil }
func (m *mockChannel) Stop(ctx context.Context) error { return nil }
// newTestManager creates a minimal Manager suitable for unit tests.
func newTestManager() *Manager {
return &Manager{
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
}
}
func TestSendWithRetry_Success(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 1 {
t.Fatalf("expected 1 Send call, got %d", callCount)
}
}
func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
if callCount <= 2 {
return fmt.Errorf("network error: %w", ErrTemporary)
}
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 3 {
t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount)
}
}
func TestSendWithRetry_PermanentFailure(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return fmt.Errorf("bad chat ID: %w", ErrSendFailed)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 1 {
t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount)
}
}
func TestSendWithRetry_NotRunning(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return ErrNotRunning
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 1 {
t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount)
}
}
func TestSendWithRetry_RateLimitRetry(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
if callCount == 1 {
return fmt.Errorf("429: %w", ErrRateLimit)
}
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
elapsed := time.Since(start)
if callCount != 2 {
t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount)
}
// Should have waited at least rateLimitDelay (1s) but allow some slack
if elapsed < 900*time.Millisecond {
t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed)
}
}
func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return fmt.Errorf("timeout: %w", ErrTemporary)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
expected := maxRetries + 1 // initial attempt + maxRetries retries
if callCount != expected {
t.Fatalf("expected %d Send calls, got %d", expected, callCount)
}
}
func TestSendWithRetry_UnknownError(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
if callCount == 1 {
return errors.New("random unexpected error")
}
return nil
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
m.sendWithRetry(ctx, "test", w, msg)
if callCount != 2 {
t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount)
}
}
func TestSendWithRetry_ContextCancelled(t *testing.T) {
m := newTestManager()
var callCount int
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
return fmt.Errorf("timeout: %w", ErrTemporary)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx, cancel := context.WithCancel(context.Background())
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
// Cancel context after first Send attempt returns
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
callCount++
cancel()
return fmt.Errorf("timeout: %w", ErrTemporary)
}
m.sendWithRetry(ctx, "test", w, msg)
// Should have called Send once, then noticed ctx cancelled during backoff
if callCount != 1 {
t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount)
}
}
func TestWorkerRateLimiter(t *testing.T) {
m := newTestManager()
var mu sync.Mutex
var sendTimes []time.Time
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
mu.Lock()
sendTimes = append(sendTimes, time.Now())
mu.Unlock()
return nil
},
}
// Create a worker with a low rate: 2 msg/s, burst 1
w := &channelWorker{
ch: ch,
queue: make(chan bus.OutboundMessage, 10),
done: make(chan struct{}),
limiter: rate.NewLimiter(2, 1),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.runWorker(ctx, "test", w)
// Enqueue 4 messages
for i := 0; i < 4; i++ {
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
}
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
time.Sleep(3 * time.Second)
mu.Lock()
times := make([]time.Time, len(sendTimes))
copy(times, sendTimes)
mu.Unlock()
if len(times) != 4 {
t.Fatalf("expected 4 sends, got %d", len(times))
}
// Verify rate limiting: total duration should be at least 1s
// (first message immediate, then ~500ms between each subsequent one at 2/s)
totalDuration := times[len(times)-1].Sub(times[0])
if totalDuration < 1*time.Second {
t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration)
}
}
func TestNewChannelWorker_DefaultRate(t *testing.T) {
ch := &mockChannel{}
w := newChannelWorker("unknown_channel", ch)
if w.limiter == nil {
t.Fatal("expected limiter to be non-nil")
}
if w.limiter.Limit() != rate.Limit(defaultRateLimit) {
t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit())
}
}
func TestNewChannelWorker_ConfiguredRate(t *testing.T) {
ch := &mockChannel{}
for name, expectedRate := range channelRateConfig {
w := newChannelWorker(name, ch)
if w.limiter.Limit() != rate.Limit(expectedRate) {
t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit())
}
}
}
func TestRunWorker_MessageSplitting(t *testing.T) {
m := newTestManager()
var mu sync.Mutex
var received []string
ch := &mockChannelWithLength{
mockChannel: mockChannel{
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
mu.Lock()
received = append(received, msg.Content)
mu.Unlock()
return nil
},
},
maxLen: 5,
}
w := &channelWorker{
ch: ch,
queue: make(chan bus.OutboundMessage, 10),
done: make(chan struct{}),
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.runWorker(ctx, "test", w)
// Send a message that should be split
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
time.Sleep(100 * time.Millisecond)
mu.Lock()
count := len(received)
mu.Unlock()
if count < 2 {
t.Fatalf("expected message to be split into at least 2 chunks, got %d", count)
}
}
// mockChannelWithLength implements MessageLengthProvider.
type mockChannelWithLength struct {
mockChannel
maxLen int
}
func (m *mockChannelWithLength) MaxMessageLength() int {
return m.maxLen
}
func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
m := newTestManager()
var callTimes []time.Time
var callCount atomic.Int32
ch := &mockChannel{
sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
callTimes = append(callTimes, time.Now())
callCount.Add(1)
return fmt.Errorf("timeout: %w", ErrTemporary)
},
}
w := &channelWorker{
ch: ch,
limiter: rate.NewLimiter(rate.Inf, 1),
}
ctx := context.Background()
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
start := time.Now()
m.sendWithRetry(ctx, "test", w, msg)
totalElapsed := time.Since(start)
// With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s
// Total backoff: 500ms + 1s + 2s = 3.5s
// Allow some margin
if totalElapsed < 3*time.Second {
t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed)
}
if int(callCount.Load()) != maxRetries+1 {
t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load())
}
}