fix(channels): fail fast when all channel startups fail (#2262)

* fix(channels): fail fast when all channel startups fail

* fix(channels): preserve startup errors and cover fail-fast semantics
This commit is contained in:
SakoroYou
2026-04-02 14:14:47 +08:00
committed by GitHub
parent adf78092da
commit 257aa0ff57
2 changed files with 150 additions and 3 deletions
+40 -1
View File
@@ -12,6 +12,7 @@ import (
"fmt"
"math"
"net/http"
"sort"
"sync"
"time"
@@ -513,6 +514,8 @@ func (m *Manager) StartAll(ctx context.Context) error {
dispatchCtx, cancel := context.WithCancel(ctx)
m.dispatchTask = &asyncTask{cancel: cancel}
failedStarts := make([]error, 0, len(m.channels))
failedNames := make([]string, 0, len(m.channels))
for name, channel := range m.channels {
logger.InfoCF("channels", "Starting channel", map[string]any{
@@ -523,6 +526,8 @@ func (m *Manager) StartAll(ctx context.Context) error {
"channel": name,
"error": err.Error(),
})
failedStarts = append(failedStarts, fmt.Errorf("channel %s: %w", name, err))
failedNames = append(failedNames, name)
continue
}
// Lazily create worker only after channel starts successfully
@@ -532,6 +537,36 @@ func (m *Manager) StartAll(ctx context.Context) error {
go m.runMediaWorker(dispatchCtx, name, w)
}
if len(m.channels) > 0 && len(m.workers) == 0 {
if m.dispatchTask != nil {
m.dispatchTask.cancel()
m.dispatchTask = nil
}
sort.Strings(failedNames)
if len(failedStarts) == 0 {
return fmt.Errorf("failed to start any enabled channels")
}
logger.ErrorCF("channels", "All enabled channels failed to start", map[string]any{
"failed": len(failedNames),
"total": len(m.channels),
"failed_channels": failedNames,
})
return fmt.Errorf("failed to start any enabled channels: %w", errors.Join(failedStarts...))
}
if len(failedNames) > 0 {
sort.Strings(failedNames)
logger.WarnCF("channels", "Some channels failed to start", map[string]any{
"failed": len(failedNames),
"started": len(m.workers),
"total": len(m.channels),
"failed_channels": failedNames,
})
}
// Start the dispatcher that reads from the bus and routes to workers
go m.dispatchOutbound(dispatchCtx)
go m.dispatchOutboundMedia(dispatchCtx)
@@ -553,7 +588,11 @@ func (m *Manager) StartAll(ctx context.Context) error {
}()
}
logger.InfoC("channels", "All channels started")
logger.InfoCF("channels", "Channel startup completed", map[string]any{
"started": len(m.workers),
"failed": len(failedNames),
"total": len(m.channels),
})
return nil
}
+110 -2
View File
@@ -19,6 +19,8 @@ import (
type mockChannel struct {
BaseChannel
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
startFn func(ctx context.Context) error
stopFn func(ctx context.Context) error
sentMessages []bus.OutboundMessage
placeholdersSent int
editedMessages int
@@ -33,8 +35,19 @@ func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
return nil, m.sendFn(ctx, msg)
}
func (m *mockChannel) Start(ctx context.Context) error { return nil }
func (m *mockChannel) Stop(ctx context.Context) error { return nil }
func (m *mockChannel) Start(ctx context.Context) error {
if m.startFn != nil {
return m.startFn(ctx)
}
return nil
}
func (m *mockChannel) Stop(ctx context.Context) error {
if m.stopFn != nil {
return m.stopFn(ctx)
}
return nil
}
func (m *mockChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
m.placeholdersSent++
@@ -86,6 +99,101 @@ func newTestManager() *Manager {
return &Manager{
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
bus: bus.NewMessageBus(),
}
}
func TestStartAll_AllChannelsFail_ReturnsJoinedError(t *testing.T) {
m := newTestManager()
errA := errors.New("channel-a start failed")
errB := errors.New("channel-b start failed")
m.channels["a"] = &mockChannel{
startFn: func(_ context.Context) error { return errA },
}
m.channels["b"] = &mockChannel{
startFn: func(_ context.Context) error { return errB },
}
err := m.StartAll(t.Context())
if err == nil {
t.Fatal("expected StartAll to fail when all channels fail")
}
if !strings.Contains(err.Error(), "failed to start any enabled channels") {
t.Fatalf("unexpected error: %v", err)
}
if !errors.Is(err, errA) {
t.Fatalf("expected error to wrap errA, got: %v", err)
}
if !errors.Is(err, errB) {
t.Fatalf("expected error to wrap errB, got: %v", err)
}
if len(m.workers) != 0 {
t.Fatalf("expected no workers on full startup failure, got %d", len(m.workers))
}
if m.dispatchTask != nil {
t.Fatal("expected dispatch task to be cleared on full startup failure")
}
}
func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
m := newTestManager()
errBad := errors.New("bad channel start failed")
processed := make(chan struct{}, 1)
m.channels["good"] = &mockChannel{
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
if msg.Channel == "good" {
select {
case processed <- struct{}{}:
default:
}
}
return nil
},
}
m.channels["bad"] = &mockChannel{
startFn: func(_ context.Context) error { return errBad },
}
err := m.StartAll(t.Context())
if err != nil {
t.Fatalf("expected StartAll to succeed with partial channel failures, got: %v", err)
}
if len(m.workers) != 1 {
t.Fatalf("expected exactly 1 active worker, got %d", len(m.workers))
}
if _, ok := m.workers["good"]; !ok {
t.Fatal("expected worker for successful channel 'good'")
}
if _, ok := m.workers["bad"]; ok {
t.Fatal("did not expect worker for failed channel 'bad'")
}
if m.dispatchTask == nil {
t.Fatal("expected dispatch task to run when at least one channel starts")
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: "good",
ChatID: "chat-1",
Content: "hello",
}); err != nil {
t.Fatalf("PublishOutbound() error = %v", err)
}
select {
case <-processed:
// worker processed outbound message as expected
case <-time.After(2 * time.Second):
t.Fatal("expected successful channel worker to process outbound message")
}
stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer stopCancel()
if err := m.StopAll(stopCtx); err != nil {
t.Fatalf("StopAll() error = %v", err)
}
}