diff --git a/go.mod b/go.mod index 98e20d07d..2d7624cf7 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + golang.org/x/time v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index abbb11cd6..bd5165d7e 100644 --- a/go.sum +++ b/go.sum @@ -236,6 +236,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/pkg/channels/errors.go b/pkg/channels/errors.go new file mode 100644 index 000000000..09ee88b3f --- /dev/null +++ b/pkg/channels/errors.go @@ -0,0 +1,21 @@ +package channels + +import "errors" + +var ( + // ErrNotRunning indicates the channel is not running. + // Manager will not retry. + ErrNotRunning = errors.New("channel not running") + + // ErrRateLimit indicates the platform returned a rate-limit response (e.g. HTTP 429). + // Manager will wait a fixed delay and retry. + ErrRateLimit = errors.New("rate limited") + + // ErrTemporary indicates a transient failure (e.g. network timeout, 5xx). + // Manager will use exponential backoff and retry. + ErrTemporary = errors.New("temporary failure") + + // ErrSendFailed indicates a permanent failure (e.g. invalid chat ID, 4xx non-429). + // Manager will not retry. + ErrSendFailed = errors.New("send failed") +) diff --git a/pkg/channels/errors_test.go b/pkg/channels/errors_test.go new file mode 100644 index 000000000..e5592345a --- /dev/null +++ b/pkg/channels/errors_test.go @@ -0,0 +1,56 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestErrorsIs(t *testing.T) { + wrapped := fmt.Errorf("telegram API: %w", ErrRateLimit) + if !errors.Is(wrapped, ErrRateLimit) { + t.Error("wrapped ErrRateLimit should match") + } + if errors.Is(wrapped, ErrTemporary) { + t.Error("wrapped ErrRateLimit should not match ErrTemporary") + } +} + +func TestErrorsIsAllTypes(t *testing.T) { + sentinels := []error{ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed} + + for _, sentinel := range sentinels { + wrapped := fmt.Errorf("context: %w", sentinel) + if !errors.Is(wrapped, sentinel) { + t.Errorf("wrapped %v should match itself", sentinel) + } + + // Verify it doesn't match other sentinel errors + for _, other := range sentinels { + if other == sentinel { + continue + } + if errors.Is(wrapped, other) { + t.Errorf("wrapped %v should not match %v", sentinel, other) + } + } + } +} + +func TestErrorMessages(t *testing.T) { + tests := []struct { + err error + want string + }{ + {ErrNotRunning, "channel not running"}, + {ErrRateLimit, "rate limited"}, + {ErrTemporary, "temporary failure"}, + {ErrSendFailed, "send failed"}, + } + + for _, tt := range tests { + if got := tt.err.Error(); got != tt.want { + t.Errorf("error message = %q, want %q", got, tt.want) + } + } +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 37af01796..1bc321cec 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -8,8 +8,13 @@ package channels import ( "context" + "errors" "fmt" + "math" "sync" + "time" + + "golang.org/x/time/rate" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" @@ -19,12 +24,28 @@ import ( "github.com/sipeed/picoclaw/pkg/utils" ) -const defaultChannelQueueSize = 100 +const ( + defaultChannelQueueSize = 100 + defaultRateLimit = 10 // default 10 msg/s + maxRetries = 3 + rateLimitDelay = 1 * time.Second + baseBackoff = 500 * time.Millisecond + maxBackoff = 8 * time.Second +) + +// channelRateConfig maps channel name to per-second rate limit. +var channelRateConfig = map[string]float64{ + "telegram": 20, + "discord": 1, + "slack": 1, + "line": 10, +} type channelWorker struct { - ch Channel - queue chan bus.OutboundMessage - done chan struct{} + ch Channel + queue chan bus.OutboundMessage + done chan struct{} + limiter *rate.Limiter } type Manager struct { @@ -83,11 +104,7 @@ func (m *Manager) initChannel(name, displayName string) { } } m.channels[name] = ch - m.workers[name] = &channelWorker{ - ch: ch, - queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), - done: make(chan struct{}), - } + m.workers[name] = newChannelWorker(name, ch) logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ "channel": displayName, }) @@ -227,6 +244,23 @@ func (m *Manager) StopAll(ctx context.Context) error { return nil } +// newChannelWorker creates a channelWorker with a rate limiter configured +// for the given channel name. +func newChannelWorker(name string, ch Channel) *channelWorker { + rateVal := float64(defaultRateLimit) + if r, ok := channelRateConfig[name]; ok { + rateVal = r + } + burst := int(math.Max(1, math.Ceil(rateVal/2))) + + return &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), + done: make(chan struct{}), + limiter: rate.NewLimiter(rate.Limit(rateVal), burst), + } +} + // runWorker processes outbound messages for a single channel, splitting // messages that exceed the channel's maximum message length. func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) { @@ -246,18 +280,10 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) for _, chunk := range chunks { chunkMsg := msg chunkMsg.Content = chunk - if err := w.ch.Send(ctx, chunkMsg); err != nil { - logger.ErrorCF("channels", "Error sending chunk", map[string]any{ - "channel": name, "error": err.Error(), - }) - } + m.sendWithRetry(ctx, name, w, chunkMsg) } } else { - if err := w.ch.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message", map[string]any{ - "channel": name, "error": err.Error(), - }) - } + m.sendWithRetry(ctx, name, w, msg) } case <-ctx.Done(): return @@ -265,6 +291,63 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) } } +// sendWithRetry sends a message through the channel with rate limiting and +// retry logic. It classifies errors to determine the retry strategy: +// - ErrNotRunning / ErrSendFailed: permanent, no retry +// - ErrRateLimit: fixed delay retry +// - ErrTemporary / unknown: exponential backoff retry +func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMessage) { + // Rate limit: wait for token + if err := w.limiter.Wait(ctx); err != nil { + // ctx cancelled, shutting down + return + } + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + lastErr = w.ch.Send(ctx, msg) + if lastErr == nil { + return + } + + // Permanent failures — don't retry + if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) { + break + } + + // Last attempt exhausted — don't sleep + if attempt == maxRetries { + break + } + + // Rate limit error — fixed delay + if errors.Is(lastErr, ErrRateLimit) { + select { + case <-time.After(rateLimitDelay): + continue + case <-ctx.Done(): + return + } + } + + // ErrTemporary or unknown error — exponential backoff + backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff) + select { + case <-time.After(backoff): + case <-ctx.Done(): + return + } + } + + // All retries exhausted or permanent failure + logger.ErrorCF("channels", "Send failed", map[string]any{ + "channel": name, + "chat_id": msg.ChatID, + "error": lastErr.Error(), + "retries": maxRetries, + }) +} + func (m *Manager) dispatchOutbound(ctx context.Context) { logger.InfoC("channels", "Outbound dispatcher started") @@ -343,11 +426,7 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { m.mu.Lock() defer m.mu.Unlock() m.channels[name] = channel - m.workers[name] = &channelWorker{ - ch: channel, - queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), - done: make(chan struct{}), - } + m.workers[name] = newChannelWorker(name, channel) } func (m *Manager) UnregisterChannel(name string) { diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go new file mode 100644 index 000000000..162c9f8c9 --- /dev/null +++ b/pkg/channels/manager_test.go @@ -0,0 +1,418 @@ +package channels + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/time/rate" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// mockChannel is a test double that delegates Send to a configurable function. +type mockChannel struct { + BaseChannel + sendFn func(ctx context.Context, msg bus.OutboundMessage) error +} + +func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + return m.sendFn(ctx, msg) +} + +func (m *mockChannel) Start(ctx context.Context) error { return nil } +func (m *mockChannel) Stop(ctx context.Context) error { return nil } + +// newTestManager creates a minimal Manager suitable for unit tests. +func newTestManager() *Manager { + return &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + } +} + +func TestSendWithRetry_Success(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call, got %d", callCount) + } +} + +func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount <= 2 { + return fmt.Errorf("network error: %w", ErrTemporary) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 3 { + t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount) + } +} + +func TestSendWithRetry_PermanentFailure(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("bad chat ID: %w", ErrSendFailed) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount) + } +} + +func TestSendWithRetry_NotRunning(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return ErrNotRunning + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount) + } +} + +func TestSendWithRetry_RateLimitRetry(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return fmt.Errorf("429: %w", ErrRateLimit) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + elapsed := time.Since(start) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount) + } + // Should have waited at least rateLimitDelay (1s) but allow some slack + if elapsed < 900*time.Millisecond { + t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed) + } +} + +func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + expected := maxRetries + 1 // initial attempt + maxRetries retries + if callCount != expected { + t.Fatalf("expected %d Send calls, got %d", expected, callCount) + } +} + +func TestSendWithRetry_UnknownError(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return errors.New("random unexpected error") + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount) + } +} + +func TestSendWithRetry_ContextCancelled(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + // Cancel context after first Send attempt returns + ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + cancel() + return fmt.Errorf("timeout: %w", ErrTemporary) + } + + m.sendWithRetry(ctx, "test", w, msg) + + // Should have called Send once, then noticed ctx cancelled during backoff + if callCount != 1 { + t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount) + } +} + +func TestWorkerRateLimiter(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var sendTimes []time.Time + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + mu.Lock() + sendTimes = append(sendTimes, time.Now()) + mu.Unlock() + return nil + }, + } + + // Create a worker with a low rate: 2 msg/s, burst 1 + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(2, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go m.runWorker(ctx, "test", w) + + // Enqueue 4 messages + for i := 0; i < 4; i++ { + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)} + } + + // Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin) + time.Sleep(3 * time.Second) + + mu.Lock() + times := make([]time.Time, len(sendTimes)) + copy(times, sendTimes) + mu.Unlock() + + if len(times) != 4 { + t.Fatalf("expected 4 sends, got %d", len(times)) + } + + // Verify rate limiting: total duration should be at least 1s + // (first message immediate, then ~500ms between each subsequent one at 2/s) + totalDuration := times[len(times)-1].Sub(times[0]) + if totalDuration < 1*time.Second { + t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration) + } +} + +func TestNewChannelWorker_DefaultRate(t *testing.T) { + ch := &mockChannel{} + w := newChannelWorker("unknown_channel", ch) + + if w.limiter == nil { + t.Fatal("expected limiter to be non-nil") + } + if w.limiter.Limit() != rate.Limit(defaultRateLimit) { + t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit()) + } +} + +func TestNewChannelWorker_ConfiguredRate(t *testing.T) { + ch := &mockChannel{} + + for name, expectedRate := range channelRateConfig { + w := newChannelWorker(name, ch) + if w.limiter.Limit() != rate.Limit(expectedRate) { + t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit()) + } + } +} + +func TestRunWorker_MessageSplitting(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var received []string + + ch := &mockChannelWithLength{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + mu.Lock() + received = append(received, msg.Content) + mu.Unlock() + return nil + }, + }, + maxLen: 5, + } + + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go m.runWorker(ctx, "test", w) + + // Send a message that should be split + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"} + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + count := len(received) + mu.Unlock() + + if count < 2 { + t.Fatalf("expected message to be split into at least 2 chunks, got %d", count) + } +} + +// mockChannelWithLength implements MessageLengthProvider. +type mockChannelWithLength struct { + mockChannel + maxLen int +} + +func (m *mockChannelWithLength) MaxMessageLength() int { + return m.maxLen +} + +func TestSendWithRetry_ExponentialBackoff(t *testing.T) { + m := newTestManager() + + var callTimes []time.Time + var callCount atomic.Int32 + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callTimes = append(callTimes, time.Now()) + callCount.Add(1) + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + totalElapsed := time.Since(start) + + // With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s + // Total backoff: 500ms + 1s + 2s = 3.5s + // Allow some margin + if totalElapsed < 3*time.Second { + t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed) + } + + if int(callCount.Load()) != maxRetries+1 { + t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load()) + } +}