mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into fix/update-assets
This commit is contained in:
+1
-1
@@ -558,7 +558,7 @@ Connetti PicoClaw al Social Network degli Agent semplicemente inviando un singol
|
||||
| `picoclaw skills list` | Elenca le skill installate |
|
||||
| `picoclaw skills install` | Installa una skill |
|
||||
| `picoclaw migrate` | Migra i dati dalle versioni precedenti |
|
||||
| `picoclaw auth login` | Autenticazione con i provider |
|
||||
| `picoclaw auth login` | Autenticazione con i provider |
|
||||
|
||||
### ⏰ Task Pianificati / Promemoria
|
||||
|
||||
|
||||
+1
-1
@@ -541,7 +541,7 @@ CLI または統合チャットアプリからメッセージを 1 つ送るだ
|
||||
|
||||
## 🖥️ CLI リファレンス
|
||||
|
||||
| コマンド | 説明 |
|
||||
| コマンド | 説明 |
|
||||
| ------------------------- | ------------------------------ |
|
||||
| `picoclaw onboard` | 設定&ワークスペースの初期化 |
|
||||
| `picoclaw auth weixin` | WeChat アカウントを QR で接続 |
|
||||
|
||||
@@ -6,6 +6,7 @@ package dingtalk
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
@@ -135,13 +136,17 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
ctx context.Context,
|
||||
data *chatbot.BotCallbackDataModel,
|
||||
) ([]byte, error) {
|
||||
if data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Extract message content from Text field
|
||||
content := data.Text.Content
|
||||
content := strings.TrimSpace(data.Text.Content)
|
||||
if content == "" {
|
||||
// Try to extract from Content interface{} if Text is empty
|
||||
if contentMap, ok := data.Content.(map[string]any); ok {
|
||||
if textContent, ok := contentMap["content"].(string); ok {
|
||||
content = textContent
|
||||
content = strings.TrimSpace(textContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -150,12 +155,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
return nil, nil // Ignore empty messages
|
||||
}
|
||||
|
||||
senderID := data.SenderStaffId
|
||||
senderNick := data.SenderNick
|
||||
chatID := senderID
|
||||
if data.ConversationType != "1" {
|
||||
// For group chats
|
||||
chatID = data.ConversationId
|
||||
senderID := strings.TrimSpace(data.SenderStaffId)
|
||||
if senderID == "" {
|
||||
senderID = strings.TrimSpace(data.SenderId)
|
||||
}
|
||||
senderNick := strings.TrimSpace(data.SenderNick)
|
||||
|
||||
chatID := strings.TrimSpace(data.ConversationId)
|
||||
if chatID == "" && data.ConversationType == "1" {
|
||||
// Fallback for direct chats when conversation_id is absent.
|
||||
chatID = senderID
|
||||
}
|
||||
if chatID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Store the session webhook for this chat so we can reply later
|
||||
@@ -171,11 +183,19 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
|
||||
var peer bus.Peer
|
||||
if data.ConversationType == "1" {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
peerID := senderID
|
||||
if peerID == "" {
|
||||
peerID = chatID
|
||||
}
|
||||
peer = bus.Peer{Kind: "direct", ID: peerID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
|
||||
isMentioned := data.IsInAtList
|
||||
if isMentioned {
|
||||
content = stripLeadingAtMentions(content)
|
||||
}
|
||||
// In group chats, apply unified group trigger filtering
|
||||
respond, cleaned := c.ShouldRespondInGroup(false, content)
|
||||
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
|
||||
if !respond {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -189,10 +209,18 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
})
|
||||
|
||||
// Build sender info
|
||||
platformID := senderID
|
||||
if platformID == "" {
|
||||
platformID = chatID
|
||||
}
|
||||
resolvedSenderID := senderID
|
||||
if resolvedSenderID == "" {
|
||||
resolvedSenderID = platformID
|
||||
}
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "dingtalk",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("dingtalk", senderID),
|
||||
PlatformID: platformID,
|
||||
CanonicalID: identity.BuildCanonicalID("dingtalk", platformID),
|
||||
DisplayName: senderNick,
|
||||
}
|
||||
|
||||
@@ -201,7 +229,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender)
|
||||
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
|
||||
|
||||
// Return nil to indicate we've handled the message asynchronously
|
||||
// The response will be sent through the message bus
|
||||
@@ -229,3 +257,19 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stripLeadingAtMentions(content string) string {
|
||||
fields := strings.Fields(content)
|
||||
if len(fields) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(fields) && strings.HasPrefix(fields[i], "@") {
|
||||
i++
|
||||
}
|
||||
if i == 0 {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
return strings.Join(fields[i:], " ")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package dingtalk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func newTestDingTalkChannel(t *testing.T, cfg config.DingTalkConfig) (*DingTalkChannel, *bus.MessageBus) {
|
||||
t.Helper()
|
||||
|
||||
if cfg.ClientID == "" {
|
||||
cfg.ClientID = "test-client-id"
|
||||
}
|
||||
if cfg.ClientSecret() == "" {
|
||||
cfg.SetClientSecret("test-client-secret")
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewDingTalkChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("new channel: %v", err)
|
||||
}
|
||||
return ch, msgBus
|
||||
}
|
||||
|
||||
func mustReceiveInbound(t *testing.T, msgBus *bus.MessageBus) bus.InboundMessage {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-msgBus.InboundChan():
|
||||
return msg
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inbound message")
|
||||
return bus.InboundMessage{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention(t *testing.T) {
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{
|
||||
GroupTrigger: config.GroupTriggerConfig{MentionOnly: true},
|
||||
})
|
||||
|
||||
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
Text: chatbot.BotCallbackDataTextModel{Content: " @bot /help "},
|
||||
SenderStaffId: "staff-123",
|
||||
SenderNick: "Alice",
|
||||
ConversationType: "2",
|
||||
ConversationId: "group-abc",
|
||||
SessionWebhook: "https://example.com/webhook",
|
||||
IsInAtList: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handler returned error: %v", err)
|
||||
}
|
||||
|
||||
inbound := mustReceiveInbound(t, msgBus)
|
||||
if inbound.Channel != "dingtalk" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.ChatID != "group-abc" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
}
|
||||
if inbound.Content != "/help" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *testing.T) {
|
||||
ch, msgBus := newTestDingTalkChannel(t, config.DingTalkConfig{})
|
||||
|
||||
_, err := ch.onChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
Text: chatbot.BotCallbackDataTextModel{Content: "ping"},
|
||||
SenderStaffId: "",
|
||||
SenderId: "openid-user-42",
|
||||
SenderNick: "Bob",
|
||||
ConversationType: "1",
|
||||
ConversationId: "conv-direct-42",
|
||||
SessionWebhook: "https://example.com/webhook-direct",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("handler returned error: %v", err)
|
||||
}
|
||||
|
||||
inbound := mustReceiveInbound(t, msgBus)
|
||||
if inbound.ChatID != "conv-direct-42" {
|
||||
t.Fatalf("chat_id=%q", inbound.ChatID)
|
||||
}
|
||||
if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" {
|
||||
t.Fatalf("peer=%+v", inbound.Peer)
|
||||
}
|
||||
if inbound.SenderID != "dingtalk:openid-user-42" {
|
||||
t.Fatalf("sender_id=%q", inbound.SenderID)
|
||||
}
|
||||
|
||||
if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok {
|
||||
t.Fatal("expected session webhook keyed by conversation_id")
|
||||
}
|
||||
if _, ok := ch.sessionWebhooks.Load(""); ok {
|
||||
t.Fatal("unexpected empty chat_id webhook key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripLeadingAtMentions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantOut string
|
||||
}{
|
||||
{name: "single mention and command", input: "@bot /help", wantOut: "/help"},
|
||||
{name: "multiple mentions", input: "@bot @alice /new", wantOut: "/new"},
|
||||
{name: "no mention", input: "/help", wantOut: "/help"},
|
||||
{name: "mention only", input: "@bot", wantOut: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := stripLeadingAtMentions(tt.input)
|
||||
if got != tt.wantOut {
|
||||
t.Fatalf("stripLeadingAtMentions(%q)=%q want=%q", tt.input, got, tt.wantOut)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+42
-52
@@ -12,6 +12,14 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
weixinChannelVersion = "2.1.1"
|
||||
weixinIlinkAppID = "bot"
|
||||
// 2.1.1 encoded as 0x00MMNNPP => 0x00020101 => 131329
|
||||
weixinClientVersion = 131329
|
||||
)
|
||||
|
||||
type ApiClient struct {
|
||||
@@ -80,13 +88,9 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if endpoint == "ilink/bot/get_bot_qrcode" || endpoint == "ilink/bot/get_qrcode_status" {
|
||||
// QR routes have different headers sometimes, but let's stick to base ones
|
||||
if endpoint == "ilink/bot/get_qrcode_status" {
|
||||
// Use direct map assignment to send exact header name the Tencent API expects
|
||||
req.Header["iLink-App-ClientVersion"] = []string{"1"}
|
||||
}
|
||||
} else {
|
||||
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
|
||||
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
|
||||
if endpoint != "ilink/bot/get_bot_qrcode" && endpoint != "ilink/bot/get_qrcode_status" {
|
||||
req.Header["AuthorizationType"] = []string{"ilink_bot_token"}
|
||||
req.Header["X-WECHAT-UIN"] = []string{randomWechatUIN()}
|
||||
if c.Token != "" {
|
||||
@@ -119,7 +123,7 @@ func (c *ApiClient) post(ctx context.Context, endpoint string, body any, respons
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpdatesResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp GetUpdatesResp
|
||||
err := c.post(ctx, "ilink/bot/getupdates", req, &resp)
|
||||
if err != nil {
|
||||
@@ -129,7 +133,7 @@ func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpda
|
||||
}
|
||||
|
||||
func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendMessageResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp SendMessageResp
|
||||
if err := c.post(ctx, "ilink/bot/sendmessage", req, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -138,7 +142,7 @@ func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendM
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*GetUploadUrlResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp GetUploadUrlResp
|
||||
err := c.post(ctx, "ilink/bot/getuploadurl", req, &resp)
|
||||
if err != nil {
|
||||
@@ -148,7 +152,7 @@ func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*Get
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfigResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp GetConfigResp
|
||||
if err := c.post(ctx, "ilink/bot/getconfig", req, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -157,7 +161,7 @@ func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfig
|
||||
}
|
||||
|
||||
func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTypingResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: weixinChannelVersion}
|
||||
var resp SendTypingResp
|
||||
if err := c.post(ctx, "ilink/bot/sendtyping", req, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -165,38 +169,51 @@ func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTyp
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
|
||||
// get_bot_qrcode is GET, not POST
|
||||
func (c *ApiClient) getQR(ctx context.Context, endpoint string, query map[string]string, respObj any) error {
|
||||
u, err := url.Parse(c.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "ilink/bot/get_bot_qrcode")
|
||||
u.Path = path.Join(u.Path, endpoint)
|
||||
q := u.Query()
|
||||
q.Set("bot_type", botType)
|
||||
for key, value := range query {
|
||||
q.Set(key, value)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
req.Header["iLink-App-Id"] = []string{weixinIlinkAppID}
|
||||
req.Header["iLink-App-ClientVersion"] = []string{strconv.Itoa(weixinClientVersion)}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("get_bot_qrcode failed: %d %s", resp.StatusCode, string(respBody))
|
||||
return fmt.Errorf("%s failed: %d %s", endpoint, resp.StatusCode, string(respBody))
|
||||
}
|
||||
if err := json.Unmarshal(respBody, respObj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
|
||||
// get_bot_qrcode is GET, not POST
|
||||
var qrcodeResp QRCodeResponse
|
||||
if err := json.Unmarshal(respBody, &qrcodeResp); err != nil {
|
||||
if err := c.getQR(ctx, "ilink/bot/get_bot_qrcode", map[string]string{
|
||||
"bot_type": botType,
|
||||
}, &qrcodeResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &qrcodeResp, nil
|
||||
@@ -204,37 +221,10 @@ func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeRespo
|
||||
|
||||
func (c *ApiClient) GetQRCodeStatus(ctx context.Context, qrcode string) (*StatusResponse, error) {
|
||||
// get_qrcode_status is GET
|
||||
u, err := url.Parse(c.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "ilink/bot/get_qrcode_status")
|
||||
q := u.Query()
|
||||
q.Set("qrcode", qrcode)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header["iLink-App-ClientVersion"] = []string{"1"}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("get_qrcode_status failed: %d %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var statusResp StatusResponse
|
||||
if err := json.Unmarshal(respBody, &statusResp); err != nil {
|
||||
if err := c.getQR(ctx, "ilink/bot/get_qrcode_status", map[string]string{
|
||||
"qrcode": qrcode,
|
||||
}, &statusResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &statusResp, nil
|
||||
|
||||
@@ -40,6 +40,7 @@ func PerformLoginInteractive(
|
||||
if err != nil {
|
||||
return "", "", "", "", fmt.Errorf("failed to create api client: %w", err)
|
||||
}
|
||||
pollAPI := api
|
||||
|
||||
logger.InfoC("weixin", "Requesting Weixin QR code...")
|
||||
qrResp, err := api.GetQRCode(ctx, opts.BotType)
|
||||
@@ -76,7 +77,7 @@ func PerformLoginInteractive(
|
||||
case <-timeoutCtx.Done():
|
||||
return "", "", "", "", fmt.Errorf("login timeout")
|
||||
case <-pollTicker.C:
|
||||
statusResp, err := api.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
|
||||
statusResp, err := pollAPI.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
|
||||
if err != nil {
|
||||
// Long poll timeout or temporary error
|
||||
continue
|
||||
@@ -99,6 +100,27 @@ func PerformLoginInteractive(
|
||||
})
|
||||
|
||||
return statusResp.BotToken, statusResp.IlinkUserID, statusResp.IlinkBotID, statusResp.Baseurl, nil
|
||||
case "scaned_but_redirect":
|
||||
if statusResp.RedirectHost == "" {
|
||||
logger.WarnC(
|
||||
"weixin",
|
||||
"scaned_but_redirect received without redirect_host; continuing on current host",
|
||||
)
|
||||
continue
|
||||
}
|
||||
nextBaseURL := "https://" + statusResp.RedirectHost + "/"
|
||||
nextAPI, nextErr := NewApiClient(nextBaseURL, "", opts.Proxy)
|
||||
if nextErr != nil {
|
||||
logger.WarnCF("weixin", "Failed to switch QR polling host", map[string]any{
|
||||
"redirect_host": statusResp.RedirectHost,
|
||||
"error": nextErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
pollAPI = nextAPI
|
||||
logger.InfoCF("weixin", "Switched QR polling host", map[string]any{
|
||||
"redirect_host": statusResp.RedirectHost,
|
||||
})
|
||||
case "expired":
|
||||
return "", "", "", "", fmt.Errorf("qrcode expired, please try again")
|
||||
default:
|
||||
|
||||
+146
-27
@@ -34,6 +34,8 @@ const (
|
||||
weixinMediaMaxBytes = 100 << 20
|
||||
weixinTypingKeepAlive = 5 * time.Second
|
||||
weixinUploadRetryMax = 3
|
||||
weixinDownloadRetryMax = 2
|
||||
weixinDownloadRetryDelay = 300 * time.Millisecond
|
||||
weixinVoiceTranscodeTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
@@ -163,49 +165,108 @@ func buildCDNDownloadURL(base, encryptedQueryParam string) string {
|
||||
"/download?encrypted_query_param=" + url.QueryEscape(encryptedQueryParam)
|
||||
}
|
||||
|
||||
func shouldRetryCDNDownload(statusCode int) bool {
|
||||
// statusCode=0 represents transport/build errors from the HTTP client.
|
||||
return statusCode == 0 || statusCode >= 500 || statusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
func buildCDNUploadURL(base, uploadParam, filekey string) string {
|
||||
return strings.TrimRight(base, "/") +
|
||||
"/upload?encrypted_query_param=" + url.QueryEscape(uploadParam) +
|
||||
"&filekey=" + url.QueryEscape(filekey)
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadCDNBuffer(ctx context.Context, encryptedQueryParam string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam),
|
||||
nil,
|
||||
)
|
||||
func uniqCDNURLs(urls []string) []string {
|
||||
seen := make(map[string]struct{}, len(urls))
|
||||
out := make([]string, 0, len(urls))
|
||||
for _, raw := range urls {
|
||||
u := strings.TrimSpace(raw)
|
||||
if u == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[u]; ok {
|
||||
continue
|
||||
}
|
||||
seen[u] = struct{}{}
|
||||
out = append(out, u)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadCDNBufferOnce(ctx context.Context, downloadURL string) ([]byte, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
resp, err := c.api.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return nil, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
|
||||
return nil, resp.StatusCode, fmt.Errorf("cdn download HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, weixinMediaMaxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
if len(data) > weixinMediaMaxBytes {
|
||||
return nil, fmt.Errorf("cdn media too large: %d bytes", len(data))
|
||||
return nil, resp.StatusCode, fmt.Errorf("cdn media too large: %d bytes", len(data))
|
||||
}
|
||||
return data, nil
|
||||
return data, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadCDNBuffer(
|
||||
ctx context.Context,
|
||||
encryptedQueryParam,
|
||||
fullURL string,
|
||||
) ([]byte, error) {
|
||||
candidates := uniqCDNURLs([]string{
|
||||
strings.TrimSpace(fullURL),
|
||||
func() string {
|
||||
if strings.TrimSpace(encryptedQueryParam) == "" {
|
||||
return ""
|
||||
}
|
||||
return buildCDNDownloadURL(c.cdnBaseURL(), encryptedQueryParam)
|
||||
}(),
|
||||
})
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("missing CDN download URL")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, downloadURL := range candidates {
|
||||
for attempt := 1; attempt <= weixinDownloadRetryMax; attempt++ {
|
||||
data, statusCode, err := c.downloadCDNBufferOnce(ctx, downloadURL)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
lastErr = fmt.Errorf("%w (attempt=%d url=%s)", err, attempt, downloadURL)
|
||||
if !shouldRetryCDNDownload(statusCode) {
|
||||
break
|
||||
}
|
||||
if attempt < weixinDownloadRetryMax {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(weixinDownloadRetryDelay):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
|
||||
ctx context.Context,
|
||||
encryptedQueryParam string,
|
||||
fullURL string,
|
||||
key []byte,
|
||||
) ([]byte, error) {
|
||||
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam)
|
||||
data, err := c.downloadCDNBuffer(ctx, encryptedQueryParam, fullURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -215,6 +276,33 @@ func (c *WeixinChannel) downloadAndDecryptCDNBuffer(
|
||||
return decryptAESECB(data, key)
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) downloadImageBuffer(
|
||||
ctx context.Context,
|
||||
img *ImageItem,
|
||||
key []byte,
|
||||
) ([]byte, error) {
|
||||
if img == nil {
|
||||
return nil, fmt.Errorf("image item is nil")
|
||||
}
|
||||
if img.Media != nil {
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.Media.EncryptQueryParam, img.Media.FullURL, key)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
if img.ThumbMedia == nil {
|
||||
return nil, fmt.Errorf("image download failed: %w", err)
|
||||
}
|
||||
}
|
||||
if img.ThumbMedia != nil {
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, img.ThumbMedia.EncryptQueryParam, img.ThumbMedia.FullURL, key)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
return nil, fmt.Errorf("image download failed: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("image media is nil")
|
||||
}
|
||||
|
||||
func detectMediaMetadata(data []byte, fallbackName, fallbackContentType string) (string, string) {
|
||||
contentType := strings.TrimSpace(fallbackContentType)
|
||||
ext := filepath.Ext(fallbackName)
|
||||
@@ -310,15 +398,18 @@ func isDownloadableMediaItem(item *MessageItem) bool {
|
||||
|
||||
switch item.Type {
|
||||
case MessageItemTypeImage:
|
||||
return item.ImageItem != nil && item.ImageItem.Media != nil && item.ImageItem.Media.EncryptQueryParam != ""
|
||||
return item.ImageItem != nil && item.ImageItem.Media != nil &&
|
||||
(item.ImageItem.Media.EncryptQueryParam != "" || item.ImageItem.Media.FullURL != "")
|
||||
case MessageItemTypeVideo:
|
||||
return item.VideoItem != nil && item.VideoItem.Media != nil && item.VideoItem.Media.EncryptQueryParam != ""
|
||||
return item.VideoItem != nil && item.VideoItem.Media != nil &&
|
||||
(item.VideoItem.Media.EncryptQueryParam != "" || item.VideoItem.Media.FullURL != "")
|
||||
case MessageItemTypeFile:
|
||||
return item.FileItem != nil && item.FileItem.Media != nil && item.FileItem.Media.EncryptQueryParam != ""
|
||||
return item.FileItem != nil && item.FileItem.Media != nil &&
|
||||
(item.FileItem.Media.EncryptQueryParam != "" || item.FileItem.Media.FullURL != "")
|
||||
case MessageItemTypeVoice:
|
||||
return item.VoiceItem != nil &&
|
||||
item.VoiceItem.Media != nil &&
|
||||
item.VoiceItem.Media.EncryptQueryParam != "" &&
|
||||
(item.VoiceItem.Media.EncryptQueryParam != "" || item.VoiceItem.Media.FullURL != "") &&
|
||||
strings.TrimSpace(item.VoiceItem.Text) == ""
|
||||
default:
|
||||
return false
|
||||
@@ -434,16 +525,20 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
|
||||
switch item.Type {
|
||||
case MessageItemTypeImage:
|
||||
if item.ImageItem == nil {
|
||||
return "", fmt.Errorf("image media is nil")
|
||||
}
|
||||
key, ok, err := imageAESKey(item.ImageItem)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.ImageItem.Media.EncryptQueryParam, func() []byte {
|
||||
decryptKey := func() []byte {
|
||||
if ok {
|
||||
return key
|
||||
}
|
||||
return nil
|
||||
}())
|
||||
}()
|
||||
data, err := c.downloadImageBuffer(ctx, item.ImageItem, decryptKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -454,7 +549,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
silk, err := c.downloadAndDecryptCDNBuffer(ctx, item.VoiceItem.Media.EncryptQueryParam, key)
|
||||
silk, err := c.downloadAndDecryptCDNBuffer(
|
||||
ctx,
|
||||
item.VoiceItem.Media.EncryptQueryParam,
|
||||
item.VoiceItem.Media.FullURL,
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -468,7 +568,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.FileItem.Media.EncryptQueryParam, key)
|
||||
data, err := c.downloadAndDecryptCDNBuffer(
|
||||
ctx,
|
||||
item.FileItem.Media.EncryptQueryParam,
|
||||
item.FileItem.Media.FullURL,
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -484,7 +589,12 @@ func (c *WeixinChannel) downloadMediaFromItem(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := c.downloadAndDecryptCDNBuffer(ctx, item.VideoItem.Media.EncryptQueryParam, key)
|
||||
data, err := c.downloadAndDecryptCDNBuffer(
|
||||
ctx,
|
||||
item.VideoItem.Media.EncryptQueryParam,
|
||||
item.VideoItem.Media.FullURL,
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -701,11 +811,13 @@ func (c *WeixinChannel) uploadLocalFile(
|
||||
}
|
||||
return nil, fmt.Errorf("getuploadurl failed: ret=%d errcode=%d errmsg=%s", resp.Ret, resp.Errcode, resp.Errmsg)
|
||||
}
|
||||
if strings.TrimSpace(resp.UploadParam) == "" {
|
||||
return nil, fmt.Errorf("getuploadurl returned empty upload_param")
|
||||
uploadParam := strings.TrimSpace(resp.UploadParam)
|
||||
uploadFullURL := strings.TrimSpace(resp.UploadFullURL)
|
||||
if uploadParam == "" && uploadFullURL == "" {
|
||||
return nil, fmt.Errorf("getuploadurl returned no upload URL")
|
||||
}
|
||||
|
||||
downloadParam, err := c.uploadBufferToCDN(ctx, data, resp.UploadParam, filekey, aesKey)
|
||||
downloadParam, err := c.uploadBufferToCDN(ctx, data, uploadParam, uploadFullURL, filekey, aesKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -723,6 +835,7 @@ func (c *WeixinChannel) uploadBufferToCDN(
|
||||
ctx context.Context,
|
||||
plaintext []byte,
|
||||
uploadParam,
|
||||
uploadFullURL,
|
||||
filekey string,
|
||||
aesKey []byte,
|
||||
) (string, error) {
|
||||
@@ -731,7 +844,13 @@ func (c *WeixinChannel) uploadBufferToCDN(
|
||||
return "", err
|
||||
}
|
||||
|
||||
uploadURL := buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
|
||||
uploadURL := strings.TrimSpace(uploadFullURL)
|
||||
if uploadURL == "" {
|
||||
if strings.TrimSpace(uploadParam) == "" {
|
||||
return "", fmt.Errorf("missing CDN upload URL")
|
||||
}
|
||||
uploadURL = buildCDNUploadURL(c.cdnBaseURL(), uploadParam, filekey)
|
||||
}
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= weixinUploadRetryMax; attempt++ {
|
||||
|
||||
@@ -38,6 +38,7 @@ type GetUploadUrlResp struct {
|
||||
APIStatus
|
||||
UploadParam string `json:"upload_param,omitempty"`
|
||||
ThumbUploadParam string `json:"thumb_upload_param,omitempty"`
|
||||
UploadFullURL string `json:"upload_full_url,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -69,6 +70,7 @@ type CDNMedia struct {
|
||||
EncryptQueryParam string `json:"encrypt_query_param,omitempty"`
|
||||
AesKey string `json:"aes_key,omitempty"` // base64 encoded
|
||||
EncryptType int `json:"encrypt_type,omitempty"`
|
||||
FullURL string `json:"full_url,omitempty"`
|
||||
}
|
||||
|
||||
type ImageItem struct {
|
||||
@@ -202,9 +204,10 @@ type QRCodeResponse struct {
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired"
|
||||
BotToken string `json:"bot_token,omitempty"`
|
||||
IlinkBotID string `json:"ilink_bot_id,omitempty"`
|
||||
Baseurl string `json:"baseurl,omitempty"`
|
||||
IlinkUserID string `json:"ilink_user_id,omitempty"`
|
||||
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired", "scaned_but_redirect"
|
||||
BotToken string `json:"bot_token,omitempty"`
|
||||
IlinkBotID string `json:"ilink_bot_id,omitempty"`
|
||||
Baseurl string `json:"baseurl,omitempty"`
|
||||
IlinkUserID string `json:"ilink_user_id,omitempty"`
|
||||
RedirectHost string `json:"redirect_host,omitempty"`
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", key)
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "", key)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
|
||||
}
|
||||
@@ -81,6 +81,116 @@ func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndDecryptCDNBufferUsesFullURLWhenProvided(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("hello weixin")
|
||||
ciphertext, err := encryptAESECB(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("encryptAESECB() error = %v", err)
|
||||
}
|
||||
|
||||
fullURLAttempts := 0
|
||||
ch := &WeixinChannel{
|
||||
api: &ApiClient{
|
||||
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.String() == "https://full.example.com/download" {
|
||||
fullURLAttempts++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(ciphertext)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
t.Fatalf("unexpected fallback request: %s", r.URL.String())
|
||||
return nil, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", "https://full.example.com/download", key)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
|
||||
}
|
||||
if fullURLAttempts == 0 {
|
||||
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndDecryptCDNBufferFallsBackToConstructedURLWhenFullURLFails(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("hello weixin")
|
||||
ciphertext, err := encryptAESECB(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("encryptAESECB() error = %v", err)
|
||||
}
|
||||
|
||||
fullURLAttempts := 0
|
||||
constructedAttempts := 0
|
||||
ch := &WeixinChannel{
|
||||
api: &ApiClient{
|
||||
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.String() == "https://full.example.com/download?encrypted_query_param=token&taskid=123" {
|
||||
fullURLAttempts++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
if r.URL.String() != "https://cdn.example.com/download?encrypted_query_param=token" {
|
||||
t.Fatalf("unexpected fallback request: %s", r.URL.String())
|
||||
}
|
||||
constructedAttempts++
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(ciphertext)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(
|
||||
context.Background(),
|
||||
"token",
|
||||
"https://full.example.com/download?encrypted_query_param=token&taskid=123",
|
||||
key,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
|
||||
}
|
||||
if fullURLAttempts == 0 {
|
||||
t.Fatalf("fullURLAttempts = %d, want > 0", fullURLAttempts)
|
||||
}
|
||||
if constructedAttempts == 0 {
|
||||
t.Fatalf("constructedAttempts = %d, want > 0", constructedAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCDNDownloadURLEscapesOpaqueToken(t *testing.T) {
|
||||
token := "MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%3D"
|
||||
|
||||
got := buildCDNDownloadURL("https://cdn.example.com", token)
|
||||
|
||||
if got != "https://cdn.example.com/download?encrypted_query_param=MFcCAQAESzBJAgEAAgSieMV9AgM9CcwCBEoKPqICBGnHZB0EJDk4OWY5YWU0LTc4OGItNGQ5Ni1iMjZhLWU4YjhlMmEwOWVkZgIEIR0IAgIBAAQFAExUPQA%253D" {
|
||||
t.Fatalf("buildCDNDownloadURL() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadBufferToCDN(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("upload me")
|
||||
@@ -120,7 +230,7 @@ func TestUploadBufferToCDN(t *testing.T) {
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "file-key", key)
|
||||
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "", "file-key", key)
|
||||
if err != nil {
|
||||
t.Fatalf("uploadBufferToCDN() error = %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user