fix/wechat-new-protocol (#2106)

* fix/wechat-new-protocol

* fix cdn download logic
This commit is contained in:
Hua Audio
2026-03-28 11:18:01 +01:00
committed by GitHub
parent d7c0205052
commit 0e13f6bdec
5 changed files with 331 additions and 87 deletions
+42 -52
View File
@@ -12,6 +12,14 @@ import (
"net/http"
"net/url"
"path"
"strconv"
)
const (
weixinChannelVersion = "2.1.1"
weixinIlinkAppID = "bot"
// 2.1.1 encoded as 0x00MMNNPP => 0x00020101 => 131329
weixinClientVersion = 131329
)
type ApiClient struct {
@@ -80,13 +88,9 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
}
req.Header.Set("Content-Type", "application/json")
if endpoint == "ilink/bot/get_bot_qrcode" || endpoint == "ilink/bot/get_qrcode_status" {
// QR routes have different headers sometimes, but let's stick to base ones
if endpoint == "ilink/bot/get_qrcode_status" {
// Use direct map assignment to send exact header name the Tencent API expects
req.Header["iLink-App-ClientVersion"] = []string{"1"}
}
} else {
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
if endpoint != "ilink/bot/get_bot_qrcode" && endpoint != "ilink/bot/get_qrcode_status" {
req.Header["AuthorizationType"] = []string{"ilink_bot_token"}
req.Header["X-WECHAT-UIN"] = []string{randomWechatUIN()}
if c.Token != "" {
@@ -119,7 +123,7 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
}
func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpdatesResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp GetUpdatesResp
err := c.post(ctx, "ilink/bot/getupdates", req, &resp)
if err != nil {
@@ -129,7 +133,7 @@ func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpda
}
func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendMessageResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp SendMessageResp
if err := c.post(ctx, "ilink/bot/sendmessage", req, &resp); err != nil {
return nil, err
@@ -138,7 +142,7 @@ func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendM
}
func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*GetUploadUrlResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp GetUploadUrlResp
err := c.post(ctx, "ilink/bot/getuploadurl", req, &resp)
if err != nil {
@@ -148,7 +152,7 @@ func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*Get
}
func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfigResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp GetConfigResp
if err := c.post(ctx, "ilink/bot/getconfig", req, &resp); err != nil {
return nil, err
@@ -157,7 +161,7 @@ func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfig
}
func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTypingResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp SendTypingResp
if err := c.post(ctx, "ilink/bot/sendtyping", req, &resp); err != nil {
return nil, err
@@ -165,38 +169,51 @@ func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTyp
return &resp, nil
}
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
// get_bot_qrcode is GET, not POST
func (c *ApiClient) getQR(ctx context.Context, endpoint string, query map[string]string, respObj any) error {
u, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
return err
}
u.Path = path.Join(u.Path, "ilink/bot/get_bot_qrcode")
u.Path = path.Join(u.Path, endpoint)
q := u.Query()
q.Set("bot_type", botType)
for key, value := range query {
q.Set(key, value)
}
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, err
return err
}
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
return err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
return err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("get_bot_qrcode failed: %d %s", resp.StatusCode, string(respBody))
return fmt.Errorf("%s failed: %d %s", endpoint, resp.StatusCode, string(respBody))
}
if err := json.Unmarshal(respBody, respObj); err != nil {
return err
}
return nil
}
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
// get_bot_qrcode is GET, not POST
var qrcodeResp QRCodeResponse
if err := json.Unmarshal(respBody, &qrcodeResp); err != nil {
if err := c.getQR(ctx, "ilink/bot/get_bot_qrcode", map[string]string{
"bot_type": botType,
}, &qrcodeResp); err != nil {
return nil, err
}
return &qrcodeResp, nil
@@ -204,37 +221,10 @@ func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeRespo
func (c *ApiClient) GetQRCodeStatus(ctx context.Context, qrcode string) (*StatusResponse, error) {
// get_qrcode_status is GET
u, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
}
u.Path = path.Join(u.Path, "ilink/bot/get_qrcode_status")
q := u.Query()
q.Set("qrcode", qrcode)
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, err
}
req.Header["iLink-App-ClientVersion"] = []string{"1"}
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("get_qrcode_status failed: %d %s", resp.StatusCode, string(respBody))
}
var statusResp StatusResponse
if err := json.Unmarshal(respBody, &statusResp); err != nil {
if err := c.getQR(ctx, "ilink/bot/get_qrcode_status", map[string]string{
"qrcode": qrcode,
}, &statusResp); err != nil {
return nil, err
}
return &statusResp, nil
+23 -1
View File
@@ -40,6 +40,7 @@ func PerformLoginInteractive(
if err != nil {
return "", "", "", "", fmt.Errorf("failed to create api client: %w", err)
}
pollAPI := api
logger.InfoC("weixin", "Requesting Weixin QR code...")
qrResp, err := api.GetQRCode(ctx, opts.BotType)
@@ -76,7 +77,7 @@ func PerformLoginInteractive(
case <-timeoutCtx.Done():
return "", "", "", "", fmt.Errorf("login timeout")
case <-pollTicker.C:
statusResp, err := api.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
statusResp, err := pollAPI.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
if err != nil {
// Long poll timeout or temporary error
continue
@@ -99,6 +100,27 @@ func PerformLoginInteractive(
})
return statusResp.BotToken, statusResp.IlinkUserID, statusResp.IlinkBotID, statusResp.Baseurl, nil
case "scaned_but_redirect":
if statusResp.RedirectHost == "" {
logger.WarnC(
"weixin",
"scaned_but_redirect received without redirect_host; continuing on current host",
)
continue
}
nextBaseURL := "https://" + statusResp.RedirectHost + "/"
nextAPI, nextErr := NewApiClient(nextBaseURL, "", opts.Proxy)
if nextErr != nil {
logger.WarnCF("weixin", "Failed to switch QR polling host", map[string]any{
"redirect_host": statusResp.RedirectHost,
"error": nextErr.Error(),
})
continue
}
pollAPI = nextAPI
logger.InfoCF("weixin", "Switched QR polling host", map[string]any{
"redirect_host": statusResp.RedirectHost,
})
case "expired":
return "", "", "", "", fmt.Errorf("qrcode expired, please try again")
default:
+146 -27
View File
@@ -34,6 +34,8 @@ const (
weixinMediaMaxBytes = 100 << 20
weixinTypingKeepAlive = 5 * time.Second
weixinUploadRetryMax = 3
weixinDownloadRetryMax = 2
weixinDownloadRetryDelay = 300 * time.Millisecond
weixinVoiceTranscodeTimeout = 15 * time.Second
)
@@ -163,49 +165,108 @@ func buildCDNDownloadURL(base, encryptedQueryParam string) string {
"/download?encrypted_query_param=" + url.QueryEscape(encryptedQueryParam)
}
func shouldRetryCDNDownload(statusCode int) bool {
// statusCode=0 represents transport/build errors from the HTTP client.
return statusCode == 0 || statusCode >= 500 || statusCode == http.StatusTooManyRequests
}
func buildCDNUploadURL(base, uploadParam, filekey string) string {
return strings.TrimRight(base, "/") +
"/upload?encrypted_query_param=" + url.QueryEscape(uploadParam) +
"&filekey=" + url.QueryEscape(filekey)
}
func (c *WeixinChannel) downloadCDNBuffer(ctx context.Context, encryptedQueryParam string) ([]byte, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam),
nil,
)
func uniqCDNURLs(urls []string) []string {
seen := make(map[string]struct{}, len(urls))
out := make([]string, 0, len(urls))
for _, raw := range urls {
u := strings.TrimSpace(raw)
if u == "" {
continue
}
if _, ok := seen[u]; ok {
continue
}
seen[u] = struct{}{}
out = append(out, u)
}
return out
}
func (c *WeixinChannel) downloadCDNBufferOnce(ctx context.Context, downloadURL string) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
if err != nil {
return nil, err
return nil, 0, err
}
resp, err := c.api.HttpClient.Do(req)
if err != nil {
return nil, err
return nil, 0, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
return nil, resp.StatusCode, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
}
data, err := io.ReadAll(io.LimitReader(resp.Body, weixinMediaMaxBytes+1))
if err != nil {
return nil, err
return nil, resp.StatusCode, err
}
if len(data) > weixinMediaMaxBytes {
return nil, fmt.Errorf("cdn media too large: %d bytes", len(data))
return nil, resp.StatusCode, fmt.Errorf("cdn media too large: %d bytes", len(data))
}
return data, nil
return data, resp.StatusCode, nil
}
func (c *WeixinChannel) downloadCDNBuffer(
ctx context.Context,
encryptedQueryParam,
fullURL string,
) ([]byte, error) {
candidates := uniqCDNURLs([]string{
strings.TrimSpace(fullURL),
func() string {
if strings.TrimSpace(encryptedQueryParam) == "" {
return ""
}
return buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam)
}(),
})
if len(candidates) == 0 {
return nil, fmt.Errorf("missing CDN download URL")
}
var lastErr error
for _, downloadURL := range candidates {
for attempt := 1; attempt <= weixinDownloadRetryMax; attempt++ {
data, statusCode, err := c.downloadCDNBufferOnce(ctx, downloadURL)
if err == nil {
return data, nil
}
lastErr = fmt.Errorf("%w (attempt=%d url=%s)", err, attempt, downloadURL)
if !shouldRetryCDNDownload(statusCode) {
break
}
if attempt < weixinDownloadRetryMax {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(weixinDownloadRetryDelay):
}
}
}
}
return nil, lastErr
}
func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
ctx context.Context,
encryptedQueryParam string,
fullURL string,
key []byte,
) ([]byte, error) {
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam)
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam, fullURL)
if err != nil {
return nil, err
}
@@ -215,6 +276,33 @@ func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
return decryptAESECB(data, key)
}
func (c *WeixinChannel) downloadImageBuffer(
ctx context.Context,
img *ImageItem,
key []byte,
) ([]byte, error) {
if img == nil {
return nil, fmt.Errorf("image item is nil")
}
if img.Media != nil {
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.Media.EncryptQueryParam, img.Media.FullURL, key)
if err == nil {
return data, nil
}
if img.ThumbMedia == nil {
return nil, fmt.Errorf("image download failed: %w", err)
}
}
if img.ThumbMedia != nil {
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.ThumbMedia.EncryptQueryParam, img.ThumbMedia.FullURL, key)
if err == nil {
return data, nil
}
return nil, fmt.Errorf("image download failed: %w", err)
}
return nil, fmt.Errorf("image media is nil")
}
func detectMediaMetadata(data []byte, fallbackName, fallbackContentType string) (string, string) {
contentType := strings.TrimSpace(fallbackContentType)
ext := filepath.Ext(fallbackName)
@@ -310,15 +398,18 @@ func isDownloadableMediaItem(item *MessageItem) bool {
switch item.Type {
case MessageItemTypeImage:
return item.ImageItem != nil && item.ImageItem.Media != nil && item.ImageItem.Media.EncryptQueryParam != ""
return item.ImageItem != nil && item.ImageItem.Media != nil &&
(item.ImageItem.Media.EncryptQueryParam != "" || item.ImageItem.Media.FullURL != "")
case MessageItemTypeVideo:
return item.VideoItem != nil && item.VideoItem.Media != nil && item.VideoItem.Media.EncryptQueryParam != ""
return item.VideoItem != nil && item.VideoItem.Media != nil &&
(item.VideoItem.Media.EncryptQueryParam != "" || item.VideoItem.Media.FullURL != "")
case MessageItemTypeFile:
return item.FileItem != nil && item.FileItem.Media != nil && item.FileItem.Media.EncryptQueryParam != ""
return item.FileItem != nil && item.FileItem.Media != nil &&
(item.FileItem.Media.EncryptQueryParam != "" || item.FileItem.Media.FullURL != "")
case MessageItemTypeVoice:
return item.VoiceItem != nil &&
item.VoiceItem.Media != nil &&
item.VoiceItem.Media.EncryptQueryParam != "" &&
(item.VoiceItem.Media.EncryptQueryParam != "" || item.VoiceItem.Media.FullURL != "") &&
strings.TrimSpace(item.VoiceItem.Text) == ""
default:
return false
@@ -434,16 +525,20 @@ func (c *WeixinChannel) downloadMediaFromItem(
switch item.Type {
case MessageItemTypeImage:
if item.ImageItem == nil {
return "", fmt.Errorf("image media is nil")
}
key, ok, err := imageAESKey(item.ImageItem)
if err != nil {
return "", err
}
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.ImageItem.Media.EncryptQueryParam, func() []byte {
decryptKey := func() []byte {
if ok {
return key
}
return nil
}())
}()
data, err := c.downloadImageBuffer(ctx, item.ImageItem, decryptKey)
if err != nil {
return "", err
}
@@ -454,7 +549,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
if err != nil {
return "", err
}
silk, err := c.downloadAndDecryptCDNBuffer(ctx, item.VoiceItem.Media.EncryptQueryParam, key)
silk, err := c.downloadAndDecryptCDNBuffer(
ctx,
item.VoiceItem.Media.EncryptQueryParam,
item.VoiceItem.Media.FullURL,
key,
)
if err != nil {
return "", err
}
@@ -468,7 +568,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
if err != nil {
return "", err
}
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.FileItem.Media.EncryptQueryParam, key)
data, err := c.downloadAndDecryptCDNBuffer(
ctx,
item.FileItem.Media.EncryptQueryParam,
item.FileItem.Media.FullURL,
key,
)
if err != nil {
return "", err
}
@@ -484,7 +589,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
if err != nil {
return "", err
}
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.VideoItem.Media.EncryptQueryParam, key)
data, err := c.downloadAndDecryptCDNBuffer(
ctx,
item.VideoItem.Media.EncryptQueryParam,
item.VideoItem.Media.FullURL,
key,
)
if err != nil {
return "", err
}
@@ -701,11 +811,13 @@ func (c *WeixinChannel) uploadLocalFile(
}
return nil, fmt.Errorf("getuploadurl failed: ret=%d errcode=%d errmsg=%s", resp.Ret, resp.Errcode, resp.Errmsg)
}
if strings.TrimSpace(resp.UploadParam) == "" {
return nil, fmt.Errorf("getuploadurl returned empty upload_param")
uploadParam := strings.TrimSpace(resp.UploadParam)
uploadFullURL := strings.TrimSpace(resp.UploadFullURL)
if uploadParam == "" && uploadFullURL == "" {
return nil, fmt.Errorf("getuploadurl returned no upload URL")
}
downloadParam, err := c.uploadBufferToCDN(ctx, data, resp.UploadParam, filekey, aesKey)
downloadParam, err := c.uploadBufferToCDN(ctx, data, uploadParam, uploadFullURL, filekey, aesKey)
if err != nil {
return nil, err
}
@@ -723,6 +835,7 @@ func (c *WeixinChannel) uploadBufferToCDN(
ctx context.Context,
plaintext []byte,
uploadParam,
uploadFullURL,
filekey string,
aesKey []byte,
) (string, error) {
@@ -731,7 +844,13 @@ func (c *WeixinChannel) uploadBufferToCDN(
return "", err
}
uploadURL := buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
uploadURL := strings.TrimSpace(uploadFullURL)
if uploadURL == "" {
if strings.TrimSpace(uploadParam) == "" {
return "", fmt.Errorf("missing CDN upload URL")
}
uploadURL = buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
}
var lastErr error
for attempt := 1; attempt <= weixinUploadRetryMax; attempt++ {
+8 -5
View File
@@ -38,6 +38,7 @@ type GetUploadUrlResp struct {
APIStatus
UploadParam string `json:"upload_param,omitempty"`
ThumbUploadParam string `json:"thumb_upload_param,omitempty"`
UploadFullURL string `json:"upload_full_url,omitempty"`
}
const (
@@ -69,6 +70,7 @@ type CDNMedia struct {
EncryptQueryParam string `json:"encrypt_query_param,omitempty"`
AesKey string `json:"aes_key,omitempty"` // base64 encoded
EncryptType int `json:"encrypt_type,omitempty"`
FullURL string `json:"full_url,omitempty"`
}
type ImageItem struct {
@@ -202,9 +204,10 @@ type QRCodeResponse struct {
}
type StatusResponse struct {
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired"
BotToken string `json:"bot_token,omitempty"`
IlinkBotID string `json:"ilink_bot_id,omitempty"`
Baseurl string `json:"baseurl,omitempty"`
IlinkUserID string `json:"ilink_user_id,omitempty"`
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired", "scaned_but_redirect"
BotToken string `json:"bot_token,omitempty"`
IlinkBotID string `json:"ilink_bot_id,omitempty"`
Baseurl string `json:"baseurl,omitempty"`
IlinkUserID string `json:"ilink_user_id,omitempty"`
RedirectHost string `json:"redirect_host,omitempty"`
}
+112 -2
View File
@@ -72,7 +72,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", key)
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "", key)
if err != nil {
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
}
@@ -81,6 +81,116 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
}
}
func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) {
key := []byte("1234567890abcdef")
plaintext := []byte("hello weixin")
ciphertext, err := encryptAESECB(plaintext, key)
if err != nil {
t.Fatalf("encryptAESECB() error = %v", err)
}
fullURLAttempts := 0
ch := &WeixinChannel{
api: &ApiClient{
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() == "https://full.example.com/download" {
fullURLAttempts++
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(ciphertext)),
Header: make(http.Header),
}, nil
}
t.Fatalf("unexpected fallback request: %s", r.URL.String())
return nil, nil
})},
},
config: config.WeixinConfig{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "https://full.example.com/download", key)
if err != nil {
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
}
if fullURLAttempts == 0 {
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
}
}
func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t *testing.T) {
key := []byte("1234567890abcdef")
plaintext := []byte("hello weixin")
ciphertext, err := encryptAESECB(plaintext, key)
if err != nil {
t.Fatalf("encryptAESECB() error = %v", err)
}
fullURLAttempts := 0
constructedAttempts := 0
ch := &WeixinChannel{
api: &ApiClient{
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() == "https://full.example.com/download?encrypted_query_param=token&taskid=123" {
fullURLAttempts++
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewReader(nil)),
Header: make(http.Header),
}, nil
}
if r.URL.String() != "https://cdn.example.com/download?encrypted_query_param=token" {
t.Fatalf("unexpected fallback request: %s", r.URL.String())
}
constructedAttempts++
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(ciphertext)),
Header: make(http.Header),
}, nil
})},
},
config: config.WeixinConfig{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.downloadAndDecryptCDNBuffer(
context.Background(),
"token",
"https://full.example.com/download?encrypted_query_param=token&taskid=123",
key,
)
if err != nil {
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
}
if fullURLAttempts == 0 {
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
}
if constructedAttempts == 0 {
t.Fatalf("constructedAttempts = %d, want > 0", constructedAttempts)
}
}
func TestBuildCDNDownloadURLEscapesOpaqueToken(t *testing.T) {
token := "MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%3D"
got := buildCDNDownloadURL("https://cdn.example.com", token)
if got != "https://cdn.example.com/download?encrypted_query_param=MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%253D" {
t.Fatalf("buildCDNDownloadURL() = %q", got)
}
}
func TestUploadBufferToCDN(t *testing.T) {
key := []byte("1234567890abcdef")
plaintext := []byte("upload me")
@@ -120,7 +230,7 @@ func TestUploadBufferToCDN(t *testing.T) {
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "file-key", key)
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "", "file-key", key)
if err != nil {
t.Fatalf("uploadBufferToCDN() error = %v", err)
}