diff --git a/pkg/channels/weixin/api.go b/pkg/channels/weixin/api.go index 7f9b3b5c6..6dc52790e 100644 --- a/pkg/channels/weixin/api.go +++ b/pkg/channels/weixin/api.go @@ -12,6 +12,14 @@ import ( "net/http" "net/url" "path" + "strconv" +) + +const ( + weixinChannelVersion = "2.1.1" + weixinIlinkAppID = "bot" + // 2.1.1 encoded as 0x00MMNNPP => 0x00020101 => 131329 + weixinClientVersion = 131329 ) type ApiClient struct { @@ -80,13 +88,9 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons } req.Header.Set("Content-Type", "application/json") - if endpoint == "ilink/bot/get_bot_qrcode" || endpoint == "ilink/bot/get_qrcode_status" { - // QR routes have different headers sometimes, but let's stick to base ones - if endpoint == "ilink/bot/get_qrcode_status" { - // Use direct map assignment to send exact header name the Tencent API expects - req.Header["iLink-App-ClientVersion"] = []string{"1"} - } - } else { + req.Header["iLink-App-Id"] = []string{weixinIlinkAppID} + req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)} + if endpoint != "ilink/bot/get_bot_qrcode" && endpoint != "ilink/bot/get_qrcode_status" { req.Header["AuthorizationType"] = []string{"ilink_bot_token"} req.Header["X-WECHAT-UIN"] = []string{randomWechatUIN()} if c.Token != "" { @@ -119,7 +123,7 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons } func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpdatesResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp GetUpdatesResp err := c.post(ctx, "ilink/bot/getupdates", req, &resp) if err != nil { @@ -129,7 +133,7 @@ func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpda } func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendMessageResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp SendMessageResp if err := c.post(ctx, "ilink/bot/sendmessage", req, &resp); err != nil { return nil, err @@ -138,7 +142,7 @@ func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendM } func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*GetUploadUrlResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp GetUploadUrlResp err := c.post(ctx, "ilink/bot/getuploadurl", req, &resp) if err != nil { @@ -148,7 +152,7 @@ func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*Get } func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfigResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp GetConfigResp if err := c.post(ctx, "ilink/bot/getconfig", req, &resp); err != nil { return nil, err @@ -157,7 +161,7 @@ func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfig } func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTypingResp, error) { - req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"} + req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion} var resp SendTypingResp if err := c.post(ctx, "ilink/bot/sendtyping", req, &resp); err != nil { return nil, err @@ -165,38 +169,51 @@ func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTyp return &resp, nil } -func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) { - // get_bot_qrcode is GET, not POST +func (c *ApiClient) getQR(ctx context.Context, endpoint string, query map[string]string, respObj any) error { u, err := url.Parse(c.BaseURL) if err != nil { - return nil, err + return err } - u.Path = path.Join(u.Path, "ilink/bot/get_bot_qrcode") + u.Path = path.Join(u.Path, endpoint) q := u.Query() - q.Set("bot_type", botType) + for key, value := range query { + q.Set(key, value) + } u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) if err != nil { - return nil, err + return err } + req.Header["iLink-App-Id"] = []string{weixinIlinkAppID} + req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)} resp, err := c.HttpClient.Do(req) if err != nil { - return nil, err + return err } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return err } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("get_bot_qrcode failed: %d %s", resp.StatusCode, string(respBody)) + return fmt.Errorf("%s failed: %d %s", endpoint, resp.StatusCode, string(respBody)) + } + if err := json.Unmarshal(respBody, respObj); err != nil { + return err } + return nil +} + +func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) { + // get_bot_qrcode is GET, not POST var qrcodeResp QRCodeResponse - if err := json.Unmarshal(respBody, &qrcodeResp); err != nil { + if err := c.getQR(ctx, "ilink/bot/get_bot_qrcode", map[string]string{ + "bot_type": botType, + }, &qrcodeResp); err != nil { return nil, err } return &qrcodeResp, nil @@ -204,37 +221,10 @@ func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeRespo func (c *ApiClient) GetQRCodeStatus(ctx context.Context, qrcode string) (*StatusResponse, error) { // get_qrcode_status is GET - u, err := url.Parse(c.BaseURL) - if err != nil { - return nil, err - } - u.Path = path.Join(u.Path, "ilink/bot/get_qrcode_status") - q := u.Query() - q.Set("qrcode", qrcode) - u.RawQuery = q.Encode() - - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) - if err != nil { - return nil, err - } - req.Header["iLink-App-ClientVersion"] = []string{"1"} - - resp, err := c.HttpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("get_qrcode_status failed: %d %s", resp.StatusCode, string(respBody)) - } - var statusResp StatusResponse - if err := json.Unmarshal(respBody, &statusResp); err != nil { + if err := c.getQR(ctx, "ilink/bot/get_qrcode_status", map[string]string{ + "qrcode": qrcode, + }, &statusResp); err != nil { return nil, err } return &statusResp, nil diff --git a/pkg/channels/weixin/auth.go b/pkg/channels/weixin/auth.go index 52ec2a6df..0a0e597c1 100644 --- a/pkg/channels/weixin/auth.go +++ b/pkg/channels/weixin/auth.go @@ -40,6 +40,7 @@ func PerformLoginInteractive( if err != nil { return "", "", "", "", fmt.Errorf("failed to create api client: %w", err) } + pollAPI := api logger.InfoC("weixin", "Requesting Weixin QR code...") qrResp, err := api.GetQRCode(ctx, opts.BotType) @@ -76,7 +77,7 @@ func PerformLoginInteractive( case <-timeoutCtx.Done(): return "", "", "", "", fmt.Errorf("login timeout") case <-pollTicker.C: - statusResp, err := api.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode) + statusResp, err := pollAPI.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode) if err != nil { // Long poll timeout or temporary error continue @@ -99,6 +100,27 @@ func PerformLoginInteractive( }) return statusResp.BotToken, statusResp.IlinkUserID, statusResp.IlinkBotID, statusResp.Baseurl, nil + case "scaned_but_redirect": + if statusResp.RedirectHost == "" { + logger.WarnC( + "weixin", + "scaned_but_redirect received without redirect_host; continuing on current host", + ) + continue + } + nextBaseURL := "https://" + statusResp.RedirectHost + "/" + nextAPI, nextErr := NewApiClient(nextBaseURL, "", opts.Proxy) + if nextErr != nil { + logger.WarnCF("weixin", "Failed to switch QR polling host", map[string]any{ + "redirect_host": statusResp.RedirectHost, + "error": nextErr.Error(), + }) + continue + } + pollAPI = nextAPI + logger.InfoCF("weixin", "Switched QR polling host", map[string]any{ + "redirect_host": statusResp.RedirectHost, + }) case "expired": return "", "", "", "", fmt.Errorf("qrcode expired, please try again") default: diff --git a/pkg/channels/weixin/media.go b/pkg/channels/weixin/media.go index 72af27438..4da7f0db9 100644 --- a/pkg/channels/weixin/media.go +++ b/pkg/channels/weixin/media.go @@ -34,6 +34,8 @@ const ( weixinMediaMaxBytes = 100 << 20 weixinTypingKeepAlive = 5 * time.Second weixinUploadRetryMax = 3 + weixinDownloadRetryMax = 2 + weixinDownloadRetryDelay = 300 * time.Millisecond weixinVoiceTranscodeTimeout = 15 * time.Second ) @@ -163,49 +165,108 @@ func buildCDNDownloadURL(base, encryptedQueryParam string) string { "/download?encrypted_query_param=" + url.QueryEscape(encryptedQueryParam) } +func shouldRetryCDNDownload(statusCode int) bool { + // statusCode=0 represents transport/build errors from the HTTP client. + return statusCode == 0 || statusCode >= 500 || statusCode == http.StatusTooManyRequests +} + func buildCDNUploadURL(base, uploadParam, filekey string) string { return strings.TrimRight(base, "/") + "/upload?encrypted_query_param=" + url.QueryEscape(uploadParam) + "&filekey=" + url.QueryEscape(filekey) } -func (c *WeixinChannel) downloadCDNBuffer(ctx context.Context, encryptedQueryParam string) ([]byte, error) { - req, err := http.NewRequestWithContext( - ctx, - http.MethodGet, - buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam), - nil, - ) +func uniqCDNURLs(urls []string) []string { + seen := make(map[string]struct{}, len(urls)) + out := make([]string, 0, len(urls)) + for _, raw := range urls { + u := strings.TrimSpace(raw) + if u == "" { + continue + } + if _, ok := seen[u]; ok { + continue + } + seen[u] = struct{}{} + out = append(out, u) + } + return out +} + +func (c *WeixinChannel) downloadCDNBufferOnce(ctx context.Context, downloadURL string) ([]byte, int, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) if err != nil { - return nil, err + return nil, 0, err } resp, err := c.api.HttpClient.Do(req) if err != nil { - return nil, err + return nil, 0, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) - return nil, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body)) + return nil, resp.StatusCode, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body)) } data, err := io.ReadAll(io.LimitReader(resp.Body, weixinMediaMaxBytes+1)) if err != nil { - return nil, err + return nil, resp.StatusCode, err } if len(data) > weixinMediaMaxBytes { - return nil, fmt.Errorf("cdn media too large: %d bytes", len(data)) + return nil, resp.StatusCode, fmt.Errorf("cdn media too large: %d bytes", len(data)) } - return data, nil + return data, resp.StatusCode, nil +} + +func (c *WeixinChannel) downloadCDNBuffer( + ctx context.Context, + encryptedQueryParam, + fullURL string, +) ([]byte, error) { + candidates := uniqCDNURLs([]string{ + strings.TrimSpace(fullURL), + func() string { + if strings.TrimSpace(encryptedQueryParam) == "" { + return "" + } + return buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam) + }(), + }) + if len(candidates) == 0 { + return nil, fmt.Errorf("missing CDN download URL") + } + + var lastErr error + for _, downloadURL := range candidates { + for attempt := 1; attempt <= weixinDownloadRetryMax; attempt++ { + data, statusCode, err := c.downloadCDNBufferOnce(ctx, downloadURL) + if err == nil { + return data, nil + } + lastErr = fmt.Errorf("%w (attempt=%d url=%s)", err, attempt, downloadURL) + if !shouldRetryCDNDownload(statusCode) { + break + } + if attempt < weixinDownloadRetryMax { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(weixinDownloadRetryDelay): + } + } + } + } + return nil, lastErr } func (c *WeixinChannel) downloadAndDecryptCDNBuffer( ctx context.Context, encryptedQueryParam string, + fullURL string, key []byte, ) ([]byte, error) { - data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam) + data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam, fullURL) if err != nil { return nil, err } @@ -215,6 +276,33 @@ func (c *WeixinChannel) downloadAndDecryptCDNBuffer( return decryptAESECB(data, key) } +func (c *WeixinChannel) downloadImageBuffer( + ctx context.Context, + img *ImageItem, + key []byte, +) ([]byte, error) { + if img == nil { + return nil, fmt.Errorf("image item is nil") + } + if img.Media != nil { + data, err := c.downloadAndDecryptCDNBuffer(ctx, img.Media.EncryptQueryParam, img.Media.FullURL, key) + if err == nil { + return data, nil + } + if img.ThumbMedia == nil { + return nil, fmt.Errorf("image download failed: %w", err) + } + } + if img.ThumbMedia != nil { + data, err := c.downloadAndDecryptCDNBuffer(ctx, img.ThumbMedia.EncryptQueryParam, img.ThumbMedia.FullURL, key) + if err == nil { + return data, nil + } + return nil, fmt.Errorf("image download failed: %w", err) + } + return nil, fmt.Errorf("image media is nil") +} + func detectMediaMetadata(data []byte, fallbackName, fallbackContentType string) (string, string) { contentType := strings.TrimSpace(fallbackContentType) ext := filepath.Ext(fallbackName) @@ -310,15 +398,18 @@ func isDownloadableMediaItem(item *MessageItem) bool { switch item.Type { case MessageItemTypeImage: - return item.ImageItem != nil && item.ImageItem.Media != nil && item.ImageItem.Media.EncryptQueryParam != "" + return item.ImageItem != nil && item.ImageItem.Media != nil && + (item.ImageItem.Media.EncryptQueryParam != "" || item.ImageItem.Media.FullURL != "") case MessageItemTypeVideo: - return item.VideoItem != nil && item.VideoItem.Media != nil && item.VideoItem.Media.EncryptQueryParam != "" + return item.VideoItem != nil && item.VideoItem.Media != nil && + (item.VideoItem.Media.EncryptQueryParam != "" || item.VideoItem.Media.FullURL != "") case MessageItemTypeFile: - return item.FileItem != nil && item.FileItem.Media != nil && item.FileItem.Media.EncryptQueryParam != "" + return item.FileItem != nil && item.FileItem.Media != nil && + (item.FileItem.Media.EncryptQueryParam != "" || item.FileItem.Media.FullURL != "") case MessageItemTypeVoice: return item.VoiceItem != nil && item.VoiceItem.Media != nil && - item.VoiceItem.Media.EncryptQueryParam != "" && + (item.VoiceItem.Media.EncryptQueryParam != "" || item.VoiceItem.Media.FullURL != "") && strings.TrimSpace(item.VoiceItem.Text) == "" default: return false @@ -434,16 +525,20 @@ func (c *WeixinChannel) downloadMediaFromItem( switch item.Type { case MessageItemTypeImage: + if item.ImageItem == nil { + return "", fmt.Errorf("image media is nil") + } key, ok, err := imageAESKey(item.ImageItem) if err != nil { return "", err } - data, err := c.downloadAndDecryptCDNBuffer(ctx, item.ImageItem.Media.EncryptQueryParam, func() []byte { + decryptKey := func() []byte { if ok { return key } return nil - }()) + }() + data, err := c.downloadImageBuffer(ctx, item.ImageItem, decryptKey) if err != nil { return "", err } @@ -454,7 +549,12 @@ func (c *WeixinChannel) downloadMediaFromItem( if err != nil { return "", err } - silk, err := c.downloadAndDecryptCDNBuffer(ctx, item.VoiceItem.Media.EncryptQueryParam, key) + silk, err := c.downloadAndDecryptCDNBuffer( + ctx, + item.VoiceItem.Media.EncryptQueryParam, + item.VoiceItem.Media.FullURL, + key, + ) if err != nil { return "", err } @@ -468,7 +568,12 @@ func (c *WeixinChannel) downloadMediaFromItem( if err != nil { return "", err } - data, err := c.downloadAndDecryptCDNBuffer(ctx, item.FileItem.Media.EncryptQueryParam, key) + data, err := c.downloadAndDecryptCDNBuffer( + ctx, + item.FileItem.Media.EncryptQueryParam, + item.FileItem.Media.FullURL, + key, + ) if err != nil { return "", err } @@ -484,7 +589,12 @@ func (c *WeixinChannel) downloadMediaFromItem( if err != nil { return "", err } - data, err := c.downloadAndDecryptCDNBuffer(ctx, item.VideoItem.Media.EncryptQueryParam, key) + data, err := c.downloadAndDecryptCDNBuffer( + ctx, + item.VideoItem.Media.EncryptQueryParam, + item.VideoItem.Media.FullURL, + key, + ) if err != nil { return "", err } @@ -701,11 +811,13 @@ func (c *WeixinChannel) uploadLocalFile( } return nil, fmt.Errorf("getuploadurl failed: ret=%d errcode=%d errmsg=%s", resp.Ret, resp.Errcode, resp.Errmsg) } - if strings.TrimSpace(resp.UploadParam) == "" { - return nil, fmt.Errorf("getuploadurl returned empty upload_param") + uploadParam := strings.TrimSpace(resp.UploadParam) + uploadFullURL := strings.TrimSpace(resp.UploadFullURL) + if uploadParam == "" && uploadFullURL == "" { + return nil, fmt.Errorf("getuploadurl returned no upload URL") } - downloadParam, err := c.uploadBufferToCDN(ctx, data, resp.UploadParam, filekey, aesKey) + downloadParam, err := c.uploadBufferToCDN(ctx, data, uploadParam, uploadFullURL, filekey, aesKey) if err != nil { return nil, err } @@ -723,6 +835,7 @@ func (c *WeixinChannel) uploadBufferToCDN( ctx context.Context, plaintext []byte, uploadParam, + uploadFullURL, filekey string, aesKey []byte, ) (string, error) { @@ -731,7 +844,13 @@ func (c *WeixinChannel) uploadBufferToCDN( return "", err } - uploadURL := buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey) + uploadURL := strings.TrimSpace(uploadFullURL) + if uploadURL == "" { + if strings.TrimSpace(uploadParam) == "" { + return "", fmt.Errorf("missing CDN upload URL") + } + uploadURL = buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey) + } var lastErr error for attempt := 1; attempt <= weixinUploadRetryMax; attempt++ { diff --git a/pkg/channels/weixin/types.go b/pkg/channels/weixin/types.go index 74c6e63c3..f2c03894f 100644 --- a/pkg/channels/weixin/types.go +++ b/pkg/channels/weixin/types.go @@ -38,6 +38,7 @@ type GetUploadUrlResp struct { APIStatus UploadParam string `json:"upload_param,omitempty"` ThumbUploadParam string `json:"thumb_upload_param,omitempty"` + UploadFullURL string `json:"upload_full_url,omitempty"` } const ( @@ -69,6 +70,7 @@ type CDNMedia struct { EncryptQueryParam string `json:"encrypt_query_param,omitempty"` AesKey string `json:"aes_key,omitempty"` // base64 encoded EncryptType int `json:"encrypt_type,omitempty"` + FullURL string `json:"full_url,omitempty"` } type ImageItem struct { @@ -202,9 +204,10 @@ type QRCodeResponse struct { } type StatusResponse struct { - Status string `json:"status"` // "wait", "scaned", "confirmed", "expired" - BotToken string `json:"bot_token,omitempty"` - IlinkBotID string `json:"ilink_bot_id,omitempty"` - Baseurl string `json:"baseurl,omitempty"` - IlinkUserID string `json:"ilink_user_id,omitempty"` + Status string `json:"status"` // "wait", "scaned", "confirmed", "expired", "scaned_but_redirect" + BotToken string `json:"bot_token,omitempty"` + IlinkBotID string `json:"ilink_bot_id,omitempty"` + Baseurl string `json:"baseurl,omitempty"` + IlinkUserID string `json:"ilink_user_id,omitempty"` + RedirectHost string `json:"redirect_host,omitempty"` } diff --git a/pkg/channels/weixin/weixin_test.go b/pkg/channels/weixin/weixin_test.go index 62984c965..b41b930db 100644 --- a/pkg/channels/weixin/weixin_test.go +++ b/pkg/channels/weixin/weixin_test.go @@ -72,7 +72,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) { typingCache: make(map[string]typingTicketCacheEntry), } - got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", key) + got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "", key) if err != nil { t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err) } @@ -81,6 +81,116 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) { } } +func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) { + key := []byte("1234567890abcdef") + plaintext := []byte("hello weixin") + ciphertext, err := encryptAESECB(plaintext, key) + if err != nil { + t.Fatalf("encryptAESECB() error = %v", err) + } + + fullURLAttempts := 0 + ch := &WeixinChannel{ + api: &ApiClient{ + HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.String() == "https://full.example.com/download" { + fullURLAttempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(ciphertext)), + Header: make(http.Header), + }, nil + } + t.Fatalf("unexpected fallback request: %s", r.URL.String()) + return nil, nil + })}, + }, + config: config.WeixinConfig{ + CDNBaseURL: "https://cdn.example.com", + }, + typingCache: make(map[string]typingTicketCacheEntry), + } + + got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "https://full.example.com/download", key) + if err != nil { + t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext) + } + if fullURLAttempts == 0 { + t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts) + } +} + +func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t *testing.T) { + key := []byte("1234567890abcdef") + plaintext := []byte("hello weixin") + ciphertext, err := encryptAESECB(plaintext, key) + if err != nil { + t.Fatalf("encryptAESECB() error = %v", err) + } + + fullURLAttempts := 0 + constructedAttempts := 0 + ch := &WeixinChannel{ + api: &ApiClient{ + HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.String() == "https://full.example.com/download?encrypted_query_param=token&taskid=123" { + fullURLAttempts++ + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(bytes.NewReader(nil)), + Header: make(http.Header), + }, nil + } + if r.URL.String() != "https://cdn.example.com/download?encrypted_query_param=token" { + t.Fatalf("unexpected fallback request: %s", r.URL.String()) + } + constructedAttempts++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(ciphertext)), + Header: make(http.Header), + }, nil + })}, + }, + config: config.WeixinConfig{ + CDNBaseURL: "https://cdn.example.com", + }, + typingCache: make(map[string]typingTicketCacheEntry), + } + + got, err := ch.downloadAndDecryptCDNBuffer( + context.Background(), + "token", + "https://full.example.com/download?encrypted_query_param=token&taskid=123", + key, + ) + if err != nil { + t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext) + } + if fullURLAttempts == 0 { + t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts) + } + if constructedAttempts == 0 { + t.Fatalf("constructedAttempts = %d, want > 0", constructedAttempts) + } +} + +func TestBuildCDNDownloadURLEscapesOpaqueToken(t *testing.T) { + token := "MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%3D" + + got := buildCDNDownloadURL("https://cdn.example.com", token) + + if got != "https://cdn.example.com/download?encrypted_query_param=MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%253D" { + t.Fatalf("buildCDNDownloadURL() = %q", got) + } +} + func TestUploadBufferToCDN(t *testing.T) { key := []byte("1234567890abcdef") plaintext := []byte("upload me") @@ -120,7 +230,7 @@ func TestUploadBufferToCDN(t *testing.T) { typingCache: make(map[string]typingTicketCacheEntry), } - got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "file-key", key) + got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "", "file-key", key) if err != nil { t.Fatalf("uploadBufferToCDN() error = %v", err) }