mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(wecom): streamline AES encryption/decryption and improve task management logic
This commit is contained in:
+43
-123
@@ -3,11 +3,8 @@ package wecom
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -194,6 +191,12 @@ func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) e
|
||||
}
|
||||
c.taskMu.Lock()
|
||||
queue := c.chatTasks[msg.ChatID]
|
||||
// Only compact Finished tasks at the head of the queue.
|
||||
// Tasks that are Finished in the middle are NOT removed here: doing a full
|
||||
// scan on every Send() call would be O(n) and is unnecessary given that
|
||||
// removeTask() always splices the task out of the queue immediately.
|
||||
// Any Finished task left stranded in the middle (e.g. due to an unexpected
|
||||
// code path) will be collected by cleanupOldTasks.
|
||||
for len(queue) > 0 && queue[0].Finished {
|
||||
queue = queue[1:]
|
||||
}
|
||||
@@ -620,41 +623,6 @@ func (c *WeComAIBotChannel) handleImageMessage(
|
||||
|
||||
imageURL := msg.Image.URL
|
||||
|
||||
// Download and decrypt image
|
||||
_, err := c.downloadAndDecryptImage(ctx, imageURL)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_aibot", "Failed to process image", map[string]any{
|
||||
"error": err,
|
||||
"url": imageURL,
|
||||
})
|
||||
return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
|
||||
MsgType: "stream",
|
||||
Stream: struct {
|
||||
ID string `json:"id"`
|
||||
Finish bool `json:"finish"`
|
||||
Content string `json:"content,omitempty"`
|
||||
MsgItem []struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Image *struct {
|
||||
Base64 string `json:"base64"`
|
||||
MD5 string `json:"md5"`
|
||||
} `json:"image,omitempty"`
|
||||
} `json:"msg_item,omitempty"`
|
||||
}{
|
||||
ID: c.generateStreamID(),
|
||||
Finish: true,
|
||||
Content: fmt.Sprintf(
|
||||
"Image received (URL: %s), but image messages are not yet supported",
|
||||
imageURL,
|
||||
),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Echo back the image (simple demo behavior)
|
||||
// streamID := c.generateStreamID()
|
||||
// return c.encryptImageResponse(streamID, timestamp, nonce, imageData)
|
||||
|
||||
// For now, just acknowledge receipt without echoing the image
|
||||
return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
|
||||
MsgType: "stream",
|
||||
@@ -943,70 +911,24 @@ func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string
|
||||
|
||||
// encryptMessage encrypts a plain text message for WeCom AI Bot
|
||||
func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) {
|
||||
// Decode AES key
|
||||
aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=")
|
||||
aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode AES key: %w", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(aesKey) != 32 {
|
||||
return "", fmt.Errorf("invalid AES key length: %d", len(aesKey))
|
||||
}
|
||||
|
||||
// Generate 16-byte random string
|
||||
randomBytes := make([]byte, 16)
|
||||
for i := range 16 {
|
||||
n, randErr := rand.Int(rand.Reader, big.NewInt(10))
|
||||
if randErr != nil {
|
||||
return "", fmt.Errorf("failed to generate random: %w", randErr)
|
||||
}
|
||||
randomBytes[i] = byte('0' + n.Int64())
|
||||
}
|
||||
|
||||
// Build message: random(16) + msg_len(4) + msg + receiveid
|
||||
plaintextBytes := []byte(plaintext)
|
||||
receiveidBytes := []byte(receiveid)
|
||||
|
||||
msgLen := uint32(len(plaintextBytes))
|
||||
msgLenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(msgLenBytes, msgLen)
|
||||
|
||||
// Concatenate
|
||||
var buffer bytes.Buffer
|
||||
buffer.Write(randomBytes)
|
||||
buffer.Write(msgLenBytes)
|
||||
buffer.Write(plaintextBytes)
|
||||
buffer.Write(receiveidBytes)
|
||||
|
||||
// PKCS7 padding
|
||||
plainData := buffer.Bytes()
|
||||
plainData = pkcs7Pad(plainData, blockSize)
|
||||
|
||||
// AES-CBC encrypt
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
frame, err := packWeComFrame(plaintext, receiveid)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := make([]byte, len(plainData))
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
mode.CryptBlocks(ciphertext, plainData)
|
||||
|
||||
// Base64 encode
|
||||
encoded := base64.StdEncoding.EncodeToString(ciphertext)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// pkcs7Pad adds PKCS7 padding
|
||||
func pkcs7Pad(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - (len(data) % blockSize)
|
||||
if padding == 0 {
|
||||
padding = blockSize
|
||||
// PKCS7 padding then AES-CBC encrypt
|
||||
paddedFrame := pkcs7Pad(frame, blockSize)
|
||||
ciphertext, err := encryptAESCBC(aesKey, paddedFrame)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(data, padText...)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// generateStreamID generates a random stream ID
|
||||
@@ -1060,35 +982,15 @@ func (c *WeComAIBotChannel) downloadAndDecryptImage(
|
||||
})
|
||||
|
||||
// Decode AES key
|
||||
aesKey, err := base64.StdEncoding.DecodeString(c.config.EncodingAESKey + "=")
|
||||
aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AES key: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(aesKey) != 32 {
|
||||
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
|
||||
}
|
||||
|
||||
// Decrypt image (AES-CBC)
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
// Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped)
|
||||
decryptedData, err := decryptAESCBC(aesKey, encryptedData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
if len(encryptedData)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("encrypted data size not multiple of block size")
|
||||
}
|
||||
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
|
||||
decryptedData := make([]byte, len(encryptedData))
|
||||
mode.CryptBlocks(decryptedData, encryptedData)
|
||||
|
||||
// Remove PKCS7 padding
|
||||
decryptedData, err = pkcs7Unpad(decryptedData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unpad: %w", err)
|
||||
return nil, fmt.Errorf("failed to decrypt image: %w", err)
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{
|
||||
@@ -1157,14 +1059,32 @@ func (c *WeComAIBotChannel) cleanupOldTasks() {
|
||||
// (agent had enough time to reply; it is not coming back).
|
||||
for chatID, queue := range c.chatTasks {
|
||||
filtered := queue[:0]
|
||||
for _, t := range queue {
|
||||
for i, t := range queue {
|
||||
absoluteExpired := t.CreatedTime.Before(cutoff)
|
||||
graceExpired := t.StreamClosed &&
|
||||
!t.StreamClosedAt.IsZero() &&
|
||||
t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod))
|
||||
if !t.Finished && !absoluteExpired && !graceExpired {
|
||||
if t.Finished {
|
||||
// Finished tasks should have been removed by removeTask().
|
||||
// Finding one here (especially not at position 0) means an
|
||||
// unexpected code path left it stranded, causing the queue to
|
||||
// grow silently. Log a warning so it is visible, then drop it.
|
||||
if i > 0 {
|
||||
logger.WarnCF("wecom_aibot",
|
||||
"Found stranded Finished task in the middle of chatTasks queue; "+
|
||||
"this should not happen — removeTask() should have spliced it out",
|
||||
map[string]any{
|
||||
"chat_id": chatID,
|
||||
"stream_id": t.StreamID,
|
||||
"position": i,
|
||||
})
|
||||
}
|
||||
// The task is already finished; its context was already canceled
|
||||
// by removeTask(), so no further action is required.
|
||||
continue
|
||||
} else if !absoluteExpired && !graceExpired {
|
||||
filtered = append(filtered, t)
|
||||
} else if !t.Finished {
|
||||
} else {
|
||||
t.cancel() // cancel any lingering agent goroutine
|
||||
}
|
||||
}
|
||||
|
||||
+101
-34
@@ -1,12 +1,15 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
@@ -51,64 +54,128 @@ func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (s
|
||||
return string(decoded), nil
|
||||
}
|
||||
|
||||
// Decode AES key (base64)
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
aesKey, err := decodeWeComAESKey(encodingAESKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode AES key: %w", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Decode encrypted message
|
||||
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
|
||||
// AES decrypt
|
||||
plainText, err := decryptAESCBC(aesKey, cipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return unpackWeComFrame(plainText, receiveid)
|
||||
}
|
||||
|
||||
// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
|
||||
// appended automatically) and validates that the result is exactly 32 bytes.
|
||||
// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
|
||||
func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AES key: %w", err)
|
||||
}
|
||||
if len(aesKey) != 32 {
|
||||
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
|
||||
}
|
||||
return aesKey, nil
|
||||
}
|
||||
|
||||
// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
|
||||
// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
|
||||
// plaintext to a multiple of aes.BlockSize before calling.
|
||||
func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
if len(cipherText) < aes.BlockSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// IV is the first 16 bytes of AESKey
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
plainText := make([]byte, len(cipherText))
|
||||
mode.CryptBlocks(plainText, cipherText)
|
||||
ciphertext := make([]byte, len(plaintext))
|
||||
cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Remove PKCS7 padding
|
||||
plainText, err = pkcs7Unpad(plainText)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unpad: %w", err)
|
||||
// packWeComFrame builds the WeCom wire format:
|
||||
//
|
||||
// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
|
||||
func packWeComFrame(msg, receiveid string) ([]byte, error) {
|
||||
randomBytes := make([]byte, 16)
|
||||
for i := range 16 {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(10))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random: %w", err)
|
||||
}
|
||||
randomBytes[i] = byte('0' + n.Int64())
|
||||
}
|
||||
msgBytes := []byte(msg)
|
||||
msgLenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
|
||||
var buf bytes.Buffer
|
||||
buf.Write(randomBytes)
|
||||
buf.Write(msgLenBytes)
|
||||
buf.Write(msgBytes)
|
||||
buf.WriteString(receiveid)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Parse message structure
|
||||
// Format: random(16) + msg_len(4) + msg + receiveid
|
||||
if len(plainText) < 20 {
|
||||
return "", fmt.Errorf("decrypted message too short")
|
||||
// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
|
||||
// If receiveid is non-empty it verifies the frame's trailing receiveid field.
|
||||
func unpackWeComFrame(data []byte, receiveid string) (string, error) {
|
||||
if len(data) < 20 {
|
||||
return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
msgLen := binary.BigEndian.Uint32(plainText[16:20])
|
||||
if int(msgLen) > len(plainText)-20 {
|
||||
return "", fmt.Errorf("invalid message length")
|
||||
msgLen := binary.BigEndian.Uint32(data[16:20])
|
||||
if int(msgLen) > len(data)-20 {
|
||||
return "", fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msg := plainText[20 : 20+msgLen]
|
||||
|
||||
// Verify receiveid if provided
|
||||
if receiveid != "" && len(plainText) > 20+int(msgLen) {
|
||||
actualReceiveID := string(plainText[20+msgLen:])
|
||||
msg := data[20 : 20+msgLen]
|
||||
if receiveid != "" && len(data) > 20+int(msgLen) {
|
||||
actualReceiveID := string(data[20+msgLen:])
|
||||
if actualReceiveID != receiveid {
|
||||
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
|
||||
}
|
||||
}
|
||||
|
||||
return string(msg), nil
|
||||
}
|
||||
|
||||
// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
|
||||
// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
|
||||
func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("ciphertext is empty")
|
||||
}
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
|
||||
}
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
|
||||
plaintext, err = pkcs7Unpad(plaintext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unpad: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// pkcs7Pad adds PKCS7 padding
|
||||
func pkcs7Pad(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - (len(data) % blockSize)
|
||||
if padding == 0 {
|
||||
padding = blockSize
|
||||
}
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
// pkcs7Unpad removes PKCS7 padding with validation
|
||||
func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user