From 14ccfb39d94cd2ede66af4d2b21b1765116153fc Mon Sep 17 00:00:00 2001 From: swordkee Date: Fri, 20 Feb 2026 18:28:10 +0800 Subject: [PATCH] feat: add wecom and wecomApp test --- pkg/channels/wecom.go | 265 +++++++++++++-------------------- pkg/channels/wecom_app.go | 115 +------------- pkg/channels/wecom_app_test.go | 20 +-- pkg/channels/wecom_common.go | 117 +++++++++++++++ pkg/channels/wecom_test.go | 210 +++++++++++++++++--------- 5 files changed, 371 insertions(+), 356 deletions(-) create mode 100644 pkg/channels/wecom_common.go diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom.go index 5d4e14697..33afef17a 100644 --- a/pkg/channels/wecom.go +++ b/pkg/channels/wecom.go @@ -7,17 +7,11 @@ package channels import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" "encoding/json" "encoding/xml" "fmt" "io" "net/http" - "sort" "strings" "sync" "time" @@ -40,40 +34,54 @@ type WeComBotChannel struct { msgMu sync.RWMutex } -// WeComBotXMLMessage represents the XML message structure from WeCom Bot -type WeComBotXMLMessage struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - FromUserName string `xml:"FromUserName"` - CreateTime int64 `xml:"CreateTime"` - MsgType string `xml:"MsgType"` - Content string `xml:"Content"` - MsgId int64 `xml:"MsgId"` - PicUrl string `xml:"PicUrl"` - MediaId string `xml:"MediaId"` - Format string `xml:"Format"` - Recognition string `xml:"Recognition"` // Voice recognition result +// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) +type WeComBotMessage struct { + MsgID string `json:"msgid"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid"` // Session ID, only present for group chats + ChatType string `json:"chattype"` // "single" for DM, "group" for group chat + From struct { + UserID string `json:"userid"` + } `json:"from"` + ResponseURL string `json:"response_url"` + MsgType string `json:"msgtype"` // text, image, voice, file, mixed + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + Voice struct { + Content string `json:"content"` // Voice to text content + } `json:"voice"` + File struct { + URL string `json:"url"` + } `json:"file"` + Mixed struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + } `json:"msg_item"` + } `json:"mixed"` + Quote struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + } `json:"quote"` } // WeComBotReplyMessage represents the reply message structure type WeComBotReplyMessage struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - FromUserName string `xml:"FromUserName"` - CreateTime int64 `xml:"CreateTime"` - MsgType string `xml:"MsgType"` - Content string `xml:"Content"` -} - -// WeComBotWebhookReply represents the webhook API reply -type WeComBotWebhookReply struct { MsgType string `json:"msgtype"` Text struct { Content string `json:"content"` } `json:"text,omitempty"` - Markdown struct { - Content string `json:"content"` - } `json:"markdown,omitempty"` } // NewWeComBotChannel creates a new WeCom Bot channel instance @@ -205,14 +213,14 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !c.verifySignature(msgSignature, timestamp, nonce, echostr) { + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnC("wecom", "Signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return } // Decrypt echostr - decryptedEchoStr, err := c.decryptMessage(echostr) + decryptedEchoStr, err := WeComDecryptMessage(echostr, c.config.EncodingAESKey) if err != nil { logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]interface{}{ "error": err.Error(), @@ -265,14 +273,14 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !c.verifySignature(msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return } // Decrypt message - decryptedMsg, err := c.decryptMessage(encryptedMsg.Encrypt) + decryptedMsg, err := WeComDecryptMessage(encryptedMsg.Encrypt, c.config.EncodingAESKey) if err != nil { logger.ErrorCF("wecom", "Failed to decrypt message", map[string]interface{}{ "error": err.Error(), @@ -281,9 +289,9 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp return } - // Parse decrypted XML message - var msg WeComBotXMLMessage - if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { + // Parse decrypted JSON message (AIBOT uses JSON format) + var msg WeComBotMessage + if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]interface{}{ "error": err.Error(), }) @@ -300,9 +308,9 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp } // processMessage processes the received message -func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotXMLMessage) { - // Skip non-text messages for now (can be extended) - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { +func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { + // Skip unsupported message types + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && msg.MsgType != "mixed" { logger.DebugCF("wecom", "Skipping non-supported message type", map[string]interface{}{ "msg_type": msg.MsgType, }) @@ -310,8 +318,7 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotXMLMes } // Message deduplication: Use msg_id to prevent duplicate processing - // As per WeCom documentation, use msg_id for deduplication - msgID := fmt.Sprintf("%d", msg.MsgId) + msgID := msg.MsgID c.msgMu.Lock() if c.processedMsgs[msgID] { c.msgMu.Unlock() @@ -330,141 +337,73 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotXMLMes c.msgMu.Unlock() } - senderID := msg.FromUserName - chatID := senderID // WeCom Bot uses user ID as chat ID + senderID := msg.From.UserID - // Use voice recognition result if available - content := msg.Content - if msg.MsgType == "voice" && msg.Recognition != "" { - content = msg.Recognition + // Determine if this is a group chat or direct message + // ChatType: "single" for DM, "group" for group chat + isGroupChat := msg.ChatType == "group" + + var chatID, peerKind, peerID string + if isGroupChat { + // Group chat: use ChatID as chatID and peer_id + chatID = msg.ChatID + peerKind = "group" + peerID = msg.ChatID + } else { + // Direct message: use senderID as chatID and peer_id + chatID = senderID + peerKind = "direct" + peerID = senderID + } + + // Extract content based on message type + var content string + switch msg.MsgType { + case "text": + content = msg.Text.Content + case "voice": + content = msg.Voice.Content // Voice to text content + case "mixed": + // For mixed messages, concatenate text items + for _, item := range msg.Mixed.MsgItem { + if item.MsgType == "text" { + content += item.Text.Content + } + } + case "image", "file": + // For image and file, we don't have text content + content = "" } // Build metadata - // WeCom Bot only supports direct messages (private chat) metadata := map[string]string{ - "msg_type": msg.MsgType, - "msg_id": fmt.Sprintf("%d", msg.MsgId), - "platform": "wecom", - "media_id": msg.MediaId, - "create_time": fmt.Sprintf("%d", msg.CreateTime), - "peer_kind": "direct", - "peer_id": senderID, + "msg_type": msg.MsgType, + "msg_id": msg.MsgID, + "platform": "wecom", + "peer_kind": peerKind, + "peer_id": peerID, + "response_url": msg.ResponseURL, + } + if isGroupChat { + metadata["chat_id"] = msg.ChatID + metadata["sender_id"] = senderID } logger.DebugCF("wecom", "Received message", map[string]interface{}{ - "sender_id": senderID, - "msg_type": msg.MsgType, - "preview": utils.Truncate(content, 50), + "sender_id": senderID, + "msg_type": msg.MsgType, + "peer_kind": peerKind, + "is_group_chat": isGroupChat, + "preview": utils.Truncate(content, 50), }) // Handle the message through the base channel c.HandleMessage(senderID, chatID, content, nil, metadata) } -// verifySignature verifies the message signature -func (c *WeComBotChannel) verifySignature(msgSignature, timestamp, nonce, msgEncrypt string) bool { - if c.config.Token == "" { - return true // Skip verification if token is not set - } - - // Sort parameters - params := []string{c.config.Token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - - // Concatenate - str := strings.Join(params, "") - - // SHA1 hash - hash := sha1.Sum([]byte(str)) - expectedSignature := fmt.Sprintf("%x", hash) - - return expectedSignature == msgSignature -} - -// decryptMessage decrypts the encrypted message using AES -func (c *WeComBotChannel) decryptMessage(encryptedMsg string) (string, error) { - if c.config.EncodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=") - if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) - } - - // Decode encrypted message - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - // AES decrypt - block, err := aes.NewCipher(aesKey) - if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) - } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - mode := cipher.NewCBCDecrypter(block, aesKey[:aes.BlockSize]) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) - - // Remove PKCS7 padding - plainText, err = pkcs7UnpadWeCom(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) - } - - // Parse message structure - // Format: random(16) + msg_len(4) + msg + corp_id - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") - } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") - } - - msg := plainText[20 : 20+msgLen] - // corpID := plainText[20+msgLen:] // Could be used for verification - - return string(msg), nil -} - -// pkcs7UnpadWeCom removes PKCS7 padding with validation -func pkcs7UnpadWeCom(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - if padding == 0 || padding > aes.BlockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := 0; i < padding; i++ { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} - // sendWebhookReply sends a reply using the webhook URL func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error { - reply := WeComBotWebhookReply{ + reply := WeComBotReplyMessage{ MsgType: "text", } reply.Text.Content = content diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom_app.go index c1d0ebaad..783d381f2 100644 --- a/pkg/channels/wecom_app.go +++ b/pkg/channels/wecom_app.go @@ -7,18 +7,12 @@ package channels import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" "encoding/json" "encoding/xml" "fmt" "io" "net/http" "net/url" - "sort" "strings" "sync" "time" @@ -265,14 +259,14 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !c.verifySignature(msgSignature, timestamp, nonce, echostr) { + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnC("wecom_app", "Signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return } // Decrypt echostr - decryptedEchoStr, err := c.decryptMessage(echostr) + decryptedEchoStr, err := WeComDecryptMessage(echostr, c.config.EncodingAESKey) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]interface{}{ "error": err.Error(), @@ -325,14 +319,14 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !c.verifySignature(msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom_app", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return } // Decrypt message - decryptedMsg, err := c.decryptMessage(encryptedMsg.Encrypt) + decryptedMsg, err := WeComDecryptMessage(encryptedMsg.Encrypt, c.config.EncodingAESKey) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]interface{}{ "error": err.Error(), @@ -418,107 +412,6 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag c.HandleMessage(senderID, chatID, content, nil, metadata) } -// verifySignature verifies the message signature -func (c *WeComAppChannel) verifySignature(msgSignature, timestamp, nonce, msgEncrypt string) bool { - if c.config.Token == "" { - return true // Skip verification if token is not set - } - - // Sort parameters - params := []string{c.config.Token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - - // Concatenate - str := strings.Join(params, "") - - // SHA1 hash - hash := sha1.Sum([]byte(str)) - expectedSignature := fmt.Sprintf("%x", hash) - - return expectedSignature == msgSignature -} - -// decryptMessage decrypts the encrypted message using AES -func (c *WeComAppChannel) decryptMessage(encryptedMsg string) (string, error) { - if c.config.EncodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=") - if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) - } - - // Decode encrypted message - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - // AES decrypt - block, err := aes.NewCipher(aesKey) - if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) - } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - mode := cipher.NewCBCDecrypter(block, aesKey[:aes.BlockSize]) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) - - // Remove PKCS7 padding - plainText, err = pkcs7Unpad(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) - } - - // Parse message structure - // Format: random(16) + msg_len(4) + msg + corp_id - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") - } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") - } - - msg := plainText[20 : 20+msgLen] - // corpID := plainText[20+msgLen:] // Can be used for verification - - return string(msg), nil -} - -// pkcs7Unpad removes PKCS7 padding with validation -func pkcs7Unpad(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - if padding == 0 || padding > aes.BlockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := 0; i < padding; i++ { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} - // tokenRefreshLoop periodically refreshes the access token func (c *WeComAppChannel) tokenRefreshLoop() { ticker := time.NewTicker(5 * time.Minute) diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom_app_test.go index 4283c07e6..bc40806bb 100644 --- a/pkg/channels/wecom_app_test.go +++ b/pkg/channels/wecom_app_test.go @@ -197,7 +197,7 @@ func TestWeComAppVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) - if !ch.verifySignature(expectedSig, timestamp, nonce, msgEncrypt) { + if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -207,7 +207,7 @@ func TestWeComAppVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if ch.verifySignature("invalid_sig", timestamp, nonce, msgEncrypt) { + if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -221,7 +221,7 @@ func TestWeComAppVerifySignature(t *testing.T) { } chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - if !chEmpty.verifySignature("any_sig", "any_ts", "any_nonce", "any_msg") { + if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -243,7 +243,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := ch.decryptMessage(encoded) + result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -268,12 +268,12 @@ func TestWeComAppDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := ch.decryptMessage(encrypted) + result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } if result != originalMsg { - t.Errorf("decryptMessage() = %q, want %q", result, originalMsg) + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) } }) @@ -286,7 +286,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := ch.decryptMessage("invalid_base64!!!") + _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -301,7 +301,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := ch.decryptMessage(base64.StdEncoding.EncodeToString([]byte("test"))) + _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -319,7 +319,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { // Encrypt a very short message that results in ciphertext less than block size shortData := make([]byte, 8) - _, err := ch.decryptMessage(base64.StdEncoding.EncodeToString(shortData)) + _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for short ciphertext, got nil") } @@ -361,7 +361,7 @@ func TestWeComAppPKCS7Unpad(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7Unpad(tt.input) + result, err := pkcs7UnpadWeCom(tt.input) if tt.expected == nil { // This case should return an error if err == nil { diff --git a/pkg/channels/wecom_common.go b/pkg/channels/wecom_common.go new file mode 100644 index 000000000..16a25fad6 --- /dev/null +++ b/pkg/channels/wecom_common.go @@ -0,0 +1,117 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// WeCom common utilities for both WeCom Bot and WeCom App + +package channels + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "sort" + "strings" +) + +// WeComVerifySignature verifies the message signature for WeCom +// This is a common function used by both WeCom Bot and WeCom App +func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { + if token == "" { + return true // Skip verification if token is not set + } + + // Sort parameters + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + + // Concatenate + str := strings.Join(params, "") + + // SHA1 hash + hash := sha1.Sum([]byte(str)) + expectedSignature := fmt.Sprintf("%x", hash) + + return expectedSignature == msgSignature +} + +// WeComDecryptMessage decrypts the encrypted message using AES +// This is a common function used by both WeCom Bot and WeCom App +func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) { + if encodingAESKey == "" { + // No encryption, return as is (base64 decode) + decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", err + } + return string(decoded), nil + } + + // Decode AES key (base64) + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", fmt.Errorf("failed to decode AES key: %w", err) + } + + // Decode encrypted message + cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", fmt.Errorf("failed to decode message: %w", err) + } + + // AES decrypt + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + if len(cipherText) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + mode := cipher.NewCBCDecrypter(block, aesKey[:aes.BlockSize]) + plainText := make([]byte, len(cipherText)) + mode.CryptBlocks(plainText, cipherText) + + // Remove PKCS7 padding + plainText, err = pkcs7UnpadWeCom(plainText) + if err != nil { + return "", fmt.Errorf("failed to unpad: %w", err) + } + + // Parse message structure + // Format: random(16) + msg_len(4) + msg + corp_id + if len(plainText) < 20 { + return "", fmt.Errorf("decrypted message too short") + } + + msgLen := binary.BigEndian.Uint32(plainText[16:20]) + if int(msgLen) > len(plainText)-20 { + return "", fmt.Errorf("invalid message length") + } + + msg := plainText[20 : 20+msgLen] + + return string(msg), nil +} + +// pkcs7UnpadWeCom removes PKCS7 padding with validation +func pkcs7UnpadWeCom(data []byte) ([]byte, error) { + if len(data) == 0 { + return data, nil + } + padding := int(data[len(data)-1]) + if padding == 0 || padding > aes.BlockSize { + return nil, fmt.Errorf("invalid padding size: %d", padding) + } + if padding > len(data) { + return nil, fmt.Errorf("padding size larger than data") + } + // Verify all padding bytes + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte at position %d", i) + } + } + return data[:len(data)-padding], nil +} diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom_test.go index a2015a8d3..c3f889c64 100644 --- a/pkg/channels/wecom_test.go +++ b/pkg/channels/wecom_test.go @@ -11,6 +11,7 @@ import ( "crypto/sha1" "encoding/base64" "encoding/binary" + "encoding/json" "encoding/xml" "fmt" "net/http" @@ -34,7 +35,7 @@ func generateTestAESKey() string { return base64.StdEncoding.EncodeToString(key)[:43] } -// encryptTestMessage encrypts a message for testing +// encryptTestMessage encrypts a message for testing (AIBOT JSON format) func encryptTestMessage(message, aesKey string) (string, error) { // Decode AES key key, err := base64.StdEncoding.DecodeString(aesKey + "=") @@ -42,14 +43,14 @@ func encryptTestMessage(message, aesKey string) (string, error) { return "", err } - // Prepare message: random(16) + msg_len(4) + msg + corp_id + // Prepare message: random(16) + msg_len(4) + msg + receiveid random := make([]byte, 0, 16) for i := 0; i < 16; i++ { random = append(random, byte(i)) } msgBytes := []byte(message) - corpID := []byte("test_corp_id") + receiveID := []byte("test_aibot_id") msgLen := uint32(len(msgBytes)) lenBytes := make([]byte, 4) @@ -57,7 +58,7 @@ func encryptTestMessage(message, aesKey string) (string, error) { plainText := append(random, lenBytes...) plainText = append(plainText, msgBytes...) - plainText = append(plainText, corpID...) + plainText = append(plainText, receiveID...) // PKCS7 padding blockSize := aes.BlockSize @@ -176,7 +177,7 @@ func TestWeComBotVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) - if !ch.verifySignature(expectedSig, timestamp, nonce, msgEncrypt) { + if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -186,7 +187,7 @@ func TestWeComBotVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if ch.verifySignature("invalid_sig", timestamp, nonce, msgEncrypt) { + if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -203,7 +204,7 @@ func TestWeComBotVerifySignature(t *testing.T) { config: cfgEmpty, } - if !chEmpty.verifySignature("any_sig", "any_ts", "any_nonce", "any_msg") { + if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -224,7 +225,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := ch.decryptMessage(encoded) + result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -248,12 +249,12 @@ func TestWeComBotDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := ch.decryptMessage(encrypted) + result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } if result != originalMsg { - t.Errorf("decryptMessage() = %q, want %q", result, originalMsg) + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) } }) @@ -265,7 +266,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := ch.decryptMessage("invalid_base64!!!") + _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -279,7 +280,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := ch.decryptMessage(base64.StdEncoding.EncodeToString([]byte("test"))) + _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -408,20 +409,62 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - t.Run("valid message callback", func(t *testing.T) { - // Create XML message - xmlMsg := WeComBotXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, - } - xmlData, _ := xml.Marshal(xmlMsg) + t.Run("valid direct message callback", func(t *testing.T) { + // Create JSON message for direct chat (single) + jsonMsg := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chattype": "single", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` // Encrypt message - encrypted, _ := encryptTestMessage(string(xmlData), aesKey) + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("valid group message callback", func(t *testing.T) { + // Create JSON message for group chat + jsonMsg := `{ + "msgid": "test_msg_id_456", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user456"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello Group"} + }` + + // Encrypt message + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) // Create encrypted XML wrapper encryptedWrapper := struct { @@ -506,42 +549,61 @@ func TestWeComBotProcessMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - t.Run("process text message", func(t *testing.T) { - msg := WeComBotXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, + t.Run("process direct text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_123", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", } + msg.From.UserID = "user123" + msg.Text.Content = "Hello World" // Should not panic ch.processMessage(context.Background(), msg) }) - t.Run("process voice message with recognition", func(t *testing.T) { - msg := WeComBotXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "voice", - Recognition: "Voice message text", - MsgId: 123456, + t.Run("process group text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_456", + AIBotID: "test_aibot_id", + ChatID: "group_chat_id_123", + ChatType: "group", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", } + msg.From.UserID = "user456" + msg.Text.Content = "Hello Group" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process voice message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_789", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "voice", + } + msg.From.UserID = "user123" + msg.Voice.Content = "Voice message text" // Should not panic ch.processMessage(context.Background(), msg) }) t.Run("skip unsupported message type", func(t *testing.T) { - msg := WeComBotXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "video", - MsgId: 123456, + msg := WeComBotMessage{ + MsgID: "test_msg_id_000", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "video", } + msg.From.UserID = "user123" // Should not panic ch.processMessage(context.Background(), msg) @@ -637,8 +699,8 @@ func TestWeComBotHandleHealth(t *testing.T) { } } -func TestWeComBotWebhookReplyMessage(t *testing.T) { - msg := WeComBotWebhookReply{ +func TestWeComBotReplyMessage(t *testing.T) { + msg := WeComBotReplyMessage{ MsgType: "text", } msg.Text.Content = "Hello World" @@ -651,39 +713,43 @@ func TestWeComBotWebhookReplyMessage(t *testing.T) { } } -func TestWeComBotXMLMessageStructure(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - 1234567890123456 -` +func TestWeComBotMessageStructure(t *testing.T) { + jsonData := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` - var msg WeComBotXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) + var msg WeComBotMessage + err := json.Unmarshal([]byte(jsonData), &msg) if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) + t.Fatalf("failed to unmarshal JSON: %v", err) } - if msg.ToUserName != "corp_id" { - t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id") + if msg.MsgID != "test_msg_id_123" { + t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123") } - if msg.FromUserName != "user123" { - t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123") + if msg.AIBotID != "test_aibot_id" { + t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id") } - if msg.CreateTime != 1234567890 { - t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890) + if msg.ChatID != "group_chat_id_123" { + t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123") + } + if msg.ChatType != "group" { + t.Errorf("ChatType = %q, want %q", msg.ChatType, "group") + } + if msg.From.UserID != "user123" { + t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123") } if msg.MsgType != "text" { t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") } - if msg.Content != "Hello World" { - t.Errorf("Content = %q, want %q", msg.Content, "Hello World") - } - if msg.MsgId != 1234567890123456 { - t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456) + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") } }