From 880c402ab7025fd2a65bda487486c854c52647f5 Mon Sep 17 00:00:00 2001 From: Zhang Rui Date: Sat, 28 Feb 2026 23:14:10 +0800 Subject: [PATCH] refactor(wecom): streamline AES encryption/decryption and improve task management logic --- pkg/channels/wecom/aibot.go | 166 +++++++++-------------------------- pkg/channels/wecom/common.go | 135 +++++++++++++++++++++------- 2 files changed, 144 insertions(+), 157 deletions(-) diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go index 788305e36..9003b0777 100644 --- a/pkg/channels/wecom/aibot.go +++ b/pkg/channels/wecom/aibot.go @@ -3,11 +3,8 @@ package wecom import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" "crypto/rand" "encoding/base64" - "encoding/binary" "encoding/json" "fmt" "io" @@ -194,6 +191,12 @@ func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) e } c.taskMu.Lock() queue := c.chatTasks[msg.ChatID] + // Only compact Finished tasks at the head of the queue. + // Tasks that are Finished in the middle are NOT removed here: doing a full + // scan on every Send() call would be O(n) and is unnecessary given that + // removeTask() always splices the task out of the queue immediately. + // Any Finished task left stranded in the middle (e.g. due to an unexpected + // code path) will be collected by cleanupOldTasks. for len(queue) > 0 && queue[0].Finished { queue = queue[1:] } @@ -620,41 +623,6 @@ func (c *WeComAIBotChannel) handleImageMessage( imageURL := msg.Image.URL - // Download and decrypt image - _, err := c.downloadAndDecryptImage(ctx, imageURL) - if err != nil { - logger.ErrorCF("wecom_aibot", "Failed to process image", map[string]any{ - "error": err, - "url": imageURL, - }) - return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ - MsgType: "stream", - Stream: struct { - ID string `json:"id"` - Finish bool `json:"finish"` - Content string `json:"content,omitempty"` - MsgItem []struct { - MsgType string `json:"msgtype"` - Image *struct { - Base64 string `json:"base64"` - MD5 string `json:"md5"` - } `json:"image,omitempty"` - } `json:"msg_item,omitempty"` - }{ - ID: c.generateStreamID(), - Finish: true, - Content: fmt.Sprintf( - "Image received (URL: %s), but image messages are not yet supported", - imageURL, - ), - }, - }) - } - - // Echo back the image (simple demo behavior) - // streamID := c.generateStreamID() - // return c.encryptImageResponse(streamID, timestamp, nonce, imageData) - // For now, just acknowledge receipt without echoing the image return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ MsgType: "stream", @@ -943,70 +911,24 @@ func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string // encryptMessage encrypts a plain text message for WeCom AI Bot func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) { - // Decode AES key - aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=") + aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey) if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) + return "", err } - if len(aesKey) != 32 { - return "", fmt.Errorf("invalid AES key length: %d", len(aesKey)) - } - - // Generate 16-byte random string - randomBytes := make([]byte, 16) - for i := range 16 { - n, randErr := rand.Int(rand.Reader, big.NewInt(10)) - if randErr != nil { - return "", fmt.Errorf("failed to generate random: %w", randErr) - } - randomBytes[i] = byte('0' + n.Int64()) - } - - // Build message: random(16) + msg_len(4) + msg + receiveid - plaintextBytes := []byte(plaintext) - receiveidBytes := []byte(receiveid) - - msgLen := uint32(len(plaintextBytes)) - msgLenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(msgLenBytes, msgLen) - - // Concatenate - var buffer bytes.Buffer - buffer.Write(randomBytes) - buffer.Write(msgLenBytes) - buffer.Write(plaintextBytes) - buffer.Write(receiveidBytes) - - // PKCS7 padding - plainData := buffer.Bytes() - plainData = pkcs7Pad(plainData, blockSize) - - // AES-CBC encrypt - block, err := aes.NewCipher(aesKey) + frame, err := packWeComFrame(plaintext, receiveid) if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) + return "", err } - ciphertext := make([]byte, len(plainData)) - iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCEncrypter(block, iv) - mode.CryptBlocks(ciphertext, plainData) - - // Base64 encode - encoded := base64.StdEncoding.EncodeToString(ciphertext) - - return encoded, nil -} - -// pkcs7Pad adds PKCS7 padding -func pkcs7Pad(data []byte, blockSize int) []byte { - padding := blockSize - (len(data) % blockSize) - if padding == 0 { - padding = blockSize + // PKCS7 padding then AES-CBC encrypt + paddedFrame := pkcs7Pad(frame, blockSize) + ciphertext, err := encryptAESCBC(aesKey, paddedFrame) + if err != nil { + return "", err } - padText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(data, padText...) + + return base64.StdEncoding.EncodeToString(ciphertext), nil } // generateStreamID generates a random stream ID @@ -1060,35 +982,15 @@ func (c *WeComAIBotChannel) downloadAndDecryptImage( }) // Decode AES key - aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=") + aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey) if err != nil { - return nil, fmt.Errorf("failed to decode AES key: %w", err) + return nil, err } - if len(aesKey) != 32 { - return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey)) - } - - // Decrypt image (AES-CBC) - block, err := aes.NewCipher(aesKey) + // Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped) + decryptedData, err := decryptAESCBC(aesKey, encryptedData) if err != nil { - return nil, fmt.Errorf("failed to create cipher: %w", err) - } - - if len(encryptedData)%aes.BlockSize != 0 { - return nil, fmt.Errorf("encrypted data size not multiple of block size") - } - - iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCDecrypter(block, iv) - - decryptedData := make([]byte, len(encryptedData)) - mode.CryptBlocks(decryptedData, encryptedData) - - // Remove PKCS7 padding - decryptedData, err = pkcs7Unpad(decryptedData) - if err != nil { - return nil, fmt.Errorf("failed to unpad: %w", err) + return nil, fmt.Errorf("failed to decrypt image: %w", err) } logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{ @@ -1157,14 +1059,32 @@ func (c *WeComAIBotChannel) cleanupOldTasks() { // (agent had enough time to reply; it is not coming back). for chatID, queue := range c.chatTasks { filtered := queue[:0] - for _, t := range queue { + for i, t := range queue { absoluteExpired := t.CreatedTime.Before(cutoff) graceExpired := t.StreamClosed && !t.StreamClosedAt.IsZero() && t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod)) - if !t.Finished && !absoluteExpired && !graceExpired { + if t.Finished { + // Finished tasks should have been removed by removeTask(). + // Finding one here (especially not at position 0) means an + // unexpected code path left it stranded, causing the queue to + // grow silently. Log a warning so it is visible, then drop it. + if i > 0 { + logger.WarnCF("wecom_aibot", + "Found stranded Finished task in the middle of chatTasks queue; "+ + "this should not happen — removeTask() should have spliced it out", + map[string]any{ + "chat_id": chatID, + "stream_id": t.StreamID, + "position": i, + }) + } + // The task is already finished; its context was already canceled + // by removeTask(), so no further action is required. + continue + } else if !absoluteExpired && !graceExpired { filtered = append(filtered, t) - } else if !t.Finished { + } else { t.cancel() // cancel any lingering agent goroutine } } diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go index b1b5399f4..6510e6f81 100644 --- a/pkg/channels/wecom/common.go +++ b/pkg/channels/wecom/common.go @@ -1,12 +1,15 @@ package wecom import ( + "bytes" "crypto/aes" "crypto/cipher" + "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/binary" "fmt" + "math/big" "sort" "strings" ) @@ -51,64 +54,128 @@ func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (s return string(decoded), nil } - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + aesKey, err := decodeWeComAESKey(encodingAESKey) if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) + return "", err } - // Decode encrypted message cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) if err != nil { return "", fmt.Errorf("failed to decode message: %w", err) } - // AES decrypt + plainText, err := decryptAESCBC(aesKey, cipherText) + if err != nil { + return "", err + } + + return unpackWeComFrame(plainText, receiveid) +} + +// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is +// appended automatically) and validates that the result is exactly 32 bytes. +// It is the single place that handles this repeated pattern in both encrypt and decrypt paths. +func decodeWeComAESKey(encodingAESKey string) ([]byte, error) { + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return nil, fmt.Errorf("failed to decode AES key: %w", err) + } + if len(aesKey) != 32 { + return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey)) + } + return aesKey, nil +} + +// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring +// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the +// plaintext to a multiple of aes.BlockSize before calling. +func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) { block, err := aes.NewCipher(aesKey) if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) + return nil, fmt.Errorf("failed to create cipher: %w", err) } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - // IV is the first 16 bytes of AESKey iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCDecrypter(block, iv) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) + ciphertext := make([]byte, len(plaintext)) + cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext) + return ciphertext, nil +} - // Remove PKCS7 padding - plainText, err = pkcs7Unpad(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) +// packWeComFrame builds the WeCom wire format: +// +// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid +func packWeComFrame(msg, receiveid string) ([]byte, error) { + randomBytes := make([]byte, 16) + for i := range 16 { + n, err := rand.Int(rand.Reader, big.NewInt(10)) + if err != nil { + return nil, fmt.Errorf("failed to generate random: %w", err) + } + randomBytes[i] = byte('0' + n.Int64()) } + msgBytes := []byte(msg) + msgLenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes))) + var buf bytes.Buffer + buf.Write(randomBytes) + buf.Write(msgLenBytes) + buf.Write(msgBytes) + buf.WriteString(receiveid) + return buf.Bytes(), nil +} - // Parse message structure - // Format: random(16) + msg_len(4) + msg + receiveid - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") +// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame. +// If receiveid is non-empty it verifies the frame's trailing receiveid field. +func unpackWeComFrame(data []byte, receiveid string) (string, error) { + if len(data) < 20 { + return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data)) } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") + msgLen := binary.BigEndian.Uint32(data[16:20]) + if int(msgLen) > len(data)-20 { + return "", fmt.Errorf("invalid message length: %d", msgLen) } - - msg := plainText[20 : 20+msgLen] - - // Verify receiveid if provided - if receiveid != "" && len(plainText) > 20+int(msgLen) { - actualReceiveID := string(plainText[20+msgLen:]) + msg := data[20 : 20+msgLen] + if receiveid != "" && len(data) > 20+int(msgLen) { + actualReceiveID := string(data[20+msgLen:]) if actualReceiveID != receiveid { return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) } } - return string(msg), nil } +// decryptAESCBC decrypts ciphertext using AES-CBC with the given key. +// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext. +func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) { + if len(ciphertext) == 0 { + return nil, fmt.Errorf("ciphertext is empty") + } + if len(ciphertext)%aes.BlockSize != 0 { + return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext)) + } + block, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + iv := aesKey[:aes.BlockSize] + plaintext := make([]byte, len(ciphertext)) + cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext) + plaintext, err = pkcs7Unpad(plaintext) + if err != nil { + return nil, fmt.Errorf("failed to unpad: %w", err) + } + return plaintext, nil +} + +// pkcs7Pad adds PKCS7 padding +func pkcs7Pad(data []byte, blockSize int) []byte { + padding := blockSize - (len(data) % blockSize) + if padding == 0 { + padding = blockSize + } + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + // pkcs7Unpad removes PKCS7 padding with validation func pkcs7Unpad(data []byte) ([]byte, error) { if len(data) == 0 {