mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(channels): fail fast when all channel startups fail (#2262)
* fix(channels): fail fast when all channel startups fail * fix(channels): preserve startup errors and cover fail-fast semantics
This commit is contained in:
+40
-1
@@ -12,6 +12,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -513,6 +514,8 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
dispatchCtx, cancel := context.WithCancel(ctx)
|
||||
m.dispatchTask = &asyncTask{cancel: cancel}
|
||||
failedStarts := make([]error, 0, len(m.channels))
|
||||
failedNames := make([]string, 0, len(m.channels))
|
||||
|
||||
for name, channel := range m.channels {
|
||||
logger.InfoCF("channels", "Starting channel", map[string]any{
|
||||
@@ -523,6 +526,8 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
failedStarts = append(failedStarts, fmt.Errorf("channel %s: %w", name, err))
|
||||
failedNames = append(failedNames, name)
|
||||
continue
|
||||
}
|
||||
// Lazily create worker only after channel starts successfully
|
||||
@@ -532,6 +537,36 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
}
|
||||
|
||||
if len(m.channels) > 0 && len(m.workers) == 0 {
|
||||
if m.dispatchTask != nil {
|
||||
m.dispatchTask.cancel()
|
||||
m.dispatchTask = nil
|
||||
}
|
||||
|
||||
sort.Strings(failedNames)
|
||||
if len(failedStarts) == 0 {
|
||||
return fmt.Errorf("failed to start any enabled channels")
|
||||
}
|
||||
|
||||
logger.ErrorCF("channels", "All enabled channels failed to start", map[string]any{
|
||||
"failed": len(failedNames),
|
||||
"total": len(m.channels),
|
||||
"failed_channels": failedNames,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to start any enabled channels: %w", errors.Join(failedStarts...))
|
||||
}
|
||||
|
||||
if len(failedNames) > 0 {
|
||||
sort.Strings(failedNames)
|
||||
logger.WarnCF("channels", "Some channels failed to start", map[string]any{
|
||||
"failed": len(failedNames),
|
||||
"started": len(m.workers),
|
||||
"total": len(m.channels),
|
||||
"failed_channels": failedNames,
|
||||
})
|
||||
}
|
||||
|
||||
// Start the dispatcher that reads from the bus and routes to workers
|
||||
go m.dispatchOutbound(dispatchCtx)
|
||||
go m.dispatchOutboundMedia(dispatchCtx)
|
||||
@@ -553,7 +588,11 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
}()
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "All channels started")
|
||||
logger.InfoCF("channels", "Channel startup completed", map[string]any{
|
||||
"started": len(m.workers),
|
||||
"failed": len(failedNames),
|
||||
"total": len(m.channels),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
type mockChannel struct {
|
||||
BaseChannel
|
||||
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
|
||||
startFn func(ctx context.Context) error
|
||||
stopFn func(ctx context.Context) error
|
||||
sentMessages []bus.OutboundMessage
|
||||
placeholdersSent int
|
||||
editedMessages int
|
||||
@@ -33,8 +35,19 @@ func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
return nil, m.sendFn(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *mockChannel) Start(ctx context.Context) error { return nil }
|
||||
func (m *mockChannel) Stop(ctx context.Context) error { return nil }
|
||||
func (m *mockChannel) Start(ctx context.Context) error {
|
||||
if m.startFn != nil {
|
||||
return m.startFn(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockChannel) Stop(ctx context.Context) error {
|
||||
if m.stopFn != nil {
|
||||
return m.stopFn(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
m.placeholdersSent++
|
||||
@@ -86,6 +99,101 @@ func newTestManager() *Manager {
|
||||
return &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
bus: bus.NewMessageBus(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAll_AllChannelsFail_ReturnsJoinedError(t *testing.T) {
|
||||
m := newTestManager()
|
||||
errA := errors.New("channel-a start failed")
|
||||
errB := errors.New("channel-b start failed")
|
||||
|
||||
m.channels["a"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errA },
|
||||
}
|
||||
m.channels["b"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errB },
|
||||
}
|
||||
|
||||
err := m.StartAll(t.Context())
|
||||
if err == nil {
|
||||
t.Fatal("expected StartAll to fail when all channels fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to start any enabled channels") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !errors.Is(err, errA) {
|
||||
t.Fatalf("expected error to wrap errA, got: %v", err)
|
||||
}
|
||||
if !errors.Is(err, errB) {
|
||||
t.Fatalf("expected error to wrap errB, got: %v", err)
|
||||
}
|
||||
if len(m.workers) != 0 {
|
||||
t.Fatalf("expected no workers on full startup failure, got %d", len(m.workers))
|
||||
}
|
||||
if m.dispatchTask != nil {
|
||||
t.Fatal("expected dispatch task to be cleared on full startup failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
m := newTestManager()
|
||||
errBad := errors.New("bad channel start failed")
|
||||
processed := make(chan struct{}, 1)
|
||||
|
||||
m.channels["good"] = &mockChannel{
|
||||
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
|
||||
if msg.Channel == "good" {
|
||||
select {
|
||||
case processed <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
m.channels["bad"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errBad },
|
||||
}
|
||||
|
||||
err := m.StartAll(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("expected StartAll to succeed with partial channel failures, got: %v", err)
|
||||
}
|
||||
if len(m.workers) != 1 {
|
||||
t.Fatalf("expected exactly 1 active worker, got %d", len(m.workers))
|
||||
}
|
||||
if _, ok := m.workers["good"]; !ok {
|
||||
t.Fatal("expected worker for successful channel 'good'")
|
||||
}
|
||||
if _, ok := m.workers["bad"]; ok {
|
||||
t.Fatal("did not expect worker for failed channel 'bad'")
|
||||
}
|
||||
if m.dispatchTask == nil {
|
||||
t.Fatal("expected dispatch task to run when at least one channel starts")
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: "good",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishOutbound() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
// worker processed outbound message as expected
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected successful channel worker to process outbound message")
|
||||
}
|
||||
|
||||
stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer stopCancel()
|
||||
if err := m.StopAll(stopCtx); err != nil {
|
||||
t.Fatalf("StopAll() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user