From 528c57dda0d3bd234a050cec4ca2532a77f8de11 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 7 Apr 2026 21:19:11 +0800 Subject: [PATCH] refactor(channels): merge non-web fixes from main --- pkg/channels/manager.go | 41 +++++++++- pkg/channels/pico/pico.go | 126 +++++++++++++++++++++++++++++- pkg/channels/pico/protocol.go | 9 +++ pkg/channels/telegram/telegram.go | 59 +++++++++++++- 4 files changed, 229 insertions(+), 6 deletions(-) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 60cea9e78..7cd93c266 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -12,6 +12,7 @@ import ( "fmt" "math" "net/http" + "sort" "sync" "time" @@ -531,6 +532,8 @@ func (m *Manager) StartAll(ctx context.Context) error { dispatchCtx, cancel := context.WithCancel(ctx) m.dispatchTask = &asyncTask{cancel: cancel} + failedStarts := make([]error, 0, len(m.channels)) + failedNames := make([]string, 0, len(m.channels)) for name, channel := range m.channels { logger.InfoCF("channels", "Starting channel", map[string]any{ @@ -541,6 +544,8 @@ func (m *Manager) StartAll(ctx context.Context) error { "channel": name, "error": err.Error(), }) + failedStarts = append(failedStarts, fmt.Errorf("channel %s: %w", name, err)) + failedNames = append(failedNames, name) continue } // Lazily create worker only after channel starts successfully @@ -550,6 +555,36 @@ func (m *Manager) StartAll(ctx context.Context) error { go m.runMediaWorker(dispatchCtx, name, w) } + if len(m.channels) > 0 && len(m.workers) == 0 { + if m.dispatchTask != nil { + m.dispatchTask.cancel() + m.dispatchTask = nil + } + + sort.Strings(failedNames) + if len(failedStarts) == 0 { + return fmt.Errorf("failed to start any enabled channels") + } + + logger.ErrorCF("channels", "All enabled channels failed to start", map[string]any{ + "failed": len(failedNames), + "total": len(m.channels), + "failed_channels": failedNames, + }) + + return fmt.Errorf("failed to start any enabled channels: %w", errors.Join(failedStarts...)) + } + + if len(failedNames) > 0 { + sort.Strings(failedNames) + logger.WarnCF("channels", "Some channels failed to start", map[string]any{ + "failed": len(failedNames), + "started": len(m.workers), + "total": len(m.channels), + "failed_channels": failedNames, + }) + } + // Start the dispatcher that reads from the bus and routes to workers go m.dispatchOutbound(dispatchCtx) go m.dispatchOutboundMedia(dispatchCtx) @@ -571,7 +606,11 @@ func (m *Manager) StartAll(ctx context.Context) error { }() } - logger.InfoC("channels", "All channels started") + logger.InfoCF("channels", "Channel startup completed", map[string]any{ + "started": len(m.workers), + "failed": len(failedNames), + "total": len(m.channels), + }) return nil } diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 4f3f4aba3..80ab84cf1 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -2,6 +2,7 @@ package pico import ( "context" + "encoding/base64" "encoding/json" "fmt" "net/http" @@ -30,6 +31,14 @@ type picoConn struct { cancel context.CancelFunc // cancels per-connection goroutines (e.g. pingLoop) } +var allowedInlineImageMIMETypes = map[string]struct{}{ + "image/jpeg": {}, + "image/png": {}, + "image/gif": {}, + "image/webp": {}, + "image/bmp": {}, +} + // writeJSON sends a JSON message to the connection with write locking. func (pc *picoConn) writeJSON(v any) error { if pc.closed.Load() { @@ -516,6 +525,9 @@ func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { case TypeMessageSend: c.handleMessageSend(pc, msg) + case TypeMediaSend: + c.handleMessageSend(pc, msg) + default: errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) pc.writeJSON(errMsg) @@ -525,8 +537,19 @@ func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { // handleMessageSend processes an inbound message.send from a client. func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { content, _ := msg.Payload["content"].(string) - if strings.TrimSpace(content) == "" { - errMsg := newError("empty_content", "message content is empty") + media, err := parseInlineImageMedia(msg.Payload) + if err != nil { + errMsg := newErrorWithPayload("invalid_media", err.Error(), map[string]any{ + "request_id": msg.ID, + }) + pc.writeJSON(errMsg) + return + } + + if strings.TrimSpace(content) == "" && len(media) == 0 { + errMsg := newErrorWithPayload("empty_content", "message content is empty", map[string]any{ + "request_id": msg.ID, + }) pc.writeJSON(errMsg) return } @@ -548,6 +571,7 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { logger.DebugCF("pico", "Received message", map[string]any{ "session_id": sessionID, "preview": truncate(content, 50), + "media": len(media), }) sender := bus.SenderInfo{ @@ -569,7 +593,7 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { Raw: metadata, } - c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) + c.HandleInboundContext(c.ctx, chatID, content, media, inboundCtx, sender) } // truncate truncates a string to maxLen runes. @@ -580,3 +604,99 @@ func truncate(s string, maxLen int) string { } return string(runes[:maxLen]) + "..." } + +func parseInlineImageMedia(payload map[string]any) ([]string, error) { + if len(payload) == 0 { + return nil, nil + } + + raw, ok := payload["media"] + if !ok || raw == nil { + return nil, nil + } + + switch values := raw.(type) { + case []any: + media := make([]string, 0, len(values)) + for i, item := range values { + value, err := inlineImageValue(item) + if err != nil { + return nil, fmt.Errorf("media[%d]: %w", i, err) + } + if err := validateInlineImageDataURL(value); err != nil { + return nil, fmt.Errorf("media[%d]: %w", i, err) + } + media = append(media, value) + } + return media, nil + case []string: + media := make([]string, 0, len(values)) + for i, value := range values { + value = strings.TrimSpace(value) + if err := validateInlineImageDataURL(value); err != nil { + return nil, fmt.Errorf("media[%d]: %w", i, err) + } + media = append(media, value) + } + return media, nil + case string: + value := strings.TrimSpace(values) + if err := validateInlineImageDataURL(value); err != nil { + return nil, err + } + return []string{value}, nil + default: + return nil, fmt.Errorf("media must be a string or array of strings") + } +} + +func inlineImageValue(item any) (string, error) { + switch value := item.(type) { + case string: + value = strings.TrimSpace(value) + if value == "" { + return "", fmt.Errorf("image payload is empty") + } + return value, nil + case map[string]any: + for _, key := range []string{"url", "data_url"} { + if raw, ok := value[key].(string); ok && strings.TrimSpace(raw) != "" { + return strings.TrimSpace(raw), nil + } + } + return "", fmt.Errorf("image payload must include url or data_url") + default: + return "", fmt.Errorf("image payload must be a string or object") + } +} + +func validateInlineImageDataURL(mediaURL string) error { + if mediaURL == "" { + return fmt.Errorf("image payload is empty") + } + if !strings.HasPrefix(mediaURL, "data:image/") { + return fmt.Errorf("only inline image data URLs are supported") + } + + header, data, found := strings.Cut(mediaURL, ",") + if !found || strings.TrimSpace(data) == "" { + return fmt.Errorf("image data URL is malformed") + } + if !strings.Contains(header, ";base64") { + return fmt.Errorf("image data URL must be base64 encoded") + } + mimeType, _, _ := strings.Cut(strings.TrimPrefix(header, "data:"), ";") + if _, ok := allowedInlineImageMIMETypes[mimeType]; !ok { + return fmt.Errorf("unsupported image format: %s", mimeType) + } + + data = strings.TrimSpace(data) + if base64.StdEncoding.DecodedLen(len(data)) > config.DefaultMaxMediaSize { + return fmt.Errorf("image exceeds %d byte limit", config.DefaultMaxMediaSize) + } + if _, err := base64.StdEncoding.DecodeString(data); err != nil { + return fmt.Errorf("invalid base64 image data") + } + + return nil +} diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index 192c96164..17fb12d2b 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -46,3 +46,12 @@ func newError(code, message string) PicoMessage { "message": message, }) } + +func newErrorWithPayload(code, message string, payload map[string]any) PicoMessage { + if payload == nil { + payload = map[string]any{} + } + payload["code"] = code + payload["message"] = message + return newMessage(TypeError, payload) +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 31a5afb30..464551351 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/binary" + "errors" "fmt" "io" "net/http" @@ -377,8 +378,38 @@ func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messag } _, err = c.bot.EditMessageText(ctx, editMsg) if err != nil { - logParseFailed(err, useMarkdownV2) - _, err = c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(cid), mid, content)) + // If it failed because it was already modified (likely from a previous + // attempt that timed out on our end but landed on Telegram), we treat + // it as success to prevent the Manager from sending a duplicate message. + if strings.Contains(err.Error(), "message is not modified") { + return nil + } + + // Only fallback to plain text if the error looks like a parsing failure (Bad Request). + // Network errors or timeouts should NOT trigger a retry with different content. + if strings.Contains(err.Error(), "Bad Request") { + logParseFailed(err, useMarkdownV2) + _, err = c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(cid), mid, content)) + } + } + + if err != nil { + if strings.Contains(err.Error(), "message is not modified") { + return nil + } + + if isPostConnectError(err) { + logger.WarnCF( + "telegram", + "EditMessage likely landed but result is unknown; swallowing error to prevent duplicate", + map[string]any{ + "chat_id": chatID, + "mid": mid, + "error": err.Error(), + }, + ) + return nil // Swallow to prevent Manager fallback to a new SendMessage + } } return err @@ -1135,3 +1166,27 @@ func cryptoRandInt() int { _, _ = rand.Read(b[:]) return int(binary.BigEndian.Uint32(b[:])) | 1 // ensure non-zero } + +// isPostConnectError identifies network errors that likely occurred after +// the request was transmitted to Telegram (e.g. dropped connection while +// waiting for response). Swallowing these for edits prevents duplicate +// fallbacks, at the small risk of leaving a stale placeholder if the +// edit never actually reached the server. +func isPostConnectError(err error) bool { + if err == nil { + return false + } + + // Context errors (timeout/canceled) are too broad; they can be triggered + // locally before any data is sent. Never swallow them. + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return false + } + + msg := strings.ToLower(err.Error()) + // Narrowly target connection dropouts where the request likely landed. + return strings.Contains(msg, "connection reset by peer") || + strings.Contains(msg, "unexpected eof") || + strings.Contains(msg, "connection closed by foreign host") || + strings.Contains(msg, "broken pipe") +}