Merge branch 'main' into fix/update-assets

This commit is contained in:
BeaconCat
2026-03-28 18:47:49 +08:00
committed by GitHub
9 changed files with 521 additions and 102 deletions
+1 -1
View File
@@ -558,7 +558,7 @@ Connetti PicoClaw al Social Network degli Agent semplicemente inviando un singol
| `picoclaw skills list` | Elenca le skill installate |
| `picoclaw skills install` | Installa una skill |
| `picoclaw migrate` | Migra i dati dalle versioni precedenti |
| `picoclaw auth login` | Autenticazione con i provider |
| `picoclaw auth login` | Autenticazione con i provider |
### ⏰ Task Pianificati / Promemoria
+1 -1
View File
@@ -541,7 +541,7 @@ CLI または統合チャットアプリからメッセージを 1 つ送るだ
## 🖥️ CLI リファレンス
| コマンド | 説明 |
| コマンド | 説明 |
| ------------------------- | ------------------------------ |
| `picoclaw onboard` | 設定&ワークスペースの初期化 |
| `picoclaw auth weixin` | WeChat アカウントを QR で接続 |
+57 -13
View File
@@ -6,6 +6,7 @@ package dingtalk
import (
"context"
"fmt"
"strings"
"sync"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
@@ -135,13 +136,17 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
ctx context.Context,
data *chatbot.BotCallbackDataModel,
) ([]byte, error) {
if data == nil {
return nil, nil
}
// Extract message content from Text field
content := data.Text.Content
content := strings.TrimSpace(data.Text.Content)
if content == "" {
// Try to extract from Content interface{} if Text is empty
if contentMap, ok := data.Content.(map[string]any); ok {
if textContent, ok := contentMap["content"].(string); ok {
content = textContent
content = strings.TrimSpace(textContent)
}
}
}
@@ -150,12 +155,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
return nil, nil // Ignore empty messages
}
senderID := data.SenderStaffId
senderNick := data.SenderNick
chatID := senderID
if data.ConversationType != "1" {
// For group chats
chatID = data.ConversationId
senderID := strings.TrimSpace(data.SenderStaffId)
if senderID == "" {
senderID = strings.TrimSpace(data.SenderId)
}
senderNick := strings.TrimSpace(data.SenderNick)
chatID := strings.TrimSpace(data.ConversationId)
if chatID == "" && data.ConversationType == "1" {
// Fallback for direct chats when conversation_id is absent.
chatID = senderID
}
if chatID == "" {
return nil, nil
}
// Store the session webhook for this chat so we can reply later
@@ -171,11 +183,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
var peer bus.Peer
if data.ConversationType == "1" {
peer = bus.Peer{Kind: "direct", ID: senderID}
peerID := senderID
if peerID == "" {
peerID = chatID
}
peer = bus.Peer{Kind: "direct", ID: peerID}
} else {
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
isMentioned := data.IsInAtList
if isMentioned {
content = stripLeadingAtMentions(content)
}
// In group chats, apply unified group trigger filtering
respond, cleaned := c.ShouldRespondInGroup(false, content)
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
if !respond {
return nil, nil
}
@@ -189,10 +209,18 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
})
// Build sender info
platformID := senderID
if platformID == "" {
platformID = chatID
}
resolvedSenderID := senderID
if resolvedSenderID == "" {
resolvedSenderID = platformID
}
sender := bus.SenderInfo{
Platform: "dingtalk",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("dingtalk", senderID),
PlatformID: platformID,
CanonicalID: identity.BuildCanonicalID("dingtalk", platformID),
DisplayName: senderNick,
}
@@ -201,7 +229,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
}
// Handle the message through the base channel
c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender)
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
// Return nil to indicate we've handled the message asynchronously
// The response will be sent through the message bus
@@ -229,3 +257,19 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c
return nil
}
func stripLeadingAtMentions(content string) string {
fields := strings.Fields(content)
if len(fields) == 0 {
return ""
}
i := 0
for i < len(fields) && strings.HasPrefix(fields[i], "@") {
i++
}
if i == 0 {
return strings.TrimSpace(content)
}
return strings.Join(fields[i:], " ")
}
+131
View File
@@ -0,0 +1,131 @@
package dingtalk
import (
"context"
"testing"
"time"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkChannel, *bus.MessageBus) {
t.Helper()
if cfg.ClientID == "" {
cfg.ClientID = "test-client-id"
}
if cfg.ClientSecret() == "" {
cfg.SetClientSecret("test-client-secret")
}
msgBus := bus.NewMessageBus()
ch, err := NewDingTalkChannel(cfg, msgBus)
if err != nil {
t.Fatalf("new channel: %v", err)
}
return ch, msgBus
}
func mustReceiveInbound(t *testing.T, msgBus *bus.MessageBus) bus.InboundMessage {
t.Helper()
select {
case msg := <-msgBus.InboundChan():
return msg
case <-time.After(time.Second):
t.Fatal("expected inbound message")
return bus.InboundMessage{}
}
}
func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention(t *testing.T) {
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{
GroupTrigger: config.GroupTriggerConfig{MentionOnly: true},
})
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
Text: chatbot.BotCallbackDataTextModel{Content: " @bot /help "},
SenderStaffId: "staff-123",
SenderNick: "Alice",
ConversationType: "2",
ConversationId: "group-abc",
SessionWebhook: "https://example.com/webhook",
IsInAtList: true,
})
if err != nil {
t.Fatalf("handler returned error: %v", err)
}
inbound := mustReceiveInbound(t, msgBus)
if inbound.Channel != "dingtalk" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.ChatID != "group-abc" {
t.Fatalf("chat_id=%q", inbound.ChatID)
}
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" {
t.Fatalf("peer=%+v", inbound.Peer)
}
if inbound.Content != "/help" {
t.Fatalf("content=%q", inbound.Content)
}
}
func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *testing.T) {
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{})
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
Text: chatbot.BotCallbackDataTextModel{Content: "ping"},
SenderStaffId: "",
SenderId: "openid-user-42",
SenderNick: "Bob",
ConversationType: "1",
ConversationId: "conv-direct-42",
SessionWebhook: "https://example.com/webhook-direct",
})
if err != nil {
t.Fatalf("handler returned error: %v", err)
}
inbound := mustReceiveInbound(t, msgBus)
if inbound.ChatID != "conv-direct-42" {
t.Fatalf("chat_id=%q", inbound.ChatID)
}
if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" {
t.Fatalf("peer=%+v", inbound.Peer)
}
if inbound.SenderID != "dingtalk:openid-user-42" {
t.Fatalf("sender_id=%q", inbound.SenderID)
}
if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok {
t.Fatal("expected session webhook keyed by conversation_id")
}
if _, ok := ch.sessionWebhooks.Load(""); ok {
t.Fatal("unexpected empty chat_id webhook key")
}
}
func TestStripLeadingAtMentions(t *testing.T) {
tests := []struct {
name string
input string
wantOut string
}{
{name: "single mention and command", input: "@bot /help", wantOut: "/help"},
{name: "multiple mentions", input: "@bot @alice /new", wantOut: "/new"},
{name: "no mention", input: "/help", wantOut: "/help"},
{name: "mention only", input: "@bot", wantOut: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := stripLeadingAtMentions(tt.input)
if got != tt.wantOut {
t.Fatalf("stripLeadingAtMentions(%q)=%q want=%q", tt.input, got, tt.wantOut)
}
})
}
}
+42 -52
View File
@@ -12,6 +12,14 @@ import (
"net/http"
"net/url"
"path"
"strconv"
)
const (
weixinChannelVersion = "2.1.1"
weixinIlinkAppID = "bot"
// 2.1.1 encoded as 0x00MMNNPP => 0x00020101 => 131329
weixinClientVersion = 131329
)
type ApiClient struct {
@@ -80,13 +88,9 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
}
req.Header.Set("Content-Type", "application/json")
if endpoint == "ilink/bot/get_bot_qrcode" || endpoint == "ilink/bot/get_qrcode_status" {
// QR routes have different headers sometimes, but let's stick to base ones
if endpoint == "ilink/bot/get_qrcode_status" {
// Use direct map assignment to send exact header name the Tencent API expects
req.Header["iLink-App-ClientVersion"] = []string{"1"}
}
} else {
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
if endpoint != "ilink/bot/get_bot_qrcode" && endpoint != "ilink/bot/get_qrcode_status" {
req.Header["AuthorizationType"] = []string{"ilink_bot_token"}
req.Header["X-WECHAT-UIN"] = []string{randomWechatUIN()}
if c.Token != "" {
@@ -119,7 +123,7 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
}
func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpdatesResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp GetUpdatesResp
err := c.post(ctx, "ilink/bot/getupdates", req, &resp)
if err != nil {
@@ -129,7 +133,7 @@ func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpda
}
func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendMessageResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp SendMessageResp
if err := c.post(ctx, "ilink/bot/sendmessage", req, &resp); err != nil {
return nil, err
@@ -138,7 +142,7 @@ func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendM
}
func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*GetUploadUrlResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp GetUploadUrlResp
err := c.post(ctx, "ilink/bot/getuploadurl", req, &resp)
if err != nil {
@@ -148,7 +152,7 @@ func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*Get
}
func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfigResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp GetConfigResp
if err := c.post(ctx, "ilink/bot/getconfig", req, &resp); err != nil {
return nil, err
@@ -157,7 +161,7 @@ func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfig
}
func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTypingResp, error) {
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
var resp SendTypingResp
if err := c.post(ctx, "ilink/bot/sendtyping", req, &resp); err != nil {
return nil, err
@@ -165,38 +169,51 @@ func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTyp
return &resp, nil
}
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
// get_bot_qrcode is GET, not POST
func (c *ApiClient) getQR(ctx context.Context, endpoint string, query map[string]string, respObj any) error {
u, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
return err
}
u.Path = path.Join(u.Path, "ilink/bot/get_bot_qrcode")
u.Path = path.Join(u.Path, endpoint)
q := u.Query()
q.Set("bot_type", botType)
for key, value := range query {
q.Set(key, value)
}
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, err
return err
}
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
return err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
return err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("get_bot_qrcode failed: %d %s", resp.StatusCode, string(respBody))
return fmt.Errorf("%s failed: %d %s", endpoint, resp.StatusCode, string(respBody))
}
if err := json.Unmarshal(respBody, respObj); err != nil {
return err
}
return nil
}
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
// get_bot_qrcode is GET, not POST
var qrcodeResp QRCodeResponse
if err := json.Unmarshal(respBody, &qrcodeResp); err != nil {
if err := c.getQR(ctx, "ilink/bot/get_bot_qrcode", map[string]string{
"bot_type": botType,
}, &qrcodeResp); err != nil {
return nil, err
}
return &qrcodeResp, nil
@@ -204,37 +221,10 @@ func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeRespo
func (c *ApiClient) GetQRCodeStatus(ctx context.Context, qrcode string) (*StatusResponse, error) {
// get_qrcode_status is GET
u, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
}
u.Path = path.Join(u.Path, "ilink/bot/get_qrcode_status")
q := u.Query()
q.Set("qrcode", qrcode)
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, err
}
req.Header["iLink-App-ClientVersion"] = []string{"1"}
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("get_qrcode_status failed: %d %s", resp.StatusCode, string(respBody))
}
var statusResp StatusResponse
if err := json.Unmarshal(respBody, &statusResp); err != nil {
if err := c.getQR(ctx, "ilink/bot/get_qrcode_status", map[string]string{
"qrcode": qrcode,
}, &statusResp); err != nil {
return nil, err
}
return &statusResp, nil
+23 -1
View File
@@ -40,6 +40,7 @@ func PerformLoginInteractive(
if err != nil {
return "", "", "", "", fmt.Errorf("failed to create api client: %w", err)
}
pollAPI := api
logger.InfoC("weixin", "Requesting Weixin QR code...")
qrResp, err := api.GetQRCode(ctx, opts.BotType)
@@ -76,7 +77,7 @@ func PerformLoginInteractive(
case <-timeoutCtx.Done():
return "", "", "", "", fmt.Errorf("login timeout")
case <-pollTicker.C:
statusResp, err := api.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
statusResp, err := pollAPI.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
if err != nil {
// Long poll timeout or temporary error
continue
@@ -99,6 +100,27 @@ func PerformLoginInteractive(
})
return statusResp.BotToken, statusResp.IlinkUserID, statusResp.IlinkBotID, statusResp.Baseurl, nil
case "scaned_but_redirect":
if statusResp.RedirectHost == "" {
logger.WarnC(
"weixin",
"scaned_but_redirect received without redirect_host; continuing on current host",
)
continue
}
nextBaseURL := "https://" + statusResp.RedirectHost + "/"
nextAPI, nextErr := NewApiClient(nextBaseURL, "", opts.Proxy)
if nextErr != nil {
logger.WarnCF("weixin", "Failed to switch QR polling host", map[string]any{
"redirect_host": statusResp.RedirectHost,
"error": nextErr.Error(),
})
continue
}
pollAPI = nextAPI
logger.InfoCF("weixin", "Switched QR polling host", map[string]any{
"redirect_host": statusResp.RedirectHost,
})
case "expired":
return "", "", "", "", fmt.Errorf("qrcode expired, please try again")
default:
+146 -27
View File
@@ -34,6 +34,8 @@ const (
weixinMediaMaxBytes = 100 << 20
weixinTypingKeepAlive = 5 * time.Second
weixinUploadRetryMax = 3
weixinDownloadRetryMax = 2
weixinDownloadRetryDelay = 300 * time.Millisecond
weixinVoiceTranscodeTimeout = 15 * time.Second
)
@@ -163,49 +165,108 @@ func buildCDNDownloadURL(base, encryptedQueryParam string) string {
"/download?encrypted_query_param=" + url.QueryEscape(encryptedQueryParam)
}
func shouldRetryCDNDownload(statusCode int) bool {
// statusCode=0 represents transport/build errors from the HTTP client.
return statusCode == 0 || statusCode >= 500 || statusCode == http.StatusTooManyRequests
}
func buildCDNUploadURL(base, uploadParam, filekey string) string {
return strings.TrimRight(base, "/") +
"/upload?encrypted_query_param=" + url.QueryEscape(uploadParam) +
"&filekey=" + url.QueryEscape(filekey)
}
func (c *WeixinChannel) downloadCDNBuffer(ctx context.Context, encryptedQueryParam string) ([]byte, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam),
nil,
)
func uniqCDNURLs(urls []string) []string {
seen := make(map[string]struct{}, len(urls))
out := make([]string, 0, len(urls))
for _, raw := range urls {
u := strings.TrimSpace(raw)
if u == "" {
continue
}
if _, ok := seen[u]; ok {
continue
}
seen[u] = struct{}{}
out = append(out, u)
}
return out
}
func (c *WeixinChannel) downloadCDNBufferOnce(ctx context.Context, downloadURL string) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
if err != nil {
return nil, err
return nil, 0, err
}
resp, err := c.api.HttpClient.Do(req)
if err != nil {
return nil, err
return nil, 0, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
return nil, resp.StatusCode, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
}
data, err := io.ReadAll(io.LimitReader(resp.Body, weixinMediaMaxBytes+1))
if err != nil {
return nil, err
return nil, resp.StatusCode, err
}
if len(data) > weixinMediaMaxBytes {
return nil, fmt.Errorf("cdn media too large: %d bytes", len(data))
return nil, resp.StatusCode, fmt.Errorf("cdn media too large: %d bytes", len(data))
}
return data, nil
return data, resp.StatusCode, nil
}
func (c *WeixinChannel) downloadCDNBuffer(
ctx context.Context,
encryptedQueryParam,
fullURL string,
) ([]byte, error) {
candidates := uniqCDNURLs([]string{
strings.TrimSpace(fullURL),
func() string {
if strings.TrimSpace(encryptedQueryParam) == "" {
return ""
}
return buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam)
}(),
})
if len(candidates) == 0 {
return nil, fmt.Errorf("missing CDN download URL")
}
var lastErr error
for _, downloadURL := range candidates {
for attempt := 1; attempt <= weixinDownloadRetryMax; attempt++ {
data, statusCode, err := c.downloadCDNBufferOnce(ctx, downloadURL)
if err == nil {
return data, nil
}
lastErr = fmt.Errorf("%w (attempt=%d url=%s)", err, attempt, downloadURL)
if !shouldRetryCDNDownload(statusCode) {
break
}
if attempt < weixinDownloadRetryMax {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(weixinDownloadRetryDelay):
}
}
}
}
return nil, lastErr
}
func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
ctx context.Context,
encryptedQueryParam string,
fullURL string,
key []byte,
) ([]byte, error) {
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam)
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam, fullURL)
if err != nil {
return nil, err
}
@@ -215,6 +276,33 @@ func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
return decryptAESECB(data, key)
}
func (c *WeixinChannel) downloadImageBuffer(
ctx context.Context,
img *ImageItem,
key []byte,
) ([]byte, error) {
if img == nil {
return nil, fmt.Errorf("image item is nil")
}
if img.Media != nil {
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.Media.EncryptQueryParam, img.Media.FullURL, key)
if err == nil {
return data, nil
}
if img.ThumbMedia == nil {
return nil, fmt.Errorf("image download failed: %w", err)
}
}
if img.ThumbMedia != nil {
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.ThumbMedia.EncryptQueryParam, img.ThumbMedia.FullURL, key)
if err == nil {
return data, nil
}
return nil, fmt.Errorf("image download failed: %w", err)
}
return nil, fmt.Errorf("image media is nil")
}
func detectMediaMetadata(data []byte, fallbackName, fallbackContentType string) (string, string) {
contentType := strings.TrimSpace(fallbackContentType)
ext := filepath.Ext(fallbackName)
@@ -310,15 +398,18 @@ func isDownloadableMediaItem(item *MessageItem) bool {
switch item.Type {
case MessageItemTypeImage:
return item.ImageItem != nil && item.ImageItem.Media != nil && item.ImageItem.Media.EncryptQueryParam != ""
return item.ImageItem != nil && item.ImageItem.Media != nil &&
(item.ImageItem.Media.EncryptQueryParam != "" || item.ImageItem.Media.FullURL != "")
case MessageItemTypeVideo:
return item.VideoItem != nil && item.VideoItem.Media != nil && item.VideoItem.Media.EncryptQueryParam != ""
return item.VideoItem != nil && item.VideoItem.Media != nil &&
(item.VideoItem.Media.EncryptQueryParam != "" || item.VideoItem.Media.FullURL != "")
case MessageItemTypeFile:
return item.FileItem != nil && item.FileItem.Media != nil && item.FileItem.Media.EncryptQueryParam != ""
return item.FileItem != nil && item.FileItem.Media != nil &&
(item.FileItem.Media.EncryptQueryParam != "" || item.FileItem.Media.FullURL != "")
case MessageItemTypeVoice:
return item.VoiceItem != nil &&
item.VoiceItem.Media != nil &&
item.VoiceItem.Media.EncryptQueryParam != "" &&
(item.VoiceItem.Media.EncryptQueryParam != "" || item.VoiceItem.Media.FullURL != "") &&
strings.TrimSpace(item.VoiceItem.Text) == ""
default:
return false
@@ -434,16 +525,20 @@ func (c *WeixinChannel) downloadMediaFromItem(
switch item.Type {
case MessageItemTypeImage:
if item.ImageItem == nil {
return "", fmt.Errorf("image media is nil")
}
key, ok, err := imageAESKey(item.ImageItem)
if err != nil {
return "", err
}
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.ImageItem.Media.EncryptQueryParam, func() []byte {
decryptKey := func() []byte {
if ok {
return key
}
return nil
}())
}()
data, err := c.downloadImageBuffer(ctx, item.ImageItem, decryptKey)
if err != nil {
return "", err
}
@@ -454,7 +549,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
if err != nil {
return "", err
}
silk, err := c.downloadAndDecryptCDNBuffer(ctx, item.VoiceItem.Media.EncryptQueryParam, key)
silk, err := c.downloadAndDecryptCDNBuffer(
ctx,
item.VoiceItem.Media.EncryptQueryParam,
item.VoiceItem.Media.FullURL,
key,
)
if err != nil {
return "", err
}
@@ -468,7 +568,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
if err != nil {
return "", err
}
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.FileItem.Media.EncryptQueryParam, key)
data, err := c.downloadAndDecryptCDNBuffer(
ctx,
item.FileItem.Media.EncryptQueryParam,
item.FileItem.Media.FullURL,
key,
)
if err != nil {
return "", err
}
@@ -484,7 +589,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
if err != nil {
return "", err
}
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.VideoItem.Media.EncryptQueryParam, key)
data, err := c.downloadAndDecryptCDNBuffer(
ctx,
item.VideoItem.Media.EncryptQueryParam,
item.VideoItem.Media.FullURL,
key,
)
if err != nil {
return "", err
}
@@ -701,11 +811,13 @@ func (c *WeixinChannel) uploadLocalFile(
}
return nil, fmt.Errorf("getuploadurl failed: ret=%d errcode=%d errmsg=%s", resp.Ret, resp.Errcode, resp.Errmsg)
}
if strings.TrimSpace(resp.UploadParam) == "" {
return nil, fmt.Errorf("getuploadurl returned empty upload_param")
uploadParam := strings.TrimSpace(resp.UploadParam)
uploadFullURL := strings.TrimSpace(resp.UploadFullURL)
if uploadParam == "" && uploadFullURL == "" {
return nil, fmt.Errorf("getuploadurl returned no upload URL")
}
downloadParam, err := c.uploadBufferToCDN(ctx, data, resp.UploadParam, filekey, aesKey)
downloadParam, err := c.uploadBufferToCDN(ctx, data, uploadParam, uploadFullURL, filekey, aesKey)
if err != nil {
return nil, err
}
@@ -723,6 +835,7 @@ func (c *WeixinChannel) uploadBufferToCDN(
ctx context.Context,
plaintext []byte,
uploadParam,
uploadFullURL,
filekey string,
aesKey []byte,
) (string, error) {
@@ -731,7 +844,13 @@ func (c *WeixinChannel) uploadBufferToCDN(
return "", err
}
uploadURL := buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
uploadURL := strings.TrimSpace(uploadFullURL)
if uploadURL == "" {
if strings.TrimSpace(uploadParam) == "" {
return "", fmt.Errorf("missing CDN upload URL")
}
uploadURL = buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
}
var lastErr error
for attempt := 1; attempt <= weixinUploadRetryMax; attempt++ {
+8 -5
View File
@@ -38,6 +38,7 @@ type GetUploadUrlResp struct {
APIStatus
UploadParam string `json:"upload_param,omitempty"`
ThumbUploadParam string `json:"thumb_upload_param,omitempty"`
UploadFullURL string `json:"upload_full_url,omitempty"`
}
const (
@@ -69,6 +70,7 @@ type CDNMedia struct {
EncryptQueryParam string `json:"encrypt_query_param,omitempty"`
AesKey string `json:"aes_key,omitempty"` // base64 encoded
EncryptType int `json:"encrypt_type,omitempty"`
FullURL string `json:"full_url,omitempty"`
}
type ImageItem struct {
@@ -202,9 +204,10 @@ type QRCodeResponse struct {
}
type StatusResponse struct {
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired"
BotToken string `json:"bot_token,omitempty"`
IlinkBotID string `json:"ilink_bot_id,omitempty"`
Baseurl string `json:"baseurl,omitempty"`
IlinkUserID string `json:"ilink_user_id,omitempty"`
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired", "scaned_but_redirect"
BotToken string `json:"bot_token,omitempty"`
IlinkBotID string `json:"ilink_bot_id,omitempty"`
Baseurl string `json:"baseurl,omitempty"`
IlinkUserID string `json:"ilink_user_id,omitempty"`
RedirectHost string `json:"redirect_host,omitempty"`
}
+112 -2
View File
@@ -72,7 +72,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", key)
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "", key)
if err != nil {
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
}
@@ -81,6 +81,116 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
}
}
func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) {
key := []byte("1234567890abcdef")
plaintext := []byte("hello weixin")
ciphertext, err := encryptAESECB(plaintext, key)
if err != nil {
t.Fatalf("encryptAESECB() error = %v", err)
}
fullURLAttempts := 0
ch := &WeixinChannel{
api: &ApiClient{
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() == "https://full.example.com/download" {
fullURLAttempts++
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(ciphertext)),
Header: make(http.Header),
}, nil
}
t.Fatalf("unexpected fallback request: %s", r.URL.String())
return nil, nil
})},
},
config: config.WeixinConfig{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "https://full.example.com/download", key)
if err != nil {
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
}
if fullURLAttempts == 0 {
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
}
}
func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t *testing.T) {
key := []byte("1234567890abcdef")
plaintext := []byte("hello weixin")
ciphertext, err := encryptAESECB(plaintext, key)
if err != nil {
t.Fatalf("encryptAESECB() error = %v", err)
}
fullURLAttempts := 0
constructedAttempts := 0
ch := &WeixinChannel{
api: &ApiClient{
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.String() == "https://full.example.com/download?encrypted_query_param=token&taskid=123" {
fullURLAttempts++
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewReader(nil)),
Header: make(http.Header),
}, nil
}
if r.URL.String() != "https://cdn.example.com/download?encrypted_query_param=token" {
t.Fatalf("unexpected fallback request: %s", r.URL.String())
}
constructedAttempts++
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(ciphertext)),
Header: make(http.Header),
}, nil
})},
},
config: config.WeixinConfig{
CDNBaseURL: "https://cdn.example.com",
},
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.downloadAndDecryptCDNBuffer(
context.Background(),
"token",
"https://full.example.com/download?encrypted_query_param=token&taskid=123",
key,
)
if err != nil {
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
}
if fullURLAttempts == 0 {
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
}
if constructedAttempts == 0 {
t.Fatalf("constructedAttempts = %d, want > 0", constructedAttempts)
}
}
func TestBuildCDNDownloadURLEscapesOpaqueToken(t *testing.T) {
token := "MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%3D"
got := buildCDNDownloadURL("https://cdn.example.com", token)
if got != "https://cdn.example.com/download?encrypted_query_param=MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%253D" {
t.Fatalf("buildCDNDownloadURL() = %q", got)
}
}
func TestUploadBufferToCDN(t *testing.T) {
key := []byte("1234567890abcdef")
plaintext := []byte("upload me")
@@ -120,7 +230,7 @@ func TestUploadBufferToCDN(t *testing.T) {
typingCache: make(map[string]typingTicketCacheEntry),
}
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "file-key", key)
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "", "file-key", key)
if err != nil {
t.Fatalf("uploadBufferToCDN() error = %v", err)
}