From 257aa0ff573dbcfbf3e88770b5569e4c5f43c836 Mon Sep 17 00:00:00 2001 From: SakoroYou <165740095+Sakurapainting@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:14:47 +0800 Subject: [PATCH] fix(channels): fail fast when all channel startups fail (#2262) * fix(channels): fail fast when all channel startups fail * fix(channels): preserve startup errors and cover fail-fast semantics --- pkg/channels/manager.go | 41 ++++++++++++- pkg/channels/manager_test.go | 112 ++++++++++++++++++++++++++++++++++- 2 files changed, 150 insertions(+), 3 deletions(-) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 5fbf35ebf..239448a1c 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -12,6 +12,7 @@ import ( "fmt" "math" "net/http" + "sort" "sync" "time" @@ -513,6 +514,8 @@ func (m *Manager) StartAll(ctx context.Context) error { dispatchCtx, cancel := context.WithCancel(ctx) m.dispatchTask = &asyncTask{cancel: cancel} + failedStarts := make([]error, 0, len(m.channels)) + failedNames := make([]string, 0, len(m.channels)) for name, channel := range m.channels { logger.InfoCF("channels", "Starting channel", map[string]any{ @@ -523,6 +526,8 @@ func (m *Manager) StartAll(ctx context.Context) error { "channel": name, "error": err.Error(), }) + failedStarts = append(failedStarts, fmt.Errorf("channel %s: %w", name, err)) + failedNames = append(failedNames, name) continue } // Lazily create worker only after channel starts successfully @@ -532,6 +537,36 @@ func (m *Manager) StartAll(ctx context.Context) error { go m.runMediaWorker(dispatchCtx, name, w) } + if len(m.channels) > 0 && len(m.workers) == 0 { + if m.dispatchTask != nil { + m.dispatchTask.cancel() + m.dispatchTask = nil + } + + sort.Strings(failedNames) + if len(failedStarts) == 0 { + return fmt.Errorf("failed to start any enabled channels") + } + + logger.ErrorCF("channels", "All enabled channels failed to start", map[string]any{ + "failed": len(failedNames), + "total": len(m.channels), + "failed_channels": failedNames, + }) + + return fmt.Errorf("failed to start any enabled channels: %w", errors.Join(failedStarts...)) + } + + if len(failedNames) > 0 { + sort.Strings(failedNames) + logger.WarnCF("channels", "Some channels failed to start", map[string]any{ + "failed": len(failedNames), + "started": len(m.workers), + "total": len(m.channels), + "failed_channels": failedNames, + }) + } + // Start the dispatcher that reads from the bus and routes to workers go m.dispatchOutbound(dispatchCtx) go m.dispatchOutboundMedia(dispatchCtx) @@ -553,7 +588,11 @@ func (m *Manager) StartAll(ctx context.Context) error { }() } - logger.InfoC("channels", "All channels started") + logger.InfoCF("channels", "Channel startup completed", map[string]any{ + "started": len(m.workers), + "failed": len(failedNames), + "total": len(m.channels), + }) return nil } diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index e76212905..937b32d2c 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -19,6 +19,8 @@ import ( type mockChannel struct { BaseChannel sendFn func(ctx context.Context, msg bus.OutboundMessage) error + startFn func(ctx context.Context) error + stopFn func(ctx context.Context) error sentMessages []bus.OutboundMessage placeholdersSent int editedMessages int @@ -33,8 +35,19 @@ func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri return nil, m.sendFn(ctx, msg) } -func (m *mockChannel) Start(ctx context.Context) error { return nil } -func (m *mockChannel) Stop(ctx context.Context) error { return nil } +func (m *mockChannel) Start(ctx context.Context) error { + if m.startFn != nil { + return m.startFn(ctx) + } + return nil +} + +func (m *mockChannel) Stop(ctx context.Context) error { + if m.stopFn != nil { + return m.stopFn(ctx) + } + return nil +} func (m *mockChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { m.placeholdersSent++ @@ -86,6 +99,101 @@ func newTestManager() *Manager { return &Manager{ channels: make(map[string]Channel), workers: make(map[string]*channelWorker), + bus: bus.NewMessageBus(), + } +} + +func TestStartAll_AllChannelsFail_ReturnsJoinedError(t *testing.T) { + m := newTestManager() + errA := errors.New("channel-a start failed") + errB := errors.New("channel-b start failed") + + m.channels["a"] = &mockChannel{ + startFn: func(_ context.Context) error { return errA }, + } + m.channels["b"] = &mockChannel{ + startFn: func(_ context.Context) error { return errB }, + } + + err := m.StartAll(t.Context()) + if err == nil { + t.Fatal("expected StartAll to fail when all channels fail") + } + if !strings.Contains(err.Error(), "failed to start any enabled channels") { + t.Fatalf("unexpected error: %v", err) + } + if !errors.Is(err, errA) { + t.Fatalf("expected error to wrap errA, got: %v", err) + } + if !errors.Is(err, errB) { + t.Fatalf("expected error to wrap errB, got: %v", err) + } + if len(m.workers) != 0 { + t.Fatalf("expected no workers on full startup failure, got %d", len(m.workers)) + } + if m.dispatchTask != nil { + t.Fatal("expected dispatch task to be cleared on full startup failure") + } +} + +func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) { + m := newTestManager() + errBad := errors.New("bad channel start failed") + processed := make(chan struct{}, 1) + + m.channels["good"] = &mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + if msg.Channel == "good" { + select { + case processed <- struct{}{}: + default: + } + } + return nil + }, + } + m.channels["bad"] = &mockChannel{ + startFn: func(_ context.Context) error { return errBad }, + } + + err := m.StartAll(t.Context()) + if err != nil { + t.Fatalf("expected StartAll to succeed with partial channel failures, got: %v", err) + } + if len(m.workers) != 1 { + t.Fatalf("expected exactly 1 active worker, got %d", len(m.workers)) + } + if _, ok := m.workers["good"]; !ok { + t.Fatal("expected worker for successful channel 'good'") + } + if _, ok := m.workers["bad"]; ok { + t.Fatal("did not expect worker for failed channel 'bad'") + } + if m.dispatchTask == nil { + t.Fatal("expected dispatch task to run when at least one channel starts") + } + + pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer pubCancel() + if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: "good", + ChatID: "chat-1", + Content: "hello", + }); err != nil { + t.Fatalf("PublishOutbound() error = %v", err) + } + + select { + case <-processed: + // worker processed outbound message as expected + case <-time.After(2 * time.Second): + t.Fatal("expected successful channel worker to process outbound message") + } + + stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer stopCancel() + if err := m.StopAll(stopCtx); err != nil { + t.Fatalf("StopAll() error = %v", err) } }