refactor(wecom): streamline AES encryption/decryption and improve task management logic

This commit is contained in:
Zhang Rui
2026-02-28 23:14:10 +08:00
parent 8f3d611a4c
commit 880c402ab7
2 changed files with 144 additions and 157 deletions
+43 -123
View File
@@ -3,11 +3,8 @@ package wecom
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
@@ -194,6 +191,12 @@ func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) e
}
c.taskMu.Lock()
queue := c.chatTasks[msg.ChatID]
// Only compact Finished tasks at the head of the queue.
// Tasks that are Finished in the middle are NOT removed here: doing a full
// scan on every Send() call would be O(n) and is unnecessary given that
// removeTask() always splices the task out of the queue immediately.
// Any Finished task left stranded in the middle (e.g. due to an unexpected
// code path) will be collected by cleanupOldTasks.
for len(queue) > 0 && queue[0].Finished {
queue = queue[1:]
}
@@ -620,41 +623,6 @@ func (c *WeComAIBotChannel) handleImageMessage(
imageURL := msg.Image.URL
// Download and decrypt image
_, err := c.downloadAndDecryptImage(ctx, imageURL)
if err != nil {
logger.ErrorCF("wecom_aibot", "Failed to process image", map[string]any{
"error": err,
"url": imageURL,
})
return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
MsgType: "stream",
Stream: struct {
ID string `json:"id"`
Finish bool `json:"finish"`
Content string `json:"content,omitempty"`
MsgItem []struct {
MsgType string `json:"msgtype"`
Image *struct {
Base64 string `json:"base64"`
MD5 string `json:"md5"`
} `json:"image,omitempty"`
} `json:"msg_item,omitempty"`
}{
ID: c.generateStreamID(),
Finish: true,
Content: fmt.Sprintf(
"Image received (URL: %s), but image messages are not yet supported",
imageURL,
),
},
})
}
// Echo back the image (simple demo behavior)
// streamID := c.generateStreamID()
// return c.encryptImageResponse(streamID, timestamp, nonce, imageData)
// For now, just acknowledge receipt without echoing the image
return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
MsgType: "stream",
@@ -943,70 +911,24 @@ func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string
// encryptMessage encrypts a plain text message for WeCom AI Bot
func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) {
// Decode AES key
aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=")
aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
if err != nil {
return "", fmt.Errorf("failed to decode AES key: %w", err)
return "", err
}
if len(aesKey) != 32 {
return "", fmt.Errorf("invalid AES key length: %d", len(aesKey))
}
// Generate 16-byte random string
randomBytes := make([]byte, 16)
for i := range 16 {
n, randErr := rand.Int(rand.Reader, big.NewInt(10))
if randErr != nil {
return "", fmt.Errorf("failed to generate random: %w", randErr)
}
randomBytes[i] = byte('0' + n.Int64())
}
// Build message: random(16) + msg_len(4) + msg + receiveid
plaintextBytes := []byte(plaintext)
receiveidBytes := []byte(receiveid)
msgLen := uint32(len(plaintextBytes))
msgLenBytes := make([]byte, 4)
binary.BigEndian.PutUint32(msgLenBytes, msgLen)
// Concatenate
var buffer bytes.Buffer
buffer.Write(randomBytes)
buffer.Write(msgLenBytes)
buffer.Write(plaintextBytes)
buffer.Write(receiveidBytes)
// PKCS7 padding
plainData := buffer.Bytes()
plainData = pkcs7Pad(plainData, blockSize)
// AES-CBC encrypt
block, err := aes.NewCipher(aesKey)
frame, err := packWeComFrame(plaintext, receiveid)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
return "", err
}
ciphertext := make([]byte, len(plainData))
iv := aesKey[:aes.BlockSize]
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext, plainData)
// Base64 encode
encoded := base64.StdEncoding.EncodeToString(ciphertext)
return encoded, nil
}
// pkcs7Pad adds PKCS7 padding
func pkcs7Pad(data []byte, blockSize int) []byte {
padding := blockSize - (len(data) % blockSize)
if padding == 0 {
padding = blockSize
// PKCS7 padding then AES-CBC encrypt
paddedFrame := pkcs7Pad(frame, blockSize)
ciphertext, err := encryptAESCBC(aesKey, paddedFrame)
if err != nil {
return "", err
}
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padText...)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// generateStreamID generates a random stream ID
@@ -1060,35 +982,15 @@ func (c *WeComAIBotChannel) downloadAndDecryptImage(
})
// Decode AES key
aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=")
aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
if err != nil {
return nil, fmt.Errorf("failed to decode AES key: %w", err)
return nil, err
}
if len(aesKey) != 32 {
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
}
// Decrypt image (AES-CBC)
block, err := aes.NewCipher(aesKey)
// Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped)
decryptedData, err := decryptAESCBC(aesKey, encryptedData)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
if len(encryptedData)%aes.BlockSize != 0 {
return nil, fmt.Errorf("encrypted data size not multiple of block size")
}
iv := aesKey[:aes.BlockSize]
mode := cipher.NewCBCDecrypter(block, iv)
decryptedData := make([]byte, len(encryptedData))
mode.CryptBlocks(decryptedData, encryptedData)
// Remove PKCS7 padding
decryptedData, err = pkcs7Unpad(decryptedData)
if err != nil {
return nil, fmt.Errorf("failed to unpad: %w", err)
return nil, fmt.Errorf("failed to decrypt image: %w", err)
}
logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{
@@ -1157,14 +1059,32 @@ func (c *WeComAIBotChannel) cleanupOldTasks() {
// (agent had enough time to reply; it is not coming back).
for chatID, queue := range c.chatTasks {
filtered := queue[:0]
for _, t := range queue {
for i, t := range queue {
absoluteExpired := t.CreatedTime.Before(cutoff)
graceExpired := t.StreamClosed &&
!t.StreamClosedAt.IsZero() &&
t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod))
if !t.Finished && !absoluteExpired && !graceExpired {
if t.Finished {
// Finished tasks should have been removed by removeTask().
// Finding one here (especially not at position 0) means an
// unexpected code path left it stranded, causing the queue to
// grow silently. Log a warning so it is visible, then drop it.
if i > 0 {
logger.WarnCF("wecom_aibot",
"Found stranded Finished task in the middle of chatTasks queue; "+
"this should not happen — removeTask() should have spliced it out",
map[string]any{
"chat_id": chatID,
"stream_id": t.StreamID,
"position": i,
})
}
// The task is already finished; its context was already canceled
// by removeTask(), so no further action is required.
continue
} else if !absoluteExpired && !graceExpired {
filtered = append(filtered, t)
} else if !t.Finished {
} else {
t.cancel() // cancel any lingering agent goroutine
}
}
+101 -34
View File
@@ -1,12 +1,15 @@
package wecom
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"math/big"
"sort"
"strings"
)
@@ -51,64 +54,128 @@ func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (s
return string(decoded), nil
}
// Decode AES key (base64)
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
aesKey, err := decodeWeComAESKey(encodingAESKey)
if err != nil {
return "", fmt.Errorf("failed to decode AES key: %w", err)
return "", err
}
// Decode encrypted message
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
if err != nil {
return "", fmt.Errorf("failed to decode message: %w", err)
}
// AES decrypt
plainText, err := decryptAESCBC(aesKey, cipherText)
if err != nil {
return "", err
}
return unpackWeComFrame(plainText, receiveid)
}
// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
// appended automatically) and validates that the result is exactly 32 bytes.
// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return nil, fmt.Errorf("failed to decode AES key: %w", err)
}
if len(aesKey) != 32 {
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
}
return aesKey, nil
}
// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
// plaintext to a multiple of aes.BlockSize before calling.
func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(aesKey)
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
if len(cipherText) < aes.BlockSize {
return "", fmt.Errorf("ciphertext too short")
}
// IV is the first 16 bytes of AESKey
iv := aesKey[:aes.BlockSize]
mode := cipher.NewCBCDecrypter(block, iv)
plainText := make([]byte, len(cipherText))
mode.CryptBlocks(plainText, cipherText)
ciphertext := make([]byte, len(plaintext))
cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
return ciphertext, nil
}
// Remove PKCS7 padding
plainText, err = pkcs7Unpad(plainText)
if err != nil {
return "", fmt.Errorf("failed to unpad: %w", err)
// packWeComFrame builds the WeCom wire format:
//
// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
func packWeComFrame(msg, receiveid string) ([]byte, error) {
randomBytes := make([]byte, 16)
for i := range 16 {
n, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
return nil, fmt.Errorf("failed to generate random: %w", err)
}
randomBytes[i] = byte('0' + n.Int64())
}
msgBytes := []byte(msg)
msgLenBytes := make([]byte, 4)
binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
var buf bytes.Buffer
buf.Write(randomBytes)
buf.Write(msgLenBytes)
buf.Write(msgBytes)
buf.WriteString(receiveid)
return buf.Bytes(), nil
}
// Parse message structure
// Format: random(16) + msg_len(4) + msg + receiveid
if len(plainText) < 20 {
return "", fmt.Errorf("decrypted message too short")
// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
// If receiveid is non-empty it verifies the frame's trailing receiveid field.
func unpackWeComFrame(data []byte, receiveid string) (string, error) {
if len(data) < 20 {
return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
}
msgLen := binary.BigEndian.Uint32(plainText[16:20])
if int(msgLen) > len(plainText)-20 {
return "", fmt.Errorf("invalid message length")
msgLen := binary.BigEndian.Uint32(data[16:20])
if int(msgLen) > len(data)-20 {
return "", fmt.Errorf("invalid message length: %d", msgLen)
}
msg := plainText[20 : 20+msgLen]
// Verify receiveid if provided
if receiveid != "" && len(plainText) > 20+int(msgLen) {
actualReceiveID := string(plainText[20+msgLen:])
msg := data[20 : 20+msgLen]
if receiveid != "" && len(data) > 20+int(msgLen) {
actualReceiveID := string(data[20+msgLen:])
if actualReceiveID != receiveid {
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
}
}
return string(msg), nil
}
// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
if len(ciphertext) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
}
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
}
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
iv := aesKey[:aes.BlockSize]
plaintext := make([]byte, len(ciphertext))
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
plaintext, err = pkcs7Unpad(plaintext)
if err != nil {
return nil, fmt.Errorf("failed to unpad: %w", err)
}
return plaintext, nil
}
// pkcs7Pad adds PKCS7 padding
func pkcs7Pad(data []byte, blockSize int) []byte {
padding := blockSize - (len(data) % blockSize)
if padding == 0 {
padding = blockSize
}
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padText...)
}
// pkcs7Unpad removes PKCS7 padding with validation
func pkcs7Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {