diff --git a/README.it.md b/README.it.md index 5ed82f15d..5874ca27a 100644 --- a/README.it.md +++ b/README.it.md @@ -558,7 +558,7 @@ Connetti PicoClaw al Social Network degli Agent semplicemente inviando un singol | `picoclaw skills list` | Elenca le skill installate | | `picoclaw skills install` | Installa una skill | | `picoclaw migrate` | Migra i dati dalle versioni precedenti | -| `picoclaw auth login` | Autenticazione con i provider | +| `picoclaw auth login` | Autenticazione con i provider | ### ⏰ Task Pianificati / Promemoria diff --git a/README.ja.md b/README.ja.md index 47d43eb56..2f17516af 100644 --- a/README.ja.md +++ b/README.ja.md @@ -541,7 +541,7 @@ CLI または統合チャットアプリからメッセージを 1 つ送るだ ## 🖥️ CLI リファレンス -| コマンド | 説明 | +| コマンド | 説明 | | ------------------------- | ------------------------------ | | `picoclaw onboard` | 設定&ワークスペースの初期化 | | `picoclaw auth weixin` | WeChat アカウントを QR で接続 | diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index 273e2b020..450dcce54 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -6,6 +6,7 @@ package dingtalk import ( "context" "fmt" + "strings" "sync" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" @@ -135,13 +136,17 @@ func (c *DingTalkChannel) onChatBotMessageReceived( ctx context.Context, data *chatbot.BotCallbackDataModel, ) ([]byte, error) { + if data == nil { + return nil, nil + } + // Extract message content from Text field - content := data.Text.Content + content := strings.TrimSpace(data.Text.Content) if content == "" { // Try to extract from Content interface{} if Text is empty if contentMap, ok := data.Content.(map[string]any); ok { if textContent, ok := contentMap["content"].(string); ok { - content = textContent + content = strings.TrimSpace(textContent) } } } @@ -150,12 +155,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived( return nil, nil // Ignore empty messages } - senderID := data.SenderStaffId - senderNick := data.SenderNick - chatID := senderID - if data.ConversationType != "1" { - // For group chats - chatID = data.ConversationId + senderID := strings.TrimSpace(data.SenderStaffId) + if senderID == "" { + senderID = strings.TrimSpace(data.SenderId) + } + senderNick := strings.TrimSpace(data.SenderNick) + + chatID := strings.TrimSpace(data.ConversationId) + if chatID == "" && data.ConversationType == "1" { + // Fallback for direct chats when conversation_id is absent. + chatID = senderID + } + if chatID == "" { + return nil, nil } // Store the session webhook for this chat so we can reply later @@ -171,11 +183,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived( var peer bus.Peer if data.ConversationType == "1" { - peer = bus.Peer{Kind: "direct", ID: senderID} + peerID := senderID + if peerID == "" { + peerID = chatID + } + peer = bus.Peer{Kind: "direct", ID: peerID} } else { peer = bus.Peer{Kind: "group", ID: data.ConversationId} + isMentioned := data.IsInAtList + if isMentioned { + content = stripLeadingAtMentions(content) + } // In group chats, apply unified group trigger filtering - respond, cleaned := c.ShouldRespondInGroup(false, content) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) if !respond { return nil, nil } @@ -189,10 +209,18 @@ func (c *DingTalkChannel) onChatBotMessageReceived( }) // Build sender info + platformID := senderID + if platformID == "" { + platformID = chatID + } + resolvedSenderID := senderID + if resolvedSenderID == "" { + resolvedSenderID = platformID + } sender := bus.SenderInfo{ Platform: "dingtalk", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("dingtalk", senderID), + PlatformID: platformID, + CanonicalID: identity.BuildCanonicalID("dingtalk", platformID), DisplayName: senderNick, } @@ -201,7 +229,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived( } // Handle the message through the base channel - c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender) + c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus @@ -229,3 +257,19 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c return nil } + +func stripLeadingAtMentions(content string) string { + fields := strings.Fields(content) + if len(fields) == 0 { + return "" + } + + i := 0 + for i < len(fields) && strings.HasPrefix(fields[i], "@") { + i++ + } + if i == 0 { + return strings.TrimSpace(content) + } + return strings.Join(fields[i:], " ") +} diff --git a/pkg/channels/dingtalk/dingtalk_test.go b/pkg/channels/dingtalk/dingtalk_test.go new file mode 100644 index 000000000..3b517aef4 --- /dev/null +++ b/pkg/channels/dingtalk/dingtalk_test.go @@ -0,0 +1,131 @@ +package dingtalk + +import ( + "context" + "testing" + "time" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkChannel, *bus.MessageBus) { + t.Helper() + + if cfg.ClientID == "" { + cfg.ClientID = "test-client-id" + } + if cfg.ClientSecret() == "" { + cfg.SetClientSecret("test-client-secret") + } + + msgBus := bus.NewMessageBus() + ch, err := NewDingTalkChannel(cfg, msgBus) + if err != nil { + t.Fatalf("new channel: %v", err) + } + return ch, msgBus +} + +func mustReceiveInbound(t *testing.T, msgBus *bus.MessageBus) bus.InboundMessage { + t.Helper() + select { + case msg := <-msgBus.InboundChan(): + return msg + case <-time.After(time.Second): + t.Fatal("expected inbound message") + return bus.InboundMessage{} + } +} + +func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention(t *testing.T) { + ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{ + GroupTrigger: config.GroupTriggerConfig{MentionOnly: true}, + }) + + _, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{ + Text: chatbot.BotCallbackDataTextModel{Content: " @bot /help "}, + SenderStaffId: "staff-123", + SenderNick: "Alice", + ConversationType: "2", + ConversationId: "group-abc", + SessionWebhook: "https://example.com/webhook", + IsInAtList: true, + }) + if err != nil { + t.Fatalf("handler returned error: %v", err) + } + + inbound := mustReceiveInbound(t, msgBus) + if inbound.Channel != "dingtalk" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.ChatID != "group-abc" { + t.Fatalf("chat_id=%q", inbound.ChatID) + } + if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" { + t.Fatalf("peer=%+v", inbound.Peer) + } + if inbound.Content != "/help" { + t.Fatalf("content=%q", inbound.Content) + } +} + +func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *testing.T) { + ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{}) + + _, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{ + Text: chatbot.BotCallbackDataTextModel{Content: "ping"}, + SenderStaffId: "", + SenderId: "openid-user-42", + SenderNick: "Bob", + ConversationType: "1", + ConversationId: "conv-direct-42", + SessionWebhook: "https://example.com/webhook-direct", + }) + if err != nil { + t.Fatalf("handler returned error: %v", err) + } + + inbound := mustReceiveInbound(t, msgBus) + if inbound.ChatID != "conv-direct-42" { + t.Fatalf("chat_id=%q", inbound.ChatID) + } + if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" { + t.Fatalf("peer=%+v", inbound.Peer) + } + if inbound.SenderID != "dingtalk:openid-user-42" { + t.Fatalf("sender_id=%q", inbound.SenderID) + } + + if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok { + t.Fatal("expected session webhook keyed by conversation_id") + } + if _, ok := ch.sessionWebhooks.Load(""); ok { + t.Fatal("unexpected empty chat_id webhook key") + } +} + +func TestStripLeadingAtMentions(t *testing.T) { + tests := []struct { + name string + input string + wantOut string + }{ + {name: "single mention and command", input: "@bot /help", wantOut: "/help"}, + {name: "multiple mentions", input: "@bot @alice /new", wantOut: "/new"}, + {name: "no mention", input: "/help", wantOut: "/help"}, + {name: "mention only", input: "@bot", wantOut: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripLeadingAtMentions(tt.input) + if got != tt.wantOut { + t.Fatalf("stripLeadingAtMentions(%q)=%q want=%q", tt.input, got, tt.wantOut) + } + }) + } +} diff --git a/pkg/channels/weixin/api.go b/pkg/channels/weixin/api.go index 7f9b3b5c6..6dc52790e 100644 --- a/pkg/channels/weixin/api.go +++ b/pkg/channels/weixin/api.go @@ -12,6 +12,14 @@ import ( "net/http" "net/url" "path" + "strconv" +) + +const ( + weixinChannelVersion = "2.1.1" + weixinIlinkAppID = "bot" + // 2.1.1 encoded as 0x00MMNNPP => 0x00020101 => 131329 + weixinClientVersion = 131329 ) type ApiClient struct { @@ -80,13 +88,9 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons } req.Header.Set("Content-Type", "application/json") - if endpoint == "ilink/bot/get_bot_qrcode" || endpoint == "ilink/bot/get_qrcode_status" { - // QR routes have different headers sometimes, but let's stick to base ones - if endpoint == "ilink/bot/get_qrcode_status" { - // Use direct map assignment to send exact header name the Tencent API expects - req.Header["iLink-App-ClientVersion"] = []string{"1"} - } - } else { + req.Header["iLink-App-Id"] = []string{weixinIlinkAppID} + req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)} + if endpoint != "ilink/bot/get_bot_qrcode" && endpoint != "ilink/bot/get_qrcode_status" { req.Header["AuthorizationType"] = []string{"ilink_bot_token"} req.Header["X-WECHAT-UIN"] = []string{randomWechatUIN()} if c.Token != "" { @@ -119,7 +123,7 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons } func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpdatesResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp GetUpdatesResp err := c.post(ctx, "ilink/bot/getupdates", req, &resp) if err != nil { @@ -129,7 +133,7 @@ func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpda } func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendMessageResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp SendMessageResp if err := c.post(ctx, "ilink/bot/sendmessage", req, &resp); err != nil { return nil, err @@ -138,7 +142,7 @@ func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendM } func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*GetUploadUrlResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp GetUploadUrlResp err := c.post(ctx, "ilink/bot/getuploadurl", req, &resp) if err != nil { @@ -148,7 +152,7 @@ func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*Get } func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfigResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp GetConfigResp if err := c.post(ctx, "ilink/bot/getconfig", req, &resp); err != nil { return nil, err @@ -157,7 +161,7 @@ func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfig } func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTypingResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp SendTypingResp if err := c.post(ctx, "ilink/bot/sendtyping", req, &resp); err != nil { return nil, err @@ -165,38 +169,51 @@ func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTyp return &resp, nil } -func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) { - // get_bot_qrcode is GET, not POST +func (c *ApiClient) getQR(ctx context.Context, endpoint string, query map[string]string, respObj any) error { u, err := url.Parse(c.BaseURL) if err != nil { - return nil, err + return err } - u.Path = path.Join(u.Path, "ilink/bot/get_bot_qrcode") + u.Path = path.Join(u.Path, endpoint) q := u.Query() - q.Set("bot_type", botType) + for key, value := range query { + q.Set(key, value) + } u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) if err != nil { - return nil, err + return err } + req.Header["iLink-App-Id"] = []string{weixinIlinkAppID} + req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)} resp, err := c.HttpClient.Do(req) if err != nil { - return nil, err + return err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return err } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("get_bot_qrcode failed: %d %s", resp.StatusCode, string(respBody)) + return fmt.Errorf("%s failed: %d %s", endpoint, resp.StatusCode, string(respBody)) + } + if err := json.Unmarshal(respBody, respObj); err != nil { + return err } + return nil +} + +func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) { + // get_bot_qrcode is GET, not POST var qrcodeResp QRCodeResponse - if err := json.Unmarshal(respBody, &qrcodeResp); err != nil { + if err := c.getQR(ctx, "ilink/bot/get_bot_qrcode", map[string]string{ + "bot_type": botType, + }, &qrcodeResp); err != nil { return nil, err } return &qrcodeResp, nil @@ -204,37 +221,10 @@ func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeRespo func (c *ApiClient) GetQRCodeStatus(ctx context.Context, qrcode string) (*StatusResponse, error) { // get_qrcode_status is GET - u, err := url.Parse(c.BaseURL) - if err != nil { - return nil, err - } - u.Path = path.Join(u.Path, "ilink/bot/get_qrcode_status") - q := u.Query() - q.Set("qrcode", qrcode) - u.RawQuery = q.Encode() - - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) - if err != nil { - return nil, err - } - req.Header["iLink-App-ClientVersion"] = []string{"1"} - - resp, err := c.HttpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("get_qrcode_status failed: %d %s", resp.StatusCode, string(respBody)) - } - var statusResp StatusResponse - if err := json.Unmarshal(respBody, &statusResp); err != nil { + if err := c.getQR(ctx, "ilink/bot/get_qrcode_status", map[string]string{ + "qrcode": qrcode, + }, &statusResp); err != nil { return nil, err } return &statusResp, nil diff --git a/pkg/channels/weixin/auth.go b/pkg/channels/weixin/auth.go index 52ec2a6df..0a0e597c1 100644 --- a/pkg/channels/weixin/auth.go +++ b/pkg/channels/weixin/auth.go @@ -40,6 +40,7 @@ func PerformLoginInteractive( if err != nil { return "", "", "", "", fmt.Errorf("failed to create api client: %w", err) } + pollAPI := api logger.InfoC("weixin", "Requesting Weixin QR code...") qrResp, err := api.GetQRCode(ctx, opts.BotType) @@ -76,7 +77,7 @@ func PerformLoginInteractive( case <-timeoutCtx.Done(): return "", "", "", "", fmt.Errorf("login timeout") case <-pollTicker.C: - statusResp, err := api.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode) + statusResp, err := pollAPI.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode) if err != nil { // Long poll timeout or temporary error continue @@ -99,6 +100,27 @@ func PerformLoginInteractive( }) return statusResp.BotToken, statusResp.IlinkUserID, statusResp.IlinkBotID, statusResp.Baseurl, nil + case "scaned_but_redirect": + if statusResp.RedirectHost == "" { + logger.WarnC( + "weixin", + "scaned_but_redirect received without redirect_host; continuing on current host", + ) + continue + } + nextBaseURL := "https://" + statusResp.RedirectHost + "/" + nextAPI, nextErr := NewApiClient(nextBaseURL, "", opts.Proxy) + if nextErr != nil { + logger.WarnCF("weixin", "Failed to switch QR polling host", map[string]any{ + "redirect_host": statusResp.RedirectHost, + "error": nextErr.Error(), + }) + continue + } + pollAPI = nextAPI + logger.InfoCF("weixin", "Switched QR polling host", map[string]any{ + "redirect_host": statusResp.RedirectHost, + }) case "expired": return "", "", "", "", fmt.Errorf("qrcode expired, please try again") default: diff --git a/pkg/channels/weixin/media.go b/pkg/channels/weixin/media.go index 72af27438..4da7f0db9 100644 --- a/pkg/channels/weixin/media.go +++ b/pkg/channels/weixin/media.go @@ -34,6 +34,8 @@ const ( weixinMediaMaxBytes = 100 << 20 weixinTypingKeepAlive = 5 * time.Second weixinUploadRetryMax = 3 + weixinDownloadRetryMax = 2 + weixinDownloadRetryDelay = 300 * time.Millisecond weixinVoiceTranscodeTimeout = 15 * time.Second ) @@ -163,49 +165,108 @@ func buildCDNDownloadURL(base, encryptedQueryParam string) string { "/download?encrypted_query_param=" + url.QueryEscape(encryptedQueryParam) } +func shouldRetryCDNDownload(statusCode int) bool { + // statusCode=0 represents transport/build errors from the HTTP client. + return statusCode == 0 || statusCode >= 500 || statusCode == http.StatusTooManyRequests +} + func buildCDNUploadURL(base, uploadParam, filekey string) string { return strings.TrimRight(base, "/") + "/upload?encrypted_query_param=" + url.QueryEscape(uploadParam) + "&filekey=" + url.QueryEscape(filekey) } -func (c *WeixinChannel) downloadCDNBuffer(ctx context.Context, encryptedQueryParam string) ([]byte, error) { - req, err := http.NewRequestWithContext( - ctx, - http.MethodGet, - buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam), - nil, - ) +func uniqCDNURLs(urls []string) []string { + seen := make(map[string]struct{}, len(urls)) + out := make([]string, 0, len(urls)) + for _, raw := range urls { + u := strings.TrimSpace(raw) + if u == "" { + continue + } + if _, ok := seen[u]; ok { + continue + } + seen[u] = struct{}{} + out = append(out, u) + } + return out +} + +func (c *WeixinChannel) downloadCDNBufferOnce(ctx context.Context, downloadURL string) ([]byte, int, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) if err != nil { - return nil, err + return nil, 0, err } resp, err := c.api.HttpClient.Do(req) if err != nil { - return nil, err + return nil, 0, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) - return nil, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body)) + return nil, resp.StatusCode, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body)) } data, err := io.ReadAll(io.LimitReader(resp.Body, weixinMediaMaxBytes+1)) if err != nil { - return nil, err + return nil, resp.StatusCode, err } if len(data) > weixinMediaMaxBytes { - return nil, fmt.Errorf("cdn media too large: %d bytes", len(data)) + return nil, resp.StatusCode, fmt.Errorf("cdn media too large: %d bytes", len(data)) } - return data, nil + return data, resp.StatusCode, nil +} + +func (c *WeixinChannel) downloadCDNBuffer( + ctx context.Context, + encryptedQueryParam, + fullURL string, +) ([]byte, error) { + candidates := uniqCDNURLs([]string{ + strings.TrimSpace(fullURL), + func() string { + if strings.TrimSpace(encryptedQueryParam) == "" { + return "" + } + return buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam) + }(), + }) + if len(candidates) == 0 { + return nil, fmt.Errorf("missing CDN download URL") + } + + var lastErr error + for _, downloadURL := range candidates { + for attempt := 1; attempt <= weixinDownloadRetryMax; attempt++ { + data, statusCode, err := c.downloadCDNBufferOnce(ctx, downloadURL) + if err == nil { + return data, nil + } + lastErr = fmt.Errorf("%w (attempt=%d url=%s)", err, attempt, downloadURL) + if !shouldRetryCDNDownload(statusCode) { + break + } + if attempt < weixinDownloadRetryMax { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(weixinDownloadRetryDelay): + } + } + } + } + return nil, lastErr } func (c *WeixinChannel) downloadAndDecryptCDNBuffer( ctx context.Context, encryptedQueryParam string, + fullURL string, key []byte, ) ([]byte, error) { - data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam) + data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam, fullURL) if err != nil { return nil, err } @@ -215,6 +276,33 @@ func (c *WeixinChannel) downloadAndDecryptCDNBuffer( return decryptAESECB(data, key) } +func (c *WeixinChannel) downloadImageBuffer( + ctx context.Context, + img *ImageItem, + key []byte, +) ([]byte, error) { + if img == nil { + return nil, fmt.Errorf("image item is nil") + } + if img.Media != nil { + data, err := c.downloadAndDecryptCDNBuffer(ctx, img.Media.EncryptQueryParam, img.Media.FullURL, key) + if err == nil { + return data, nil + } + if img.ThumbMedia == nil { + return nil, fmt.Errorf("image download failed: %w", err) + } + } + if img.ThumbMedia != nil { + data, err := c.downloadAndDecryptCDNBuffer(ctx, img.ThumbMedia.EncryptQueryParam, img.ThumbMedia.FullURL, key) + if err == nil { + return data, nil + } + return nil, fmt.Errorf("image download failed: %w", err) + } + return nil, fmt.Errorf("image media is nil") +} + func detectMediaMetadata(data []byte, fallbackName, fallbackContentType string) (string, string) { contentType := strings.TrimSpace(fallbackContentType) ext := filepath.Ext(fallbackName) @@ -310,15 +398,18 @@ func isDownloadableMediaItem(item *MessageItem) bool { switch item.Type { case MessageItemTypeImage: - return item.ImageItem != nil && item.ImageItem.Media != nil && item.ImageItem.Media.EncryptQueryParam != "" + return item.ImageItem != nil && item.ImageItem.Media != nil && + (item.ImageItem.Media.EncryptQueryParam != "" || item.ImageItem.Media.FullURL != "") case MessageItemTypeVideo: - return item.VideoItem != nil && item.VideoItem.Media != nil && item.VideoItem.Media.EncryptQueryParam != "" + return item.VideoItem != nil && item.VideoItem.Media != nil && + (item.VideoItem.Media.EncryptQueryParam != "" || item.VideoItem.Media.FullURL != "") case MessageItemTypeFile: - return item.FileItem != nil && item.FileItem.Media != nil && item.FileItem.Media.EncryptQueryParam != "" + return item.FileItem != nil && item.FileItem.Media != nil && + (item.FileItem.Media.EncryptQueryParam != "" || item.FileItem.Media.FullURL != "") case MessageItemTypeVoice: return item.VoiceItem != nil && item.VoiceItem.Media != nil && - item.VoiceItem.Media.EncryptQueryParam != "" && + (item.VoiceItem.Media.EncryptQueryParam != "" || item.VoiceItem.Media.FullURL != "") && strings.TrimSpace(item.VoiceItem.Text) == "" default: return false @@ -434,16 +525,20 @@ func (c *WeixinChannel) downloadMediaFromItem( switch item.Type { case MessageItemTypeImage: + if item.ImageItem == nil { + return "", fmt.Errorf("image media is nil") + } key, ok, err := imageAESKey(item.ImageItem) if err != nil { return "", err } - data, err := c.downloadAndDecryptCDNBuffer(ctx, item.ImageItem.Media.EncryptQueryParam, func() []byte { + decryptKey := func() []byte { if ok { return key } return nil - }()) + }() + data, err := c.downloadImageBuffer(ctx, item.ImageItem, decryptKey) if err != nil { return "", err } @@ -454,7 +549,12 @@ func (c *WeixinChannel) downloadMediaFromItem( if err != nil { return "", err } - silk, err := c.downloadAndDecryptCDNBuffer(ctx, item.VoiceItem.Media.EncryptQueryParam, key) + silk, err := c.downloadAndDecryptCDNBuffer( + ctx, + item.VoiceItem.Media.EncryptQueryParam, + item.VoiceItem.Media.FullURL, + key, + ) if err != nil { return "", err } @@ -468,7 +568,12 @@ func (c *WeixinChannel) downloadMediaFromItem( if err != nil { return "", err } - data, err := c.downloadAndDecryptCDNBuffer(ctx, item.FileItem.Media.EncryptQueryParam, key) + data, err := c.downloadAndDecryptCDNBuffer( + ctx, + item.FileItem.Media.EncryptQueryParam, + item.FileItem.Media.FullURL, + key, + ) if err != nil { return "", err } @@ -484,7 +589,12 @@ func (c *WeixinChannel) downloadMediaFromItem( if err != nil { return "", err } - data, err := c.downloadAndDecryptCDNBuffer(ctx, item.VideoItem.Media.EncryptQueryParam, key) + data, err := c.downloadAndDecryptCDNBuffer( + ctx, + item.VideoItem.Media.EncryptQueryParam, + item.VideoItem.Media.FullURL, + key, + ) if err != nil { return "", err } @@ -701,11 +811,13 @@ func (c *WeixinChannel) uploadLocalFile( } return nil, fmt.Errorf("getuploadurl failed: ret=%d errcode=%d errmsg=%s", resp.Ret, resp.Errcode, resp.Errmsg) } - if strings.TrimSpace(resp.UploadParam) == "" { - return nil, fmt.Errorf("getuploadurl returned empty upload_param") + uploadParam := strings.TrimSpace(resp.UploadParam) + uploadFullURL := strings.TrimSpace(resp.UploadFullURL) + if uploadParam == "" && uploadFullURL == "" { + return nil, fmt.Errorf("getuploadurl returned no upload URL") } - downloadParam, err := c.uploadBufferToCDN(ctx, data, resp.UploadParam, filekey, aesKey) + downloadParam, err := c.uploadBufferToCDN(ctx, data, uploadParam, uploadFullURL, filekey, aesKey) if err != nil { return nil, err } @@ -723,6 +835,7 @@ func (c *WeixinChannel) uploadBufferToCDN( ctx context.Context, plaintext []byte, uploadParam, + uploadFullURL, filekey string, aesKey []byte, ) (string, error) { @@ -731,7 +844,13 @@ func (c *WeixinChannel) uploadBufferToCDN( return "", err } - uploadURL := buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey) + uploadURL := strings.TrimSpace(uploadFullURL) + if uploadURL == "" { + if strings.TrimSpace(uploadParam) == "" { + return "", fmt.Errorf("missing CDN upload URL") + } + uploadURL = buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey) + } var lastErr error for attempt := 1; attempt <= weixinUploadRetryMax; attempt++ { diff --git a/pkg/channels/weixin/types.go b/pkg/channels/weixin/types.go index 74c6e63c3..f2c03894f 100644 --- a/pkg/channels/weixin/types.go +++ b/pkg/channels/weixin/types.go @@ -38,6 +38,7 @@ type GetUploadUrlResp struct { APIStatus UploadParam string `json:"upload_param,omitempty"` ThumbUploadParam string `json:"thumb_upload_param,omitempty"` + UploadFullURL string `json:"upload_full_url,omitempty"` } const ( @@ -69,6 +70,7 @@ type CDNMedia struct { EncryptQueryParam string `json:"encrypt_query_param,omitempty"` AesKey string `json:"aes_key,omitempty"` // base64 encoded EncryptType int `json:"encrypt_type,omitempty"` + FullURL string `json:"full_url,omitempty"` } type ImageItem struct { @@ -202,9 +204,10 @@ type QRCodeResponse struct { } type StatusResponse struct { - Status string `json:"status"` // "wait", "scaned", "confirmed", "expired" - BotToken string `json:"bot_token,omitempty"` - IlinkBotID string `json:"ilink_bot_id,omitempty"` - Baseurl string `json:"baseurl,omitempty"` - IlinkUserID string `json:"ilink_user_id,omitempty"` + Status string `json:"status"` // "wait", "scaned", "confirmed", "expired", "scaned_but_redirect" + BotToken string `json:"bot_token,omitempty"` + IlinkBotID string `json:"ilink_bot_id,omitempty"` + Baseurl string `json:"baseurl,omitempty"` + IlinkUserID string `json:"ilink_user_id,omitempty"` + RedirectHost string `json:"redirect_host,omitempty"` } diff --git a/pkg/channels/weixin/weixin_test.go b/pkg/channels/weixin/weixin_test.go index 62984c965..b41b930db 100644 --- a/pkg/channels/weixin/weixin_test.go +++ b/pkg/channels/weixin/weixin_test.go @@ -72,7 +72,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) { typingCache: make(map[string]typingTicketCacheEntry), } - got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", key) + got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "", key) if err != nil { t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err) } @@ -81,6 +81,116 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) { } } +func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) { + key := []byte("1234567890abcdef") + plaintext := []byte("hello weixin") + ciphertext, err := encryptAESECB(plaintext, key) + if err != nil { + t.Fatalf("encryptAESECB() error = %v", err) + } + + fullURLAttempts := 0 + ch := &WeixinChannel{ + api: &ApiClient{ + HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.String() == "https://full.example.com/download" { + fullURLAttempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(ciphertext)), + Header: make(http.Header), + }, nil + } + t.Fatalf("unexpected fallback request: %s", r.URL.String()) + return nil, nil + })}, + }, + config: config.WeixinConfig{ + CDNBaseURL: "https://cdn.example.com", + }, + typingCache: make(map[string]typingTicketCacheEntry), + } + + got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "https://full.example.com/download", key) + if err != nil { + t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext) + } + if fullURLAttempts == 0 { + t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts) + } +} + +func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t *testing.T) { + key := []byte("1234567890abcdef") + plaintext := []byte("hello weixin") + ciphertext, err := encryptAESECB(plaintext, key) + if err != nil { + t.Fatalf("encryptAESECB() error = %v", err) + } + + fullURLAttempts := 0 + constructedAttempts := 0 + ch := &WeixinChannel{ + api: &ApiClient{ + HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.String() == "https://full.example.com/download?encrypted_query_param=token&taskid=123" { + fullURLAttempts++ + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(bytes.NewReader(nil)), + Header: make(http.Header), + }, nil + } + if r.URL.String() != "https://cdn.example.com/download?encrypted_query_param=token" { + t.Fatalf("unexpected fallback request: %s", r.URL.String()) + } + constructedAttempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(ciphertext)), + Header: make(http.Header), + }, nil + })}, + }, + config: config.WeixinConfig{ + CDNBaseURL: "https://cdn.example.com", + }, + typingCache: make(map[string]typingTicketCacheEntry), + } + + got, err := ch.downloadAndDecryptCDNBuffer( + context.Background(), + "token", + "https://full.example.com/download?encrypted_query_param=token&taskid=123", + key, + ) + if err != nil { + t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext) + } + if fullURLAttempts == 0 { + t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts) + } + if constructedAttempts == 0 { + t.Fatalf("constructedAttempts = %d, want > 0", constructedAttempts) + } +} + +func TestBuildCDNDownloadURLEscapesOpaqueToken(t *testing.T) { + token := "MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%3D" + + got := buildCDNDownloadURL("https://cdn.example.com", token) + + if got != "https://cdn.example.com/download?encrypted_query_param=MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%253D" { + t.Fatalf("buildCDNDownloadURL() = %q", got) + } +} + func TestUploadBufferToCDN(t *testing.T) { key := []byte("1234567890abcdef") plaintext := []byte("upload me") @@ -120,7 +230,7 @@ func TestUploadBufferToCDN(t *testing.T) { typingCache: make(map[string]typingTicketCacheEntry), } - got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "file-key", key) + got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "", "file-key", key) if err != nil { t.Fatalf("uploadBufferToCDN() error = %v", err) }