diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index e051add1f..c49769761 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -96,7 +96,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { // Send sends a message to DingTalk via the chatbot reply API func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("dingtalk channel not running") + return channels.ErrNotRunning } // Get session webhook from storage @@ -197,7 +197,7 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c contentBytes, ) if err != nil { - return fmt.Errorf("failed to send reply: %w", err) + return fmt.Errorf("dingtalk send: %w", channels.ErrTemporary) } return nil diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 7977d32e1..d5524f7f9 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -113,7 +113,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro c.stopTyping(msg.ChatID) if !c.IsRunning() { - return fmt.Errorf("discord bot not running") + return channels.ErrNotRunning } channelID := msg.ChatID @@ -142,11 +142,11 @@ func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content strin select { case err := <-done: if err != nil { - return fmt.Errorf("failed to send discord message: %w", err) + return fmt.Errorf("discord send: %w", channels.ErrTemporary) } return nil case <-sendCtx.Done(): - return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + return sendCtx.Err() } } diff --git a/pkg/channels/errutil.go b/pkg/channels/errutil.go new file mode 100644 index 000000000..319e3c980 --- /dev/null +++ b/pkg/channels/errutil.go @@ -0,0 +1,30 @@ +package channels + +import ( + "fmt" + "net/http" +) + +// ClassifySendError wraps a raw error with the appropriate sentinel based on +// an HTTP status code. Channels that perform HTTP API calls should use this +// in their Send path. +func ClassifySendError(statusCode int, rawErr error) error { + switch { + case statusCode == http.StatusTooManyRequests: + return fmt.Errorf("%w: %v", ErrRateLimit, rawErr) + case statusCode >= 500: + return fmt.Errorf("%w: %v", ErrTemporary, rawErr) + case statusCode >= 400: + return fmt.Errorf("%w: %v", ErrSendFailed, rawErr) + default: + return rawErr + } +} + +// ClassifyNetError wraps a network/timeout error as ErrTemporary. +func ClassifyNetError(err error) error { + if err == nil { + return nil + } + return fmt.Errorf("%w: %v", ErrTemporary, err) +} diff --git a/pkg/channels/errutil_test.go b/pkg/channels/errutil_test.go new file mode 100644 index 000000000..e3d35f65b --- /dev/null +++ b/pkg/channels/errutil_test.go @@ -0,0 +1,97 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestClassifySendError(t *testing.T) { + raw := fmt.Errorf("some API error") + + tests := []struct { + name string + statusCode int + wantIs error + wantNil bool + }{ + {"429 -> ErrRateLimit", 429, ErrRateLimit, false}, + {"500 -> ErrTemporary", 500, ErrTemporary, false}, + {"502 -> ErrTemporary", 502, ErrTemporary, false}, + {"503 -> ErrTemporary", 503, ErrTemporary, false}, + {"400 -> ErrSendFailed", 400, ErrSendFailed, false}, + {"403 -> ErrSendFailed", 403, ErrSendFailed, false}, + {"404 -> ErrSendFailed", 404, ErrSendFailed, false}, + {"200 -> raw error", 200, nil, false}, + {"201 -> raw error", 201, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ClassifySendError(tt.statusCode, raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if tt.wantIs != nil { + if !errors.Is(err, tt.wantIs) { + t.Errorf("errors.Is(err, %v) = false, want true; err = %v", tt.wantIs, err) + } + } else { + // Should return the raw error unchanged + if err != raw { + t.Errorf("expected raw error to be returned unchanged for status %d, got %v", tt.statusCode, err) + } + } + }) + } +} + +func TestClassifySendErrorNoFalsePositive(t *testing.T) { + raw := fmt.Errorf("some error") + + // 429 should NOT match ErrTemporary or ErrSendFailed + err := ClassifySendError(429, raw) + if errors.Is(err, ErrTemporary) { + t.Error("429 should not match ErrTemporary") + } + if errors.Is(err, ErrSendFailed) { + t.Error("429 should not match ErrSendFailed") + } + + // 500 should NOT match ErrRateLimit or ErrSendFailed + err = ClassifySendError(500, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("500 should not match ErrRateLimit") + } + if errors.Is(err, ErrSendFailed) { + t.Error("500 should not match ErrSendFailed") + } + + // 400 should NOT match ErrRateLimit or ErrTemporary + err = ClassifySendError(400, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("400 should not match ErrRateLimit") + } + if errors.Is(err, ErrTemporary) { + t.Error("400 should not match ErrTemporary") + } +} + +func TestClassifyNetError(t *testing.T) { + t.Run("nil error returns nil", func(t *testing.T) { + if err := ClassifyNetError(nil); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + + t.Run("non-nil error wraps as ErrTemporary", func(t *testing.T) { + raw := fmt.Errorf("connection refused") + err := ClassifyNetError(raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if !errors.Is(err, ErrTemporary) { + t.Errorf("errors.Is(err, ErrTemporary) = false, want true; err = %v", err) + } + }) +} diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index d67823974..5245cd99d 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -91,7 +91,7 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("feishu channel not running") + return channels.ErrNotRunning } if msg.ChatID == "" { @@ -115,11 +115,11 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("failed to send feishu message: %w", err) + return fmt.Errorf("feishu send: %w", channels.ErrTemporary) } if !resp.Success() { - return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) + return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) } logger.DebugCF("feishu", "Feishu message sent", map[string]any{ diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 272a53c6e..fd06334d5 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -491,7 +491,7 @@ func (c *LINEChannel) resolveChatID(source lineSource) string { // using a cached reply token, then falls back to the Push API. func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("line channel not running") + return channels.ErrNotRunning } // Load and consume quote token for this chat @@ -582,13 +582,13 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("API request failed: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody))) } return nil diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index 05213b095..b5b7259f9 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -215,7 +215,14 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error { func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("maixcam channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.clientsMux.RLock() @@ -246,7 +253,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro "client": conn.RemoteAddr().String(), "error": err.Error(), }) - sendErr = err + sendErr = fmt.Errorf("maixcam send: %w", channels.ErrTemporary) } _ = conn.SetWriteDeadline(time.Time{}) } diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index e2fe541f1..76950663e 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -373,7 +373,14 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("OneBot channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.mu.Lock() @@ -412,7 +419,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error logger.ErrorCF("onebot", "Failed to send message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("onebot send: %w", channels.ErrTemporary) } if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 429e23cbf..69f323e6e 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -114,7 +114,7 @@ func (c *QQChannel) Stop(ctx context.Context) error { func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("QQ bot not running") + return channels.ErrNotRunning } // 构造消息 @@ -128,7 +128,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("qq send: %w", channels.ErrTemporary) } return nil diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 53d7c0609..9e066e00a 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -112,7 +112,7 @@ func (c *SlackChannel) Stop(ctx context.Context) error { func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("slack channel not running") + return channels.ErrNotRunning } channelID, threadTS := parseSlackChatID(msg.ChatID) @@ -130,7 +130,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) if err != nil { - return fmt.Errorf("failed to send slack message: %w", err) + return fmt.Errorf("slack send: %w", channels.ErrTemporary) } if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index af7155799..a07eb6579 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -164,6 +164,12 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { return true }) + // Clean up placeholder state + c.placeholders.Range(func(key, value any) bool { + c.placeholders.Delete(key) + return true + }) + // Stop the bot handler if c.bh != nil { c.bh.Stop() @@ -179,12 +185,12 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("telegram bot not running") + return channels.ErrNotRunning } chatID, err := parseChatID(msg.ChatID) if err != nil { - return fmt.Errorf("invalid chat ID: %w", err) + return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } // Stop thinking animation @@ -217,8 +223,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err "error": err.Error(), }) tgMsg.ParseMode = "" - _, err = c.bot.SendMessage(ctx, tgMsg) - return err + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + return fmt.Errorf("telegram send: %w", channels.ErrTemporary) + } } return nil diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index eb1711d75..41861e8fc 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -207,7 +207,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { // Send sends a message to WeCom user proactively using access token func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom_app channel not running") + return channels.ErrNotRunning } accessToken := c.getAccessToken() @@ -548,10 +548,15 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send message: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) @@ -603,10 +608,15 @@ func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send message: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index bbac8611a..7960802fb 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -166,7 +166,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { // For delayed responses, we use the webhook URL func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom channel not running") + return channels.ErrNotRunning } logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ @@ -433,10 +433,15 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send webhook reply: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index b5f3e99d7..97032334f 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -94,11 +94,22 @@ func (c *WhatsAppChannel) Stop(ctx context.Context) error { } func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + // Check ctx before acquiring lock + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { - return fmt.Errorf("whatsapp connection not established") + return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary) } payload := map[string]any{ @@ -115,7 +126,7 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { _ = c.conn.SetWriteDeadline(time.Time{}) - return fmt.Errorf("failed to send message: %w", err) + return fmt.Errorf("whatsapp send: %w", channels.ErrTemporary) } _ = c.conn.SetWriteDeadline(time.Time{})