mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(wecom): rebuild ai bot channel
This commit is contained in:
+11
-13
@@ -1495,18 +1495,17 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
t.Fatalf("Failed to create channel manager: %v", err)
|
||||
}
|
||||
for name, id := range map[string]string{
|
||||
"whatsapp": "rid-whatsapp",
|
||||
"telegram": "rid-telegram",
|
||||
"feishu": "rid-feishu",
|
||||
"discord": "rid-discord",
|
||||
"maixcam": "rid-maixcam",
|
||||
"qq": "rid-qq",
|
||||
"dingtalk": "rid-dingtalk",
|
||||
"slack": "rid-slack",
|
||||
"line": "rid-line",
|
||||
"onebot": "rid-onebot",
|
||||
"wecom": "rid-wecom",
|
||||
"wecom_app": "rid-wecom-app",
|
||||
"whatsapp": "rid-whatsapp",
|
||||
"telegram": "rid-telegram",
|
||||
"feishu": "rid-feishu",
|
||||
"discord": "rid-discord",
|
||||
"maixcam": "rid-maixcam",
|
||||
"qq": "rid-qq",
|
||||
"dingtalk": "rid-dingtalk",
|
||||
"slack": "rid-slack",
|
||||
"line": "rid-line",
|
||||
"onebot": "rid-onebot",
|
||||
"wecom": "rid-wecom",
|
||||
} {
|
||||
chManager.RegisterChannel(name, &fakeChannel{id: id})
|
||||
}
|
||||
@@ -1526,7 +1525,6 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
|
||||
{channel: "line", wantID: "rid-line"},
|
||||
{channel: "onebot", wantID: "rid-onebot"},
|
||||
{channel: "wecom", wantID: "rid-wecom"},
|
||||
{channel: "wecom_app", wantID: "rid-wecom-app"},
|
||||
{channel: "unknown", wantID: ""},
|
||||
}
|
||||
|
||||
|
||||
+1
-10
@@ -371,19 +371,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
m.initChannel("onebot", "OneBot")
|
||||
}
|
||||
|
||||
if channels.WeCom.Enabled && channels.WeCom.Token() != "" {
|
||||
if channels.WeCom.Enabled && channels.WeCom.BotID != "" && channels.WeCom.Secret() != "" {
|
||||
m.initChannel("wecom", "WeCom")
|
||||
}
|
||||
|
||||
if channels.WeComAIBot.Enabled && (channels.WeComAIBot.Token() != "" ||
|
||||
(channels.WeComAIBot.Secret() != "" && channels.WeComAIBot.BotID != "")) {
|
||||
m.initChannel("wecom_aibot", "WeCom AI Bot")
|
||||
}
|
||||
|
||||
if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" {
|
||||
m.initChannel("wecom_app", "WeCom App")
|
||||
}
|
||||
|
||||
if channels.Weixin.Enabled && channels.Weixin.Token() != "" {
|
||||
m.initChannel("weixin", "Weixin")
|
||||
}
|
||||
|
||||
@@ -49,15 +49,7 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
|
||||
value["token"] = ch.LINE.ChannelAccessToken()
|
||||
value["secret"] = ch.LINE.ChannelSecret()
|
||||
case "wecom":
|
||||
value["token"] = ch.WeCom.Token()
|
||||
value["key"] = ch.WeCom.EncodingAESKey()
|
||||
case "wecom_app":
|
||||
value["token"] = ch.WeComApp.Token()
|
||||
value["secret"] = ch.WeComApp.CorpSecret()
|
||||
case "wecom_aibot":
|
||||
value["token"] = ch.WeComAIBot.Token()
|
||||
value["key"] = ch.WeComAIBot.EncodingAESKey()
|
||||
value["secret"] = ch.WeComAIBot.Secret()
|
||||
value["secret"] = ch.WeCom.Secret()
|
||||
case "dingtalk":
|
||||
value["secret"] = ch.QQ.AppSecret()
|
||||
case "qq":
|
||||
@@ -156,16 +148,7 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
|
||||
newcfg.LINE.SetChannelSecret(old.LINE.ChannelSecret())
|
||||
}
|
||||
if newcfg.WeCom.Enabled {
|
||||
newcfg.WeCom.SetToken(old.WeCom.Token())
|
||||
newcfg.WeCom.SetEncodingAESKey(old.WeCom.EncodingAESKey())
|
||||
}
|
||||
if newcfg.WeComApp.Enabled {
|
||||
newcfg.WeComApp.SetToken(old.WeComApp.Token())
|
||||
newcfg.WeComApp.SetCorpSecret(old.WeComApp.CorpSecret())
|
||||
}
|
||||
if newcfg.WeComAIBot.Enabled {
|
||||
newcfg.WeComAIBot.SetToken(old.WeComAIBot.Token())
|
||||
newcfg.WeComAIBot.SetEncodingAESKey(old.WeComAIBot.EncodingAESKey())
|
||||
newcfg.WeCom.SetSecret(old.WeCom.Secret())
|
||||
}
|
||||
if newcfg.DingTalk.Enabled {
|
||||
newcfg.DingTalk.SetClientSecret(old.DingTalk.ClientSecret())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,559 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// ---- Webhook mode tests ----
|
||||
|
||||
func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) {
|
||||
t.Run("success with valid config", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
cfg.WebhookPath = "/webhook/test"
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if ch == nil {
|
||||
t.Fatal("Expected channel to be created")
|
||||
}
|
||||
if ch.Name() != "wecom_aibot" {
|
||||
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
|
||||
}
|
||||
// Webhook mode must implement WebhookHandler.
|
||||
if _, ok := ch.(channels.WebhookHandler); !ok {
|
||||
t.Error("Webhook mode channel should implement WebhookHandler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing token", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing encoding key", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing encoding key, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAIBotWebhookChannelStartStop(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Failed to start channel: %v", err)
|
||||
}
|
||||
if !ch.IsRunning() {
|
||||
t.Error("Expected channel to be running after Start")
|
||||
}
|
||||
|
||||
if err := ch.Stop(ctx); err != nil {
|
||||
t.Fatalf("Failed to stop channel: %v", err)
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("Expected channel to be stopped after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComAIBotChannelWebhookPath(t *testing.T) {
|
||||
t.Run("default path", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
wh, ok := ch.(channels.WebhookHandler)
|
||||
if !ok {
|
||||
t.Fatal("Expected channel to implement WebhookHandler")
|
||||
}
|
||||
expectedPath := "/webhook/wecom-aibot"
|
||||
if wh.WebhookPath() != expectedPath {
|
||||
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom path", func(t *testing.T) {
|
||||
customPath := "/custom/webhook"
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
cfg.WebhookPath = customPath
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
|
||||
wh, ok := ch.(channels.WebhookHandler)
|
||||
if !ok {
|
||||
t.Fatal("Expected channel to implement WebhookHandler")
|
||||
}
|
||||
if wh.WebhookPath() != customPath {
|
||||
t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAIBotChannelGetStreamResponseProcessingMessage(t *testing.T) {
|
||||
validAESKey := "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG"
|
||||
|
||||
t.Run("uses default processing message", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(validAESKey)
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
channel, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
ch, ok := channel.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
task := &streamTask{
|
||||
StreamID: "stream-default",
|
||||
ChatID: "chat-default",
|
||||
Deadline: time.Now().Add(-time.Second),
|
||||
}
|
||||
ch.streamTasks[task.StreamID] = task
|
||||
ch.chatTasks[task.ChatID] = []*streamTask{task}
|
||||
|
||||
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
|
||||
|
||||
if !resp.Stream.Finish {
|
||||
t.Fatal("Expected finished stream response after deadline")
|
||||
}
|
||||
if resp.Stream.Content != config.DefaultWeComAIBotProcessingMessage {
|
||||
t.Fatalf("Expected default processing message %q, got %q",
|
||||
config.DefaultWeComAIBotProcessingMessage, resp.Stream.Content)
|
||||
}
|
||||
if !task.StreamClosed {
|
||||
t.Fatal("Expected task stream to be marked closed")
|
||||
}
|
||||
if _, ok := ch.streamTasks[task.StreamID]; ok {
|
||||
t.Fatal("Expected closed stream task to be removed from streamTasks")
|
||||
}
|
||||
if len(ch.chatTasks[task.ChatID]) != 1 {
|
||||
t.Fatalf("Expected task to remain queued for response_url delivery, got %d entries",
|
||||
len(ch.chatTasks[task.ChatID]))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses custom processing message", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
ProcessingMessage: "Please wait a moment. The result will be delivered in a follow-up message.",
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(validAESKey)
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
channel, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
ch, ok := channel.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
task := &streamTask{
|
||||
StreamID: "stream-custom",
|
||||
ChatID: "chat-custom",
|
||||
Deadline: time.Now().Add(-time.Second),
|
||||
}
|
||||
|
||||
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
|
||||
|
||||
if resp.Stream.Content != cfg.ProcessingMessage {
|
||||
t.Fatalf("Expected custom processing message %q, got %q", cfg.ProcessingMessage, resp.Stream.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateStreamID(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
webhookCh, ok := ch.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := webhookCh.generateStreamID()
|
||||
if len(id) != 10 {
|
||||
t.Errorf("Expected stream ID length 10, got %d", len(id))
|
||||
}
|
||||
if ids[id] {
|
||||
t.Errorf("Duplicate stream ID generated: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Use a valid 43-character base64 key (企业微信标准格式)
|
||||
cfg := config.WeComAIBotConfig{}
|
||||
cfg.Enabled = true
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG") // 43 characters
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
|
||||
webhookCh, ok := ch.(*WeComAIBotChannel)
|
||||
if !ok {
|
||||
t.Fatal("Expected webhook mode channel")
|
||||
}
|
||||
|
||||
plaintext := "Hello, World!"
|
||||
receiveid := ""
|
||||
|
||||
encrypted, err := webhookCh.encryptMessage(plaintext, receiveid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt message: %v", err)
|
||||
}
|
||||
if encrypted == "" {
|
||||
t.Fatal("Encrypted message is empty")
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey(), receiveid)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt message: %v", err)
|
||||
}
|
||||
if decrypted != plaintext {
|
||||
t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSignature(t *testing.T) {
|
||||
token := "test_token"
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
encrypt := "encrypted_msg"
|
||||
|
||||
signature := computeSignature(token, timestamp, nonce, encrypt)
|
||||
if signature == "" {
|
||||
t.Error("Generated signature is empty")
|
||||
}
|
||||
if !verifySignature(token, signature, timestamp, nonce, encrypt) {
|
||||
t.Error("Generated signature does not verify correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func decodeStreamResponse(t *testing.T, ch *WeComAIBotChannel, encryptedResponse string) WeComAIBotStreamResponse {
|
||||
t.Helper()
|
||||
|
||||
var wrapped WeComAIBotEncryptedResponse
|
||||
if err := json.Unmarshal([]byte(encryptedResponse), &wrapped); err != nil {
|
||||
t.Fatalf("Failed to unmarshal encrypted response: %v", err)
|
||||
}
|
||||
|
||||
plaintext, err := decryptMessageWithVerify(wrapped.Encrypt, ch.config.EncodingAESKey(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt response: %v", err)
|
||||
}
|
||||
|
||||
var resp WeComAIBotStreamResponse
|
||||
if err := json.Unmarshal([]byte(plaintext), &resp); err != nil {
|
||||
t.Fatalf("Failed to unmarshal decrypted response: %v", err)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ---- WebSocket long-connection mode tests ----
|
||||
|
||||
func TestNewWeComAIBotChannel_WSMode(t *testing.T) {
|
||||
t.Run("success with bot_id and secret", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if ch == nil {
|
||||
t.Fatal("Expected channel to be created")
|
||||
}
|
||||
if ch.Name() != "wecom_aibot" {
|
||||
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
|
||||
}
|
||||
// WebSocket mode must NOT implement WebhookHandler.
|
||||
if _, ok := ch.(channels.WebhookHandler); ok {
|
||||
t.Error("WebSocket mode channel should NOT implement WebhookHandler")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ws mode takes priority over webhook fields", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
cfg.SetToken("also_set")
|
||||
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if _, ok := ch.(*WeComAIBotWSChannel); !ok {
|
||||
t.Error("Expected WebSocket mode channel when both BotID+secret and Token+Key are set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing bot_id", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
// Missing bot_id alone means neither WS mode nor webhook mode is fully configured.
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing bot_id, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error with missing secret", func(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
messageBus := bus.NewMessageBus()
|
||||
_, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing secret, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComAIBotWSChannelStartStop(t *testing.T) {
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch, err := NewWeComAIBotChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create channel: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start launches a background goroutine; it should not block or return an error.
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Failed to start channel: %v", err)
|
||||
}
|
||||
if !ch.IsRunning() {
|
||||
t.Error("Expected channel to be running after Start")
|
||||
}
|
||||
|
||||
// Stop should work regardless of whether the WebSocket actually connected.
|
||||
if err := ch.Stop(ctx); err != nil {
|
||||
t.Fatalf("Failed to stop channel: %v", err)
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("Expected channel to be stopped after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 200; i++ {
|
||||
id := generateRandomID(10)
|
||||
if len(id) != 10 {
|
||||
t.Errorf("Expected ID length 10, got %d", len(id))
|
||||
}
|
||||
if ids[id] {
|
||||
t.Errorf("Duplicate ID generated: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestWSGenerateID(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 200; i++ {
|
||||
id := wsGenerateID()
|
||||
if len(id) != 10 {
|
||||
t.Errorf("Expected ID length 10, got %d", len(id))
|
||||
}
|
||||
if ids[id] {
|
||||
t.Errorf("Duplicate wsGenerateID result: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Webhook streaming fallback tests ----
|
||||
|
||||
// makeWebhookChannel creates a started WeComAIBotChannel for testing.
|
||||
func makeWebhookChannel(t *testing.T) *WeComAIBotChannel {
|
||||
t.Helper()
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG")
|
||||
ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatalf("create channel: %v", err)
|
||||
}
|
||||
wc := ch.(*WeComAIBotChannel)
|
||||
wc.ctx, wc.cancel = context.WithCancel(context.Background())
|
||||
return wc
|
||||
}
|
||||
|
||||
// makeStreamTask creates and registers a streamTask for testing.
|
||||
func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask {
|
||||
t.Helper()
|
||||
task := &streamTask{
|
||||
StreamID: streamID,
|
||||
ChatID: chatID,
|
||||
Deadline: deadline,
|
||||
answerCh: make(chan string, 1),
|
||||
}
|
||||
task.ctx, task.cancel = context.WithCancel(ch.ctx)
|
||||
ch.taskMu.Lock()
|
||||
ch.streamTasks[streamID] = task
|
||||
ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task)
|
||||
ch.taskMu.Unlock()
|
||||
return task
|
||||
}
|
||||
|
||||
// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already
|
||||
// placed its answer in answerCh, getStreamResponse returns a finish=true response
|
||||
// and fully removes the task.
|
||||
func TestGetStreamResponse_ImmediateAnswer(t *testing.T) {
|
||||
ch := makeWebhookChannel(t)
|
||||
defer ch.cancel()
|
||||
|
||||
task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second))
|
||||
task.answerCh <- "hello from agent"
|
||||
|
||||
result := ch.getStreamResponse(task, "ts123", "nonce123")
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty encrypted response")
|
||||
}
|
||||
|
||||
ch.taskMu.RLock()
|
||||
_, exists := ch.streamTasks["stream-1"]
|
||||
ch.taskMu.RUnlock()
|
||||
if exists {
|
||||
t.Error("task should have been removed from streamTasks after normal finish")
|
||||
}
|
||||
if !task.Finished {
|
||||
t.Error("task.Finished should be true after normal finish")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has
|
||||
// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the
|
||||
// task alive so the response_url fallback can still deliver the answer.
|
||||
func TestGetStreamResponse_DeadlinePassed(t *testing.T) {
|
||||
ch := makeWebhookChannel(t)
|
||||
defer ch.cancel()
|
||||
|
||||
task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond))
|
||||
|
||||
result := ch.getStreamResponse(task, "ts456", "nonce456")
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty encrypted response")
|
||||
}
|
||||
|
||||
ch.taskMu.RLock()
|
||||
_, stillStreaming := ch.streamTasks["stream-2"]
|
||||
ch.taskMu.RUnlock()
|
||||
if stillStreaming {
|
||||
t.Error("task should have been removed from streamTasks after deadline")
|
||||
}
|
||||
if !task.StreamClosed {
|
||||
t.Error("task.StreamClosed should be true after deadline")
|
||||
}
|
||||
if task.Finished {
|
||||
t.Error("task.Finished must remain false: agent reply still expected via response_url")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetStreamResponse_StillPending verifies that when neither the agent has
|
||||
// replied nor the deadline has passed, getStreamResponse returns without altering
|
||||
// task state (client should poll again).
|
||||
func TestGetStreamResponse_StillPending(t *testing.T) {
|
||||
ch := makeWebhookChannel(t)
|
||||
defer ch.cancel()
|
||||
|
||||
task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second))
|
||||
|
||||
result := ch.getStreamResponse(task, "ts789", "nonce789")
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty encrypted response")
|
||||
}
|
||||
|
||||
ch.taskMu.RLock()
|
||||
_, exists := ch.streamTasks["stream-3"]
|
||||
ch.taskMu.RUnlock()
|
||||
if !exists {
|
||||
t.Error("pending task should still be in streamTasks")
|
||||
}
|
||||
if task.Finished || task.StreamClosed {
|
||||
t.Error("pending task should not be finished or stream-closed")
|
||||
}
|
||||
// Cleanup.
|
||||
ch.removeTask(task)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,295 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing.
|
||||
func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel {
|
||||
t.Helper()
|
||||
cfg := config.WeComAIBotConfig{
|
||||
Enabled: true,
|
||||
BotID: "test_bot_id",
|
||||
}
|
||||
cfg.SetSecret("test_secret")
|
||||
ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatalf("create WS channel: %v", err)
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no
|
||||
// MediaStore has been injected.
|
||||
func TestStoreWSMedia_NilStore(t *testing.T) {
|
||||
ch := newTestWSChannel(t)
|
||||
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no MediaStore is set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors
|
||||
// from the media server.
|
||||
func TestStoreWSMedia_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
ch.SetMediaStore(media.NewFileMediaStore())
|
||||
|
||||
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP 404")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear
|
||||
// error when the media server cannot be reached.
|
||||
func TestStoreWSMedia_ServerUnavailable(t *testing.T) {
|
||||
ch := newTestWSChannel(t)
|
||||
ch.SetMediaStore(media.NewFileMediaStore())
|
||||
|
||||
// Port 1 is reserved and will refuse the connection immediately.
|
||||
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unreachable server")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded,
|
||||
// a media ref is returned, and the file persists and is readable via Resolve until
|
||||
// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used.
|
||||
func TestStoreWSMedia_Success_NoAES(t *testing.T) {
|
||||
imageData := bytes.Repeat([]byte("x"), 256)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(imageData)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if ref == "" {
|
||||
t.Fatal("expected non-empty ref")
|
||||
}
|
||||
|
||||
// File must be accessible after storeWSMedia returns (no premature deletion).
|
||||
path, err := store.Resolve(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("ref should resolve: %v", err)
|
||||
}
|
||||
got, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("file should exist at %s: %v", path, err)
|
||||
}
|
||||
if !bytes.Equal(got, imageData) {
|
||||
t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData))
|
||||
}
|
||||
|
||||
// ReleaseAll must delete the file (store owns lifecycle).
|
||||
scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1")
|
||||
if err := store.ReleaseAll(scope); err != nil {
|
||||
t.Fatalf("ReleaseAll failed: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with
|
||||
// different msgIDs do not collide and each resolve to distinct files.
|
||||
func TestStoreWSMedia_MultipleMessages(t *testing.T) {
|
||||
imageA := bytes.Repeat([]byte("a"), 64)
|
||||
imageB := bytes.Repeat([]byte("b"), 64)
|
||||
|
||||
srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(imageA)
|
||||
}))
|
||||
defer srvA.Close()
|
||||
srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(imageB)
|
||||
}))
|
||||
defer srvB.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("storeWSMedia A: %v", err)
|
||||
}
|
||||
refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("storeWSMedia B: %v", err)
|
||||
}
|
||||
if refA == refB {
|
||||
t.Fatal("distinct messages must produce distinct refs")
|
||||
}
|
||||
|
||||
pathA, _ := store.Resolve(refA)
|
||||
pathB, _ := store.Resolve(refB)
|
||||
if pathA == pathB {
|
||||
t.Fatal("distinct messages must be stored at distinct paths")
|
||||
}
|
||||
|
||||
gotA, _ := os.ReadFile(pathA)
|
||||
gotB, _ := os.ReadFile(pathB)
|
||||
if !bytes.Equal(gotA, imageA) {
|
||||
t.Errorf("content mismatch for message A")
|
||||
}
|
||||
if !bytes.Equal(gotB, imageB) {
|
||||
t.Errorf("content mismatch for message B")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred
|
||||
// from the HTTP Content-Type header and the defaultExt fallback is used when the
|
||||
// type is absent or unrecognized.
|
||||
func TestStoreWSMedia_ContentTypeExt(t *testing.T) {
|
||||
tests := []struct {
|
||||
contentType string
|
||||
wantExt string
|
||||
}{
|
||||
{"image/jpeg", ".jpg"},
|
||||
{"image/png", ".png"},
|
||||
{"video/mp4", ".mp4"},
|
||||
{"application/pdf", ".pdf"},
|
||||
{"application/zip", ".zip"},
|
||||
// With parameters stripped.
|
||||
{"video/mp4; codecs=avc1", ".mp4"},
|
||||
// Unknown type → falls back to defaultExt.
|
||||
{"", ""},
|
||||
{"application/octet-stream", ""},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := wsMediaExtFromContentType(tc.contentType)
|
||||
if got != tc.wantExt {
|
||||
t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt)
|
||||
}
|
||||
}
|
||||
|
||||
// End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin.
|
||||
// The stored file should carry the .mp4 extension, not .bin.
|
||||
payload := bytes.Repeat([]byte("v"), 128)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "video/mp4")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(payload)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ch := newTestWSChannel(t)
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin")
|
||||
if err != nil {
|
||||
t.Fatalf("storeWSMedia: %v", err)
|
||||
}
|
||||
path, err := store.Resolve(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve: %v", err)
|
||||
}
|
||||
if ext := path[len(path)-4:]; ext != ".mp4" {
|
||||
t.Errorf("expected .mp4 extension from Content-Type, got %q", ext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSplitWSContent verifies byte-aware splitting of stream content.
|
||||
func TestSplitWSContent(t *testing.T) {
|
||||
t.Run("short content is not split", func(t *testing.T) {
|
||||
chunks := splitWSContent("hello", 20480)
|
||||
if len(chunks) != 1 || chunks[0] != "hello" {
|
||||
t.Fatalf("unexpected chunks: %v", chunks)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ASCII content split at byte boundary", func(t *testing.T) {
|
||||
// Build a string just over the limit.
|
||||
content := strings.Repeat("a", 20481)
|
||||
chunks := splitWSContent(content, 20480)
|
||||
if len(chunks) < 2 {
|
||||
t.Fatalf("expected >= 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
for i, c := range chunks {
|
||||
if len(c) > 20480 {
|
||||
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
|
||||
}
|
||||
}
|
||||
// Reassembled content must equal the original (possibly without leading
|
||||
// whitespace that splitWSContent trims between chunks).
|
||||
joined := strings.Join(chunks, "")
|
||||
if len(joined) < len(content)-len(chunks) {
|
||||
t.Errorf("joined length %d too short (original %d)", len(joined), len(content))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CJK content split within byte limit", func(t *testing.T) {
|
||||
// Each CJK rune is 3 bytes in UTF-8.
|
||||
// 7000 CJK chars = 21000 bytes, which exceeds 20480.
|
||||
content := strings.Repeat("\u4e2d", 7000)
|
||||
chunks := splitWSContent(content, 20480)
|
||||
if len(chunks) < 2 {
|
||||
t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks))
|
||||
}
|
||||
for i, c := range chunks {
|
||||
if len(c) > 20480 {
|
||||
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
|
||||
}
|
||||
// Every chunk must be valid UTF-8.
|
||||
if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 {
|
||||
// quick plausibility check — content was pure CJK
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter.
|
||||
func TestSplitAtByteBoundary(t *testing.T) {
|
||||
t.Run("ASCII fits in one chunk", func(t *testing.T) {
|
||||
parts := splitAtByteBoundary("hello world", 100)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) {
|
||||
// 10 CJK characters = 30 bytes; split at 20 bytes.
|
||||
s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes
|
||||
parts := splitAtByteBoundary(s, 20)
|
||||
for i, p := range parts {
|
||||
if len(p) > 20 {
|
||||
t.Errorf("part %d has %d bytes, want <= 20", i, len(p))
|
||||
}
|
||||
// Must be valid UTF-8 (no torn multi-byte sequences).
|
||||
for j, r := range p {
|
||||
if r == '\uFFFD' {
|
||||
t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,756 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomAPIBase = "https://qyapi.weixin.qq.com"
|
||||
)
|
||||
|
||||
// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用)
|
||||
type WeComAppChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComAppConfig
|
||||
client *http.Client
|
||||
accessToken string
|
||||
tokenExpiry time.Time
|
||||
tokenMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComXMLMessage represents the XML message structure from WeCom
|
||||
type WeComXMLMessage struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Content string `xml:"Content"`
|
||||
MsgId int64 `xml:"MsgId"`
|
||||
AgentID int64 `xml:"AgentID"`
|
||||
PicUrl string `xml:"PicUrl"`
|
||||
MediaId string `xml:"MediaId"`
|
||||
Format string `xml:"Format"`
|
||||
ThumbMediaId string `xml:"ThumbMediaId"`
|
||||
LocationX float64 `xml:"Location_X"`
|
||||
LocationY float64 `xml:"Location_Y"`
|
||||
Scale int `xml:"Scale"`
|
||||
Label string `xml:"Label"`
|
||||
Title string `xml:"Title"`
|
||||
Description string `xml:"Description"`
|
||||
Url string `xml:"Url"`
|
||||
Event string `xml:"Event"`
|
||||
EventKey string `xml:"EventKey"`
|
||||
}
|
||||
|
||||
// WeComTextMessage represents text message for sending
|
||||
type WeComTextMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Safe int `json:"safe,omitempty"`
|
||||
}
|
||||
|
||||
// WeComMarkdownMessage represents markdown message for sending
|
||||
type WeComMarkdownMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Markdown struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"markdown"`
|
||||
}
|
||||
|
||||
// WeComImageMessage represents image message for sending
|
||||
type WeComImageMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Image struct {
|
||||
MediaID string `json:"media_id"`
|
||||
} `json:"image"`
|
||||
}
|
||||
|
||||
// WeComAccessTokenResponse represents the access token API response
|
||||
type WeComAccessTokenResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// WeComSendMessageResponse represents the send message API response
|
||||
type WeComSendMessageResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
InvalidUser string `json:"invaliduser"`
|
||||
InvalidParty string `json:"invalidparty"`
|
||||
InvalidTag string `json:"invalidtag"`
|
||||
}
|
||||
|
||||
// PKCS7Padding adds PKCS7 padding
|
||||
type PKCS7Padding struct{}
|
||||
|
||||
// NewWeComAppChannel creates a new WeCom App channel instance
|
||||
func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) {
|
||||
if cfg.CorpID == "" || cfg.CorpSecret() == "" || cfg.AgentID == 0 {
|
||||
return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(2048),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComAppChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the channel name
|
||||
func (c *WeComAppChannel) Name() string {
|
||||
return "wecom_app"
|
||||
}
|
||||
|
||||
// Start initializes the WeCom App channel
|
||||
func (c *WeComAppChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Starting WeCom App channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Get initial access token
|
||||
if err := c.refreshAccessToken(); err != nil {
|
||||
logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Start token refresh goroutine
|
||||
go c.tokenRefreshLoop()
|
||||
|
||||
c.SetRunning(true)
|
||||
logger.InfoC("wecom_app", "WeCom App channel started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the WeCom App channel
|
||||
func (c *WeComAppChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Stopping WeCom App channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.SetRunning(false)
|
||||
logger.InfoC("wecom_app", "WeCom App channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to WeCom user proactively using access token
|
||||
func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
accessToken := c.getAccessToken()
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("no valid access token available")
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom_app", "Sending message", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content)
|
||||
}
|
||||
|
||||
// SendMedia implements the channels.MediaSender interface.
|
||||
func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
accessToken := c.getAccessToken()
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
localPath, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Map part type to WeCom media type
|
||||
var mediaType string
|
||||
switch part.Type {
|
||||
case "image":
|
||||
mediaType = "image"
|
||||
case "audio":
|
||||
mediaType = "voice"
|
||||
case "video":
|
||||
mediaType = "video"
|
||||
default:
|
||||
mediaType = "file"
|
||||
}
|
||||
|
||||
// Upload media to get media_id
|
||||
mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{
|
||||
"type": mediaType,
|
||||
"error": err.Error(),
|
||||
})
|
||||
// Fallback: send caption as text
|
||||
if part.Caption != "" {
|
||||
_ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Send media message using the media_id
|
||||
if mediaType == "image" {
|
||||
err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID)
|
||||
} else {
|
||||
// For non-image types, send as text fallback with caption
|
||||
caption := part.Caption
|
||||
if caption == "" {
|
||||
caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename)
|
||||
}
|
||||
err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// uploadMedia uploads a local file to WeCom temporary media storage.
|
||||
func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s",
|
||||
wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType))
|
||||
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
filename := filepath.Base(localPath)
|
||||
formFile, err := writer.CreateFormFile("media", filename)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create form file: %w", err)
|
||||
}
|
||||
|
||||
if _, err = io.Copy(formFile, file); err != nil {
|
||||
return "", fmt.Errorf("failed to copy file content: %w", err)
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom upload error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom upload error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
MediaID string `json:"media_id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse upload response: %w", err)
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
|
||||
}
|
||||
|
||||
return result.MediaID, nil
|
||||
}
|
||||
|
||||
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
|
||||
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom_app error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom_app API error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var sendResp WeComSendMessageResponse
|
||||
if err := json.Unmarshal(respBody, &sendResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if sendResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImageMessage sends an image message using a media_id.
|
||||
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
|
||||
msg := WeComImageMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "image",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Image.MediaID = mediaID
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// WebhookPath returns the path for registering on the shared HTTP server.
|
||||
func (c *WeComAppChannel) WebhookPath() string {
|
||||
if c.config.WebhookPath != "" {
|
||||
return c.config.WebhookPath
|
||||
}
|
||||
return "/webhook/wecom-app"
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler for the shared HTTP server.
|
||||
func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleWebhook(w, r)
|
||||
}
|
||||
|
||||
// HealthPath returns the health check endpoint path.
|
||||
func (c *WeComAppChannel) HealthPath() string {
|
||||
return "/health/wecom-app"
|
||||
}
|
||||
|
||||
// HealthHandler handles health check requests.
|
||||
func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleHealth(w, r)
|
||||
}
|
||||
|
||||
// handleWebhook handles incoming webhook requests from WeCom
|
||||
func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Log all incoming requests for debugging
|
||||
logger.DebugCF("wecom_app", "Received webhook request", map[string]any{
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
"path": r.URL.Path,
|
||||
"query": r.URL.RawQuery,
|
||||
})
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
// Handle verification request
|
||||
c.handleVerification(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
// Handle message callback
|
||||
c.handleMessageCallback(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
logger.WarnCF("wecom_app", "Method not allowed", map[string]any{
|
||||
"method": r.Method,
|
||||
})
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handleVerification handles the URL verification request from WeCom
|
||||
func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
echostr := query.Get("echostr")
|
||||
|
||||
logger.DebugCF("wecom_app", "Handling verification request", map[string]any{
|
||||
"msg_signature": msgSignature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"echostr": echostr,
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
|
||||
logger.ErrorC("wecom_app", "Missing parameters in verification request")
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) {
|
||||
logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{
|
||||
"token": c.config.Token(),
|
||||
"msg_signature": msgSignature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
})
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugC("wecom_app", "Signature verification passed")
|
||||
|
||||
// Decrypt echostr with CorpID verification
|
||||
// For WeCom App (自建应用), receiveid should be corp_id
|
||||
logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{
|
||||
"encoding_aes_key": c.config.EncodingAESKey(),
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), c.config.CorpID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{
|
||||
"error": err.Error(),
|
||||
"encoding_aes_key": c.config.EncodingAESKey,
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{
|
||||
"decrypted": decryptedEchoStr,
|
||||
})
|
||||
|
||||
// Remove BOM and whitespace as per WeCom documentation
|
||||
// The response must be plain text without quotes, BOM, or newlines
|
||||
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
|
||||
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
|
||||
w.Write([]byte(decryptedEchoStr))
|
||||
}
|
||||
|
||||
// handleMessageCallback handles incoming messages from WeCom
|
||||
func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Parse XML to get encrypted message
|
||||
var encryptedMsg struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
}
|
||||
|
||||
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid XML", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
|
||||
logger.WarnC("wecom_app", "Message signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt message with CorpID verification
|
||||
// For WeCom App (自建应用), receiveid should be corp_id
|
||||
decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), c.config.CorpID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted XML message
|
||||
var msg WeComXMLMessage
|
||||
if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message with the channel's long-lived context (not the HTTP
|
||||
// request context, which is canceled as soon as we return the response).
|
||||
go c.processMessage(c.ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom App requires response within configured timeout (default 5 seconds)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
|
||||
// processMessage processes the received message
|
||||
func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) {
|
||||
// Skip non-text messages for now (can be extended)
|
||||
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" {
|
||||
logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{
|
||||
"msg_type": msg.MsgType,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
// As per WeCom documentation, use msg_id for deduplication
|
||||
msgID := fmt.Sprintf("%d", msg.MsgId)
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := msg.FromUserName
|
||||
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
|
||||
|
||||
// Build metadata
|
||||
// WeCom App only supports direct messages (private chat)
|
||||
peer := bus.Peer{Kind: "direct", ID: senderID}
|
||||
messageID := fmt.Sprintf("%d", msg.MsgId)
|
||||
|
||||
metadata := map[string]string{
|
||||
"msg_type": msg.MsgType,
|
||||
"msg_id": fmt.Sprintf("%d", msg.MsgId),
|
||||
"agent_id": fmt.Sprintf("%d", msg.AgentID),
|
||||
"platform": "wecom_app",
|
||||
"media_id": msg.MediaId,
|
||||
"create_time": fmt.Sprintf("%d", msg.CreateTime),
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
|
||||
logger.DebugCF("wecom_app", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"msg_type": msg.MsgType,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Build sender info
|
||||
appSender := bus.SenderInfo{
|
||||
Platform: "wecom",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender)
|
||||
}
|
||||
|
||||
// tokenRefreshLoop periodically refreshes the access token
|
||||
func (c *WeComAppChannel) tokenRefreshLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.refreshAccessToken(); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshAccessToken gets a new access token from WeCom API
|
||||
func (c *WeComAppChannel) refreshAccessToken() error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
|
||||
wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret()))
|
||||
|
||||
resp, err := http.Get(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to request access token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp WeComAccessTokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode)
|
||||
}
|
||||
|
||||
c.tokenMu.Lock()
|
||||
c.accessToken = tokenResp.AccessToken
|
||||
c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early
|
||||
c.tokenMu.Unlock()
|
||||
|
||||
logger.DebugC("wecom_app", "Access token refreshed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccessToken returns the current valid access token
|
||||
func (c *WeComAppChannel) getAccessToken() string {
|
||||
c.tokenMu.RLock()
|
||||
defer c.tokenMu.RUnlock()
|
||||
|
||||
if time.Now().After(c.tokenExpiry) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return c.accessToken
|
||||
}
|
||||
|
||||
// sendTextMessage sends a text message to a user.
|
||||
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
|
||||
msg := WeComTextMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "text",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Text.Content = content
|
||||
return c.sendWeComMessage(ctx, accessToken, msg)
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]any{
|
||||
"status": "ok",
|
||||
"running": c.IsRunning(),
|
||||
"has_token": c.getAccessToken() != "",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,499 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人)
|
||||
// Uses webhook callback mode - simpler than WeCom App but only supports passive replies
|
||||
type WeComBotChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComConfig
|
||||
client *http.Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
|
||||
type WeComBotMessage struct {
|
||||
MsgID string `json:"msgid"`
|
||||
AIBotID string `json:"aibotid"`
|
||||
ChatID string `json:"chatid"` // Session ID, only present for group chats
|
||||
ChatType string `json:"chattype"` // "single" for DM, "group" for group chat
|
||||
From struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"from"`
|
||||
ResponseURL string `json:"response_url"`
|
||||
MsgType string `json:"msgtype"` // text, image, voice, file, mixed
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Image struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"image"`
|
||||
Voice struct {
|
||||
Content string `json:"content"` // Voice to text content
|
||||
} `json:"voice"`
|
||||
File struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"file"`
|
||||
Mixed struct {
|
||||
MsgItem []struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Image struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"image"`
|
||||
} `json:"msg_item"`
|
||||
} `json:"mixed"`
|
||||
Quote struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
} `json:"quote"`
|
||||
}
|
||||
|
||||
// WeComBotReplyMessage represents the reply message structure
|
||||
type WeComBotReplyMessage struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// NewWeComBotChannel creates a new WeCom Bot channel instance
|
||||
func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) {
|
||||
if cfg.Token() == "" || cfg.WebhookURL == "" {
|
||||
return nil, fmt.Errorf("wecom token and webhook_url are required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(2048),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
// Client timeout must be >= the configured ReplyTimeout so the
|
||||
// per-request context deadline is always the effective limit.
|
||||
clientTimeout := 30 * time.Second
|
||||
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
|
||||
clientTimeout = d
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WeComBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the channel name
|
||||
func (c *WeComBotChannel) Name() string {
|
||||
return "wecom"
|
||||
}
|
||||
|
||||
// Start initializes the WeCom Bot channel
|
||||
func (c *WeComBotChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Starting WeCom Bot channel...")
|
||||
|
||||
// Cancel the context created in the constructor to avoid a resource leak.
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
c.SetRunning(true)
|
||||
logger.InfoC("wecom", "WeCom Bot channel started")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the WeCom Bot channel
|
||||
func (c *WeComBotChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Stopping WeCom Bot channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.SetRunning(false)
|
||||
logger.InfoC("wecom", "WeCom Bot channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to WeCom user via webhook API
|
||||
// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message
|
||||
// For delayed responses, we use the webhook URL
|
||||
func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom", "Sending message via webhook", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
return c.sendWebhookReply(ctx, msg.ChatID, msg.Content)
|
||||
}
|
||||
|
||||
// WebhookPath returns the path for registering on the shared HTTP server.
|
||||
func (c *WeComBotChannel) WebhookPath() string {
|
||||
if c.config.WebhookPath != "" {
|
||||
return c.config.WebhookPath
|
||||
}
|
||||
return "/webhook/wecom"
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler for the shared HTTP server.
|
||||
func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleWebhook(w, r)
|
||||
}
|
||||
|
||||
// HealthPath returns the health check endpoint path.
|
||||
func (c *WeComBotChannel) HealthPath() string {
|
||||
return "/health/wecom"
|
||||
}
|
||||
|
||||
// HealthHandler handles health check requests.
|
||||
func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
c.handleHealth(w, r)
|
||||
}
|
||||
|
||||
// handleWebhook handles incoming webhook requests from WeCom
|
||||
func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
// Handle verification request
|
||||
c.handleVerification(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
// Handle message callback
|
||||
c.handleMessageCallback(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handleVerification handles the URL verification request from WeCom
|
||||
func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
echostr := query.Get("echostr")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) {
|
||||
logger.WarnC("wecom", "Signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt echostr
|
||||
// For AIBOT (智能机器人), receiveid should be empty string ""
|
||||
// Reference: https://developer.work.weixin.qq.com/document/path/101033
|
||||
decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), "")
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove BOM and whitespace as per WeCom documentation
|
||||
// The response must be plain text without quotes, BOM, or newlines
|
||||
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
|
||||
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
|
||||
w.Write([]byte(decryptedEchoStr))
|
||||
}
|
||||
|
||||
// handleMessageCallback handles incoming messages from WeCom
|
||||
func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Parse XML to get encrypted message
|
||||
var encryptedMsg struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
}
|
||||
|
||||
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid XML", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
|
||||
logger.WarnC("wecom", "Message signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt message
|
||||
// For AIBOT (智能机器人), receiveid should be empty string ""
|
||||
// Reference: https://developer.work.weixin.qq.com/document/path/101033
|
||||
decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), "")
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted JSON message (AIBOT uses JSON format)
|
||||
var msg WeComBotMessage
|
||||
if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message with the channel's long-lived context (not the HTTP
|
||||
// request context, which is canceled as soon as we return the response).
|
||||
go c.processMessage(c.ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom Bot requires response within configured timeout (default 5 seconds)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
|
||||
// processMessage processes the received message
|
||||
func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) {
|
||||
// Skip unsupported message types
|
||||
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" &&
|
||||
msg.MsgType != "mixed" {
|
||||
logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{
|
||||
"msg_type": msg.MsgType,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
msgID := msg.MsgID
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := msg.From.UserID
|
||||
|
||||
// Determine if this is a group chat or direct message
|
||||
// ChatType: "single" for DM, "group" for group chat
|
||||
isGroupChat := msg.ChatType == "group"
|
||||
|
||||
var chatID, peerKind, peerID string
|
||||
if isGroupChat {
|
||||
// Group chat: use ChatID as chatID and peer_id
|
||||
chatID = msg.ChatID
|
||||
peerKind = "group"
|
||||
peerID = msg.ChatID
|
||||
} else {
|
||||
// Direct message: use senderID as chatID and peer_id
|
||||
chatID = senderID
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
// Extract content based on message type
|
||||
var content string
|
||||
switch msg.MsgType {
|
||||
case "text":
|
||||
content = msg.Text.Content
|
||||
case "voice":
|
||||
content = msg.Voice.Content // Voice to text content
|
||||
case "mixed":
|
||||
// For mixed messages, concatenate text items
|
||||
for _, item := range msg.Mixed.MsgItem {
|
||||
if item.MsgType == "text" {
|
||||
content += item.Text.Content
|
||||
}
|
||||
}
|
||||
case "image", "file":
|
||||
// For image and file, we don't have text content
|
||||
content = ""
|
||||
}
|
||||
|
||||
// Build metadata
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
// In group chats, apply unified group trigger filtering
|
||||
if isGroupChat {
|
||||
respond, cleaned := c.ShouldRespondInGroup(false, content)
|
||||
if !respond {
|
||||
return
|
||||
}
|
||||
content = cleaned
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"msg_type": msg.MsgType,
|
||||
"msg_id": msg.MsgID,
|
||||
"platform": "wecom",
|
||||
"response_url": msg.ResponseURL,
|
||||
}
|
||||
if isGroupChat {
|
||||
metadata["chat_id"] = msg.ChatID
|
||||
metadata["sender_id"] = senderID
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"msg_type": msg.MsgType,
|
||||
"peer_kind": peerKind,
|
||||
"is_group_chat": isGroupChat,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Build sender info
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "wecom",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender)
|
||||
}
|
||||
|
||||
// sendWebhookReply sends a reply using the webhook URL
|
||||
func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error {
|
||||
reply := WeComBotReplyMessage{
|
||||
MsgType: "text",
|
||||
}
|
||||
reply.Text.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(reply)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal reply: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading webhook error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("webhook API error: %s", string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Check response
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]any{
|
||||
"status": "ok",
|
||||
"running": c.IsRunning(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
@@ -1,734 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// generateTestAESKey generates a valid test AES key
|
||||
func generateTestAESKey() string {
|
||||
// AES key needs to be 32 bytes (256 bits) for AES-256
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
// Return base64 encoded key without padding
|
||||
return base64.StdEncoding.EncodeToString(key)[:43]
|
||||
}
|
||||
|
||||
// encryptTestMessage encrypts a message for testing (AIBOT JSON format)
|
||||
func encryptTestMessage(message, aesKey string) (string, error) {
|
||||
// Decode AES key
|
||||
key, err := base64.StdEncoding.DecodeString(aesKey + "=")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prepare message: random(16) + msg_len(4) + msg + receiveid
|
||||
random := make([]byte, 0, 16)
|
||||
for i := range 16 {
|
||||
random = append(random, byte(i))
|
||||
}
|
||||
|
||||
msgBytes := []byte(message)
|
||||
receiveID := []byte("test_aibot_id")
|
||||
|
||||
msgLen := uint32(len(msgBytes))
|
||||
lenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(lenBytes, msgLen)
|
||||
|
||||
plainText := append(random, lenBytes...)
|
||||
plainText = append(plainText, msgBytes...)
|
||||
plainText = append(plainText, receiveID...)
|
||||
|
||||
// PKCS7 padding
|
||||
blockSize := aes.BlockSize
|
||||
padding := blockSize - len(plainText)%blockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
plainText = append(plainText, padText...)
|
||||
|
||||
// Encrypt
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize])
|
||||
cipherText := make([]byte, len(plainText))
|
||||
mode.CryptBlocks(cipherText, plainText)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// generateSignature generates a signature for testing
|
||||
func generateSignature(token, timestamp, nonce, msgEncrypt string) string {
|
||||
params := []string{token, timestamp, nonce, msgEncrypt}
|
||||
sort.Strings(params)
|
||||
str := strings.Join(params, "")
|
||||
hash := sha1.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
func TestNewWeComBotChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing token", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
_, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing webhook_url", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = ""
|
||||
_, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing webhook_url, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.AllowFrom = []string{"user1", "user2"}
|
||||
ch, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "wecom" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "wecom")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotChannelIsAllowed(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("empty allowlist allows all", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.AllowFrom = []string{}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("any_user") {
|
||||
t.Error("empty allowlist should allow all users")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allowlist restricts users", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.AllowFrom = []string{"allowed_user"}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("allowed_user") {
|
||||
t.Error("allowed user should pass allowlist check")
|
||||
}
|
||||
if ch.IsAllowed("blocked_user") {
|
||||
t.Error("non-allowed user should be blocked")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotVerifySignature(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid signature", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
msgEncrypt := "test_message"
|
||||
expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt)
|
||||
|
||||
if !verifySignature(ch.config.Token(), expectedSig, timestamp, nonce, msgEncrypt) {
|
||||
t.Error("valid signature should pass verification")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
msgEncrypt := "test_message"
|
||||
|
||||
if verifySignature(ch.config.Token(), "invalid_sig", timestamp, nonce, msgEncrypt) {
|
||||
t.Error("invalid signature should fail verification")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
|
||||
cfgEmpty := config.WeComConfig{}
|
||||
cfgEmpty.SetToken("")
|
||||
cfgEmpty.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
chEmpty := &WeComBotChannel{
|
||||
config: cfgEmpty,
|
||||
}
|
||||
|
||||
if verifySignature(chEmpty.config.Token(), "any_sig", "any_ts", "any_nonce", "any_msg") {
|
||||
t.Error("empty token should reject verification (fail-closed)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotDecryptMessage(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("decrypt without AES key", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey("")
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
// Without AES key, message should be base64 decoded only
|
||||
plainText := "hello world"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
|
||||
|
||||
result, err := decryptMessage(encoded, ch.config.EncodingAESKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != plainText {
|
||||
t.Errorf("decryptMessage() = %q, want %q", result, plainText)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decrypt with AES key", func(t *testing.T) {
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey(aesKey)
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
originalMsg := "<xml><Content>Hello</Content></xml>"
|
||||
encrypted, err := encryptTestMessage(originalMsg, aesKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt test message: %v", err)
|
||||
}
|
||||
|
||||
result, err := decryptMessage(encrypted, ch.config.EncodingAESKey())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != originalMsg {
|
||||
t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid base64", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey("")
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
_, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey())
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid base64, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid AES key", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
cfg.SetEncodingAESKey("invalid_key")
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
_, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey())
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid AES key, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotPKCS7Unpad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "valid padding 3 bytes",
|
||||
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
|
||||
expected: []byte("hello"),
|
||||
},
|
||||
{
|
||||
name: "valid padding 16 bytes (full block)",
|
||||
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("123456789012345"),
|
||||
},
|
||||
{
|
||||
name: "invalid padding larger than data",
|
||||
input: []byte{20},
|
||||
expected: nil, // should return error
|
||||
},
|
||||
{
|
||||
name: "invalid padding zero",
|
||||
input: append([]byte("test"), byte(0)),
|
||||
expected: nil, // should return error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs7Unpad(tt.input)
|
||||
if tt.expected == nil {
|
||||
// This case should return an error
|
||||
if err == nil {
|
||||
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotHandleVerification(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(aesKey)
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid verification request", func(t *testing.T) {
|
||||
echostr := "test_echostr_123"
|
||||
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != echostr {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), echostr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing parameters", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
echostr := "test_echostr"
|
||||
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.SetEncodingAESKey(aesKey)
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid group message callback", func(t *testing.T) {
|
||||
w := runBotMessageCallback(t, `{
|
||||
"msgid": "test_msg_id_456",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user456"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello Group"}
|
||||
}`)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing parameters", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid XML", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, "")
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
strings.NewReader("invalid xml"),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: "encrypted_data",
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotProcessMessage(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("process direct text message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_123",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
msg.Text.Content = "Hello World"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("process group text message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_456",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatID: "group_chat_id_123",
|
||||
ChatType: "group",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.From.UserID = "user456"
|
||||
msg.Text.Content = "Hello Group"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("process voice message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_789",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "voice",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
msg.Voice.Content = "Voice message text"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("skip unsupported message type", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_000",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "video",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleWebhook(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("GET request calls verification", func(t *testing.T) {
|
||||
echostr := "test_echostr"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(echostr))
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encoded)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("POST request calls message callback", func(t *testing.T) {
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: base64.StdEncoding.EncodeToString([]byte("test")),
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
// Should not be method not allowed
|
||||
if w.Code == http.StatusMethodNotAllowed {
|
||||
t.Error("POST request should not return Method Not Allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported method", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleHealth(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{}
|
||||
cfg.SetToken("test_token")
|
||||
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleHealth(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want %q", contentType, "application/json")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "status") || !strings.Contains(body, "running") {
|
||||
t.Errorf("response body should contain status and running fields, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotReplyMessage(t *testing.T) {
|
||||
msg := WeComBotReplyMessage{
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.Text.Content = "Hello World"
|
||||
|
||||
if msg.MsgType != "text" {
|
||||
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
|
||||
}
|
||||
if msg.Text.Content != "Hello World" {
|
||||
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotMessageStructure(t *testing.T) {
|
||||
jsonData := `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`
|
||||
|
||||
var msg WeComBotMessage
|
||||
err := json.Unmarshal([]byte(jsonData), &msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
if msg.MsgID != "test_msg_id_123" {
|
||||
t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123")
|
||||
}
|
||||
if msg.AIBotID != "test_aibot_id" {
|
||||
t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id")
|
||||
}
|
||||
if msg.ChatID != "group_chat_id_123" {
|
||||
t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123")
|
||||
}
|
||||
if msg.ChatType != "group" {
|
||||
t.Errorf("ChatType = %q, want %q", msg.ChatType, "group")
|
||||
}
|
||||
if msg.From.UserID != "user123" {
|
||||
t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123")
|
||||
}
|
||||
if msg.MsgType != "text" {
|
||||
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
|
||||
}
|
||||
if msg.Text.Content != "Hello World" {
|
||||
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
|
||||
}
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// blockSize is the PKCS7 block size used by WeCom (32)
|
||||
const blockSize = 32
|
||||
|
||||
// computeSignature computes the WeCom message signature from the given parameters.
|
||||
// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest.
|
||||
func computeSignature(token, timestamp, nonce, encrypt string) string {
|
||||
params := []string{token, timestamp, nonce, encrypt}
|
||||
sort.Strings(params)
|
||||
str := strings.Join(params, "")
|
||||
hash := sha1.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// verifySignature verifies the message signature for WeCom
|
||||
// This is a common function used by both WeCom Bot and WeCom App
|
||||
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
|
||||
}
|
||||
|
||||
// decryptMessage decrypts the encrypted message using AES
|
||||
// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id
|
||||
func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) {
|
||||
return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "")
|
||||
}
|
||||
|
||||
// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid
|
||||
// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification.
|
||||
func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) {
|
||||
if encodingAESKey == "" {
|
||||
// No encryption, return as is (base64 decode)
|
||||
decoded, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(decoded), nil
|
||||
}
|
||||
|
||||
aesKey, err := decodeWeComAESKey(encodingAESKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
|
||||
plainText, err := decryptAESCBC(aesKey, cipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return unpackWeComFrame(plainText, receiveid)
|
||||
}
|
||||
|
||||
// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
|
||||
// appended automatically) and validates that the result is exactly 32 bytes.
|
||||
// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
|
||||
func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AES key: %w", err)
|
||||
}
|
||||
if len(aesKey) != 32 {
|
||||
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
|
||||
}
|
||||
return aesKey, nil
|
||||
}
|
||||
|
||||
// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
|
||||
// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
|
||||
// plaintext to a multiple of aes.BlockSize before calling.
|
||||
func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
ciphertext := make([]byte, len(plaintext))
|
||||
cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// packWeComFrame builds the WeCom wire format:
|
||||
//
|
||||
// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
|
||||
func packWeComFrame(msg, receiveid string) ([]byte, error) {
|
||||
randomBytes := make([]byte, 16)
|
||||
for i := range 16 {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(10))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random: %w", err)
|
||||
}
|
||||
randomBytes[i] = byte('0' + n.Int64())
|
||||
}
|
||||
msgBytes := []byte(msg)
|
||||
msgLenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
|
||||
var buf bytes.Buffer
|
||||
buf.Write(randomBytes)
|
||||
buf.Write(msgLenBytes)
|
||||
buf.Write(msgBytes)
|
||||
buf.WriteString(receiveid)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
|
||||
// If receiveid is non-empty it verifies the frame's trailing receiveid field.
|
||||
func unpackWeComFrame(data []byte, receiveid string) (string, error) {
|
||||
if len(data) < 20 {
|
||||
return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
|
||||
}
|
||||
msgLen := binary.BigEndian.Uint32(data[16:20])
|
||||
if int(msgLen) > len(data)-20 {
|
||||
return "", fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
msg := data[20 : 20+msgLen]
|
||||
if receiveid != "" && len(data) > 20+int(msgLen) {
|
||||
actualReceiveID := string(data[20+msgLen:])
|
||||
if actualReceiveID != receiveid {
|
||||
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
|
||||
}
|
||||
}
|
||||
return string(msg), nil
|
||||
}
|
||||
|
||||
// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
|
||||
// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
|
||||
func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("ciphertext is empty")
|
||||
}
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
|
||||
}
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
|
||||
plaintext, err = pkcs7Unpad(plaintext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unpad: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// pkcs7Pad adds PKCS7 padding
|
||||
func pkcs7Pad(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - (len(data) % blockSize)
|
||||
if padding == 0 {
|
||||
padding = blockSize
|
||||
}
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
// pkcs7Unpad removes PKCS7 padding with validation
|
||||
func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
padding := int(data[len(data)-1])
|
||||
// WeCom uses 32-byte block size for PKCS7 padding
|
||||
if padding == 0 || padding > blockSize {
|
||||
return nil, fmt.Errorf("invalid padding size: %d", padding)
|
||||
}
|
||||
if padding > len(data) {
|
||||
return nil, fmt.Errorf("padding size larger than data")
|
||||
}
|
||||
// Verify all padding bytes
|
||||
for i := range padding {
|
||||
if data[len(data)-1-i] != byte(padding) {
|
||||
return nil, fmt.Errorf("invalid padding byte at position %d", i)
|
||||
}
|
||||
}
|
||||
return data[:len(data)-padding], nil
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import "sync"
|
||||
|
||||
const wecomMaxProcessedMessages = 1000
|
||||
|
||||
// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
|
||||
// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
|
||||
// messages without causing "amnesia cliffs" when the limit is reached.
|
||||
type MessageDeduplicator struct {
|
||||
mu sync.Mutex
|
||||
msgs map[string]bool
|
||||
ring []string
|
||||
idx int
|
||||
max int
|
||||
}
|
||||
|
||||
// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
|
||||
func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
|
||||
if maxEntries <= 0 {
|
||||
maxEntries = wecomMaxProcessedMessages
|
||||
}
|
||||
return &MessageDeduplicator{
|
||||
msgs: make(map[string]bool, maxEntries),
|
||||
ring: make([]string, maxEntries),
|
||||
max: maxEntries,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
|
||||
func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 1. Check for duplicate
|
||||
if d.msgs[msgID] {
|
||||
return false
|
||||
}
|
||||
|
||||
// 2. Evict the oldest message at our current ring position (if any)
|
||||
oldestID := d.ring[d.idx]
|
||||
if oldestID != "" {
|
||||
delete(d.msgs, oldestID)
|
||||
}
|
||||
|
||||
// 3. Store the new message
|
||||
d.msgs[msgID] = true
|
||||
d.ring[d.idx] = msgID
|
||||
|
||||
// 4. Advance the circle queue index
|
||||
d.idx = (d.idx + 1) % d.max
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("first message should be accepted")
|
||||
}
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); ok {
|
||||
t.Fatalf("duplicate message should be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
const goroutines = 64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
results := make(chan bool, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results <- d.MarkMessageProcessed("msg-concurrent")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successes := 0
|
||||
for ok := range results {
|
||||
if ok {
|
||||
successes++
|
||||
}
|
||||
}
|
||||
|
||||
if successes != 1 {
|
||||
t.Fatalf("expected exactly 1 successful mark, got %d", successes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
|
||||
// Create a deduplicator with a very small capacity to test eviction easily.
|
||||
capacity := 3
|
||||
d := NewMessageDeduplicator(capacity)
|
||||
|
||||
// Fill the queue.
|
||||
d.MarkMessageProcessed("msg-1")
|
||||
d.MarkMessageProcessed("msg-2")
|
||||
d.MarkMessageProcessed("msg-3")
|
||||
|
||||
// At this point, the queue is full. msg-1 is the oldest.
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// This should evict msg-1 and add msg-4.
|
||||
if ok := d.MarkMessageProcessed("msg-4"); !ok {
|
||||
t.Fatalf("msg-4 should be accepted")
|
||||
}
|
||||
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// msg-1 should now be forgotten (evicted).
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("msg-1 should be accepted again because it was evicted")
|
||||
}
|
||||
|
||||
// msg-2 should have been evicted when we added msg-1 back.
|
||||
if ok := d.MarkMessageProcessed("msg-2"); !ok {
|
||||
t.Fatalf("msg-2 should be accepted again because it was evicted")
|
||||
}
|
||||
}
|
||||
@@ -8,12 +8,6 @@ import (
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComBotChannel(cfg.Channels.WeCom, b)
|
||||
})
|
||||
channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComAppChannel(cfg.Channels.WeComApp, b)
|
||||
})
|
||||
channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b)
|
||||
return NewChannel(cfg.Channels.WeCom, b)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func decodeMediaAESKey(value string) ([]byte, error) {
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
key, err := base64.StdEncoding.DecodeString(value)
|
||||
if err == nil && len(key) == 32 {
|
||||
return key, nil
|
||||
}
|
||||
key, err = base64.StdEncoding.DecodeString(value + "=")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode AES key: %w", err)
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("invalid AES key length %d", len(key))
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func decryptAESCBC(key, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("ciphertext is empty")
|
||||
}
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
iv := key[:aes.BlockSize]
|
||||
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs7Unpad(plaintext)
|
||||
}
|
||||
|
||||
func pkcs7Unpad(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, fmt.Errorf("empty plaintext")
|
||||
}
|
||||
padding := int(data[len(data)-1])
|
||||
if padding == 0 || padding > 32 || padding > len(data) {
|
||||
return nil, fmt.Errorf("invalid padding size %d", padding)
|
||||
}
|
||||
for i := 0; i < padding; i++ {
|
||||
if data[len(data)-1-i] != byte(padding) {
|
||||
return nil, fmt.Errorf("invalid padding byte")
|
||||
}
|
||||
}
|
||||
return data[:len(data)-padding], nil
|
||||
}
|
||||
|
||||
func inferMediaExt(contentType, fallback string) string {
|
||||
contentType = normalizeWeComContentType(contentType)
|
||||
switch contentType {
|
||||
case "image/jpeg", "image/jpg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "application/pdf":
|
||||
return ".pdf"
|
||||
case "video/mp4":
|
||||
return ".mp4"
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeWeComContentType(value string) string {
|
||||
value = strings.ToLower(strings.TrimSpace(value))
|
||||
if idx := strings.Index(value, ";"); idx >= 0 {
|
||||
value = strings.TrimSpace(value[:idx])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func isGenericWeComContentType(value string) bool {
|
||||
switch normalizeWeComContentType(value) {
|
||||
case "", "application/octet-stream", "binary/octet-stream", "application/unknown", "application/binary":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeWeComFilename(name string) string {
|
||||
name = filepath.Base(strings.TrimSpace(name))
|
||||
if name == "." || name == "/" || name == "" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func candidateWeComFilename(resourceURL, contentDisposition, fallbackName string) string {
|
||||
if _, params, err := mime.ParseMediaType(contentDisposition); err == nil {
|
||||
if name := sanitizeWeComFilename(params["filename"]); name != "" {
|
||||
return name
|
||||
}
|
||||
if name := sanitizeWeComFilename(params["filename*"]); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
if parsed, err := url.Parse(resourceURL); err == nil {
|
||||
query := parsed.Query()
|
||||
for _, key := range []string{"filename", "file_name", "name"} {
|
||||
if name := sanitizeWeComFilename(query.Get(key)); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
if name := sanitizeWeComFilename(parsed.Path); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
return sanitizeWeComFilename(fallbackName)
|
||||
}
|
||||
|
||||
func detectWeComFiletype(data []byte) (string, string) {
|
||||
kind, err := filetype.Match(data)
|
||||
if err != nil || kind == filetype.Unknown {
|
||||
return "", ""
|
||||
}
|
||||
ext := ""
|
||||
if kind.Extension != "" {
|
||||
ext = "." + strings.ToLower(kind.Extension)
|
||||
}
|
||||
return normalizeWeComContentType(kind.MIME.Value), ext
|
||||
}
|
||||
|
||||
func detectWeComMediaMetadata(data []byte, fallbackName, fallbackContentType, resourceURL, contentDisposition string) (string, string) {
|
||||
filename := candidateWeComFilename(resourceURL, contentDisposition, fallbackName)
|
||||
if filename == "" {
|
||||
filename = "media"
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
contentType := normalizeWeComContentType(fallbackContentType)
|
||||
detectedType, detectedExt := detectWeComFiletype(data)
|
||||
|
||||
if ext != "" && isGenericWeComContentType(contentType) {
|
||||
if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" {
|
||||
contentType = byExt
|
||||
}
|
||||
}
|
||||
|
||||
if detectedType != "" {
|
||||
switch {
|
||||
case contentType == "":
|
||||
contentType = detectedType
|
||||
case isGenericWeComContentType(contentType):
|
||||
contentType = detectedType
|
||||
case strings.HasPrefix(detectedType, "image/") && !strings.HasPrefix(contentType, "image/"):
|
||||
contentType = detectedType
|
||||
case strings.HasPrefix(detectedType, "audio/") && !strings.HasPrefix(contentType, "audio/"):
|
||||
contentType = detectedType
|
||||
case strings.HasPrefix(detectedType, "video/") && !strings.HasPrefix(contentType, "video/"):
|
||||
contentType = detectedType
|
||||
}
|
||||
}
|
||||
|
||||
if contentType == "" && ext != "" {
|
||||
contentType = normalizeWeComContentType(mime.TypeByExtension(ext))
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = normalizeWeComContentType(http.DetectContentType(data))
|
||||
}
|
||||
|
||||
if ext == "" {
|
||||
ext = detectedExt
|
||||
}
|
||||
if ext == "" && contentType != "" {
|
||||
if exts, err := mime.ExtensionsByType(contentType); err == nil && len(exts) > 0 {
|
||||
ext = strings.ToLower(exts[0])
|
||||
}
|
||||
}
|
||||
|
||||
if filepath.Ext(filename) == "" && ext != "" {
|
||||
filename += ext
|
||||
}
|
||||
return filename, contentType
|
||||
}
|
||||
|
||||
func (c *WeComChannel) storeRemoteMedia(
|
||||
ctx context.Context,
|
||||
scope, msgID, resourceURL, aesKey, fallbackExt string,
|
||||
) (string, error) {
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("no media store available")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
resp, err := c.mediaClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("download media: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("download media returned HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
const maxSize = 20 << 20
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read media: %w", err)
|
||||
}
|
||||
if len(data) > maxSize {
|
||||
return "", fmt.Errorf("media too large")
|
||||
}
|
||||
|
||||
if aesKey != "" {
|
||||
key, keyErr := decodeMediaAESKey(aesKey)
|
||||
if keyErr != nil {
|
||||
return "", keyErr
|
||||
}
|
||||
data, err = decryptAESCBC(key, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt media: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filename, contentType := detectWeComMediaMetadata(
|
||||
data,
|
||||
msgID+fallbackExt,
|
||||
resp.Header.Get("Content-Type"),
|
||||
resourceURL,
|
||||
resp.Header.Get("Content-Disposition"),
|
||||
)
|
||||
ext := filepath.Ext(filename)
|
||||
if ext == "" {
|
||||
ext = inferMediaExt(contentType, fallbackExt)
|
||||
}
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
return "", fmt.Errorf("mkdir media dir: %w", mkdirErr)
|
||||
}
|
||||
tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
if _, writeErr := tmpFile.Write(data); writeErr != nil {
|
||||
tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("write temp file: %w", writeErr)
|
||||
}
|
||||
if closeErr := tmpFile.Close(); closeErr != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("close temp file: %w", closeErr)
|
||||
}
|
||||
|
||||
ref, err := store.Store(tmpPath, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: contentType,
|
||||
Source: "wecom",
|
||||
CleanupPolicy: media.CleanupPolicyDeleteOnCleanup,
|
||||
}, scope)
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", err
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
basechannels "github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestStoreRemoteMedia_DetectsJPEGContentTypeFromBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" +
|
||||
"//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" +
|
||||
"//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k="
|
||||
|
||||
jpegData := decodeTestBase64(t, jpegBase64)
|
||||
store := media.NewFileMediaStore()
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
|
||||
mediaClient: &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/octet-stream"}},
|
||||
Body: io.NopCloser(bytes.NewReader(jpegData)),
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeRemoteMedia(context.Background(), "test-scope", "msg-1", "https://wecom.example/media", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("storeRemoteMedia returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = store.ReleaseAll("test-scope")
|
||||
})
|
||||
|
||||
_, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve media ref: %v", err)
|
||||
}
|
||||
if meta.ContentType != "image/jpeg" {
|
||||
t.Fatalf("expected image/jpeg content type, got %q", meta.ContentType)
|
||||
}
|
||||
if !strings.HasSuffix(meta.Filename, ".jpg") && !strings.HasSuffix(meta.Filename, ".jpeg") {
|
||||
t.Fatalf("expected jpeg filename, got %q", meta.Filename)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectWeComMediaMetadata_UsesFallbackExtensionWhenBodyUnknown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
filename, contentType := detectWeComMediaMetadata([]byte("not a real image"), "msg-2.pdf", "", "", "")
|
||||
if filename != "msg-2.pdf" {
|
||||
t.Fatalf("expected fallback filename to be preserved, got %q", filename)
|
||||
}
|
||||
if contentType != "application/pdf" {
|
||||
t.Fatalf("expected application/pdf from fallback extension, got %q", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRemoteMedia_PreservesSuffixFromURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
docxLikeData := []byte("PK\x03\x04fake office payload")
|
||||
store := media.NewFileMediaStore()
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
|
||||
mediaClient: &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/octet-stream"}},
|
||||
Body: io.NopCloser(bytes.NewReader(docxLikeData)),
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeRemoteMedia(
|
||||
context.Background(),
|
||||
"test-scope",
|
||||
"msg-docx",
|
||||
"https://wecom.example/media/report.docx?signature=1",
|
||||
"",
|
||||
".bin",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("storeRemoteMedia returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = store.ReleaseAll("test-scope")
|
||||
})
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve media ref: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(meta.Filename, ".docx") {
|
||||
t.Fatalf("expected docx filename, got %q", meta.Filename)
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(localPath), ".docx") {
|
||||
t.Fatalf("expected docx temp path, got %q", localPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRemoteMedia_PreservesSuffixFromContentDisposition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pptxLikeData := []byte("PK\x03\x04fake office payload")
|
||||
store := media.NewFileMediaStore()
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
|
||||
mediaClient: &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
"Content-Disposition": []string{`attachment; filename="slides.pptx"`},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(pptxLikeData)),
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
ref, err := ch.storeRemoteMedia(
|
||||
context.Background(),
|
||||
"test-scope",
|
||||
"msg-pptx",
|
||||
"https://wecom.example/media/download",
|
||||
"",
|
||||
".bin",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("storeRemoteMedia returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = store.ReleaseAll("test-scope")
|
||||
})
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve media ref: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(meta.Filename, ".pptx") {
|
||||
t.Fatalf("expected pptx filename, got %q", meta.Filename)
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(localPath), ".pptx") {
|
||||
t.Fatalf("expected pptx temp path, got %q", localPath)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeTestBase64(t *testing.T, value string) []byte {
|
||||
t.Helper()
|
||||
|
||||
data, err := io.ReadAll(base64.NewDecoder(base64.StdEncoding, strings.NewReader(value)))
|
||||
if err != nil {
|
||||
t.Fatalf("decode base64 fixture: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package wecom
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
const (
|
||||
wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com"
|
||||
wecomCmdSubscribe = "aibot_subscribe"
|
||||
wecomCmdPing = "ping"
|
||||
wecomCmdMsgCallback = "aibot_msg_callback"
|
||||
wecomCmdEventCallback = "aibot_event_callback"
|
||||
wecomCmdRespondMsg = "aibot_respond_msg"
|
||||
wecomCmdSendMsg = "aibot_send_msg"
|
||||
wecomMaxContentBytes = 20480
|
||||
)
|
||||
|
||||
type wecomEnvelope struct {
|
||||
Cmd string `json:"cmd,omitempty"`
|
||||
Headers wecomHeaders `json:"headers"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
ErrCode int `json:"errcode,omitempty"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
type wecomHeaders struct {
|
||||
ReqID string `json:"req_id,omitempty"`
|
||||
}
|
||||
|
||||
type wecomCommand struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Headers wecomHeaders `json:"headers"`
|
||||
Body any `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
type wecomSendMsgBody struct {
|
||||
ChatID string `json:"chatid"`
|
||||
ChatType uint32 `json:"chat_type,omitempty"`
|
||||
MsgType string `json:"msgtype"`
|
||||
Markdown *wecomMarkdownContent `json:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
type wecomRespondMsgBody struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Stream *wecomStreamContent `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type wecomStreamContent struct {
|
||||
ID string `json:"id"`
|
||||
Finish bool `json:"finish"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type wecomMarkdownContent struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type wecomIncomingMessage struct {
|
||||
MsgID string `json:"msgid"`
|
||||
AIBotID string `json:"aibotid"`
|
||||
ChatID string `json:"chatid,omitempty"`
|
||||
ChatType string `json:"chattype,omitempty"`
|
||||
From struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"from"`
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
Image *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"image,omitempty"`
|
||||
File *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"file,omitempty"`
|
||||
Video *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"video,omitempty"`
|
||||
Voice *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"voice,omitempty"`
|
||||
Mixed *struct {
|
||||
MsgItem []struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
Image *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"image,omitempty"`
|
||||
File *struct {
|
||||
URL string `json:"url"`
|
||||
AESKey string `json:"aeskey,omitempty"`
|
||||
} `json:"file,omitempty"`
|
||||
} `json:"msg_item"`
|
||||
} `json:"mixed,omitempty"`
|
||||
Quote *struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
} `json:"quote,omitempty"`
|
||||
Event *struct {
|
||||
EventType string `json:"eventtype"`
|
||||
} `json:"event,omitempty"`
|
||||
}
|
||||
|
||||
func incomingChatID(msg wecomIncomingMessage) string {
|
||||
if msg.ChatID != "" {
|
||||
return msg.ChatID
|
||||
}
|
||||
return msg.From.UserID
|
||||
}
|
||||
|
||||
func incomingChatTypeCode(kind string) uint32 {
|
||||
if kind == "group" {
|
||||
return 2
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type wecomRoute struct {
|
||||
ReqID string `json:"req_id"`
|
||||
ChatID string `json:"chat_id"`
|
||||
ChatType uint32 `json:"chat_type"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type reqIDStore struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
routes map[string]wecomRoute
|
||||
}
|
||||
|
||||
func newReqIDStore(path string) *reqIDStore {
|
||||
if path == "" {
|
||||
path = defaultReqIDStorePath()
|
||||
}
|
||||
s := &reqIDStore{
|
||||
path: path,
|
||||
routes: make(map[string]wecomRoute),
|
||||
}
|
||||
_ = s.load()
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultReqIDStorePath() string {
|
||||
if home, err := os.UserHomeDir(); err == nil && home != "" {
|
||||
return filepath.Join(home, ".picoclaw", "wecom", "reqid-store.json")
|
||||
}
|
||||
return filepath.Join(os.TempDir(), "picoclaw-wecom-reqid-store.json")
|
||||
}
|
||||
|
||||
func (s *reqIDStore) Put(chatID, reqID string, chatType uint32, ttl time.Duration) error {
|
||||
if reqID == "" || chatID == "" {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deleteExpiredLocked(time.Now())
|
||||
s.routes[chatID] = wecomRoute{
|
||||
ReqID: reqID,
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *reqIDStore) Get(chatID string) (wecomRoute, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deleteExpiredLocked(time.Now())
|
||||
route, ok := s.routes[chatID]
|
||||
return route, ok
|
||||
}
|
||||
|
||||
func (s *reqIDStore) Delete(chatID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.routes, chatID)
|
||||
return s.saveLocked()
|
||||
}
|
||||
|
||||
func (s *reqIDStore) load() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var routes map[string]wecomRoute
|
||||
if err := json.Unmarshal(data, &routes); err != nil {
|
||||
return err
|
||||
}
|
||||
s.routes = routes
|
||||
s.deleteExpiredLocked(time.Now())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *reqIDStore) deleteExpiredLocked(now time.Time) {
|
||||
for chatID, route := range s.routes {
|
||||
if !route.ExpiresAt.IsZero() && now.After(route.ExpiresAt) {
|
||||
delete(s.routes, chatID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *reqIDStore) saveLocked() error {
|
||||
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(s.routes, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, data, 0o600)
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReqIDStorePersistsRoutes(t *testing.T) {
|
||||
storePath := filepath.Join(t.TempDir(), "reqids.json")
|
||||
store := newReqIDStore(storePath)
|
||||
if err := store.Put("chat-1", "req-1", 2, time.Hour); err != nil {
|
||||
t.Fatalf("Put() error = %v", err)
|
||||
}
|
||||
|
||||
reloaded := newReqIDStore(storePath)
|
||||
route, ok := reloaded.Get("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected persisted route to be loaded")
|
||||
}
|
||||
if route.ChatID != "chat-1" || route.ReqID != "req-1" || route.ChatType != 2 {
|
||||
t.Fatalf("loaded route = %+v", route)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,777 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomConnectTimeout = 15 * time.Second
|
||||
wecomCommandTimeout = 10 * time.Second
|
||||
wecomHeartbeatInterval = 30 * time.Second
|
||||
wecomStreamMaxDuration = 5*time.Minute + 30*time.Second
|
||||
wecomRouteTTL = 30 * time.Minute
|
||||
wecomMediaTimeout = 30 * time.Second
|
||||
wecomRecentMessageMax = 1000
|
||||
)
|
||||
|
||||
type WeComChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.WeComConfig
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
conn *websocket.Conn
|
||||
connMu sync.Mutex
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pending map[string]chan wecomEnvelope
|
||||
|
||||
turnsMu sync.Mutex
|
||||
turns map[string][]wecomTurn
|
||||
|
||||
recent *recentMessageSet
|
||||
routes *reqIDStore
|
||||
mediaClient *http.Client
|
||||
commandSend func(wecomCommand, time.Duration) error
|
||||
}
|
||||
|
||||
type wecomTurn struct {
|
||||
ReqID string
|
||||
ChatID string
|
||||
ChatType uint32
|
||||
StreamID string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type recentMessageSet struct {
|
||||
mu sync.Mutex
|
||||
seen map[string]struct{}
|
||||
ring []string
|
||||
idx int
|
||||
}
|
||||
|
||||
func newRecentMessageSet(capacity int) *recentMessageSet {
|
||||
if capacity <= 0 {
|
||||
capacity = wecomRecentMessageMax
|
||||
}
|
||||
return &recentMessageSet{
|
||||
seen: make(map[string]struct{}, capacity),
|
||||
ring: make([]string, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *recentMessageSet) Mark(id string) bool {
|
||||
if id == "" {
|
||||
return true
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.seen[id]; ok {
|
||||
return false
|
||||
}
|
||||
if old := s.ring[s.idx]; old != "" {
|
||||
delete(s.seen, old)
|
||||
}
|
||||
s.ring[s.idx] = id
|
||||
s.idx = (s.idx + 1) % len(s.ring)
|
||||
s.seen[id] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChannel, error) {
|
||||
if cfg.BotID == "" || cfg.Secret() == "" {
|
||||
return nil, fmt.Errorf("wecom bot_id and secret are required")
|
||||
}
|
||||
if cfg.WebSocketURL == "" {
|
||||
cfg.WebSocketURL = wecomDefaultWebSocketURL
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"wecom",
|
||||
cfg,
|
||||
messageBus,
|
||||
cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(wecomMaxContentBytes),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
ch := &WeComChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
pending: make(map[string]chan wecomEnvelope),
|
||||
turns: make(map[string][]wecomTurn),
|
||||
recent: newRecentMessageSet(wecomRecentMessageMax),
|
||||
routes: newReqIDStore(""),
|
||||
mediaClient: &http.Client{Timeout: wecomMediaTimeout},
|
||||
}
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) Name() string { return "wecom" }
|
||||
|
||||
func (c *WeComChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Starting WeCom channel...")
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
c.SetRunning(true)
|
||||
go c.connectLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) Stop(_ context.Context) error {
|
||||
logger.InfoC("wecom", "Stopping WeCom channel...")
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
c.connMu.Lock()
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
c.clearTurns()
|
||||
c.SetRunning(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if content == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if turn, ok := c.getTurn(msg.ChatID); ok {
|
||||
if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration {
|
||||
if err := c.sendStreamReply(turn, content); err == nil {
|
||||
c.deleteTurn(msg.ChatID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
c.deleteTurn(msg.ChatID)
|
||||
}
|
||||
|
||||
if route, ok := c.routes.Get(msg.ChatID); ok {
|
||||
if err := c.sendActivePush(route.ChatID, route.ChatType, content); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.sendActivePush(msg.ChatID, 0, content); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
var parts []string
|
||||
for _, part := range msg.Parts {
|
||||
switch {
|
||||
case part.Caption != "":
|
||||
parts = append(parts, part.Caption)
|
||||
case part.Filename != "":
|
||||
parts = append(parts, fmt.Sprintf("[media: %s]", part.Filename))
|
||||
default:
|
||||
parts = append(parts, "[media attachments are not yet supported]")
|
||||
}
|
||||
}
|
||||
return c.Send(ctx, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: strings.Join(parts, "\n"),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *WeComChannel) connectLoop() {
|
||||
backoff := time.Second
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err := c.runConnection(); err != nil {
|
||||
logger.WarnCF("wecom", "WeCom connection lost", map[string]any{
|
||||
"error": err.Error(),
|
||||
"backoff": backoff.String(),
|
||||
})
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
if backoff < time.Minute {
|
||||
backoff *= 2
|
||||
if backoff > time.Minute {
|
||||
backoff = time.Minute
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) runConnection() error {
|
||||
dialCtx, cancel := context.WithTimeout(c.ctx, wecomConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, resp, err := websocket.DefaultDialer.DialContext(dialCtx, c.config.WebSocketURL, nil)
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
||||
}
|
||||
|
||||
c.connMu.Lock()
|
||||
c.conn = conn
|
||||
c.connMu.Unlock()
|
||||
defer func() {
|
||||
c.connMu.Lock()
|
||||
if c.conn == conn {
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMu.Unlock()
|
||||
_ = conn.Close()
|
||||
c.clearTurns()
|
||||
}()
|
||||
|
||||
readErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
readErrCh <- c.readLoop(conn)
|
||||
}()
|
||||
|
||||
if writeErr := c.writeAndWait(conn, wecomCommand{
|
||||
Cmd: wecomCmdSubscribe,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: map[string]string{
|
||||
"bot_id": c.config.BotID,
|
||||
"secret": c.config.Secret(),
|
||||
},
|
||||
}, wecomCommandTimeout); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
heartbeatDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(heartbeatDone)
|
||||
c.heartbeatLoop(conn)
|
||||
}()
|
||||
|
||||
err = <-readErrCh
|
||||
_ = conn.Close()
|
||||
<-heartbeatDone
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *WeComChannel) heartbeatLoop(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(wecomHeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.writeAndWait(conn, wecomCommand{
|
||||
Cmd: wecomCmdPing,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
}, wecomCommandTimeout); err != nil {
|
||||
logger.WarnCF("wecom", "Heartbeat failed", map[string]any{"error": err.Error()})
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) readLoop(conn *websocket.Conn) error {
|
||||
for {
|
||||
_, raw, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
||||
}
|
||||
}
|
||||
|
||||
var env wecomEnvelope
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to parse WebSocket message", map[string]any{"error": err.Error()})
|
||||
continue
|
||||
}
|
||||
|
||||
if env.Cmd == "" && env.Headers.ReqID != "" {
|
||||
c.pendingMu.Lock()
|
||||
ch, ok := c.pending[env.Headers.ReqID]
|
||||
if ok {
|
||||
delete(c.pending, env.Headers.ReqID)
|
||||
}
|
||||
c.pendingMu.Unlock()
|
||||
if ok {
|
||||
ch <- env
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleEnvelope(env)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) handleEnvelope(env wecomEnvelope) {
|
||||
switch env.Cmd {
|
||||
case wecomCmdMsgCallback:
|
||||
c.handleMessageCallback(env)
|
||||
case wecomCmdEventCallback:
|
||||
c.handleEventCallback(env)
|
||||
default:
|
||||
logger.DebugCF("wecom", "Ignoring unsupported WeCom command", map[string]any{"cmd": env.Cmd})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) handleEventCallback(env wecomEnvelope) {
|
||||
var msg wecomIncomingMessage
|
||||
if err := json.Unmarshal(env.Body, &msg); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to parse WeCom event callback", map[string]any{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) handleMessageCallback(env wecomEnvelope) {
|
||||
var msg wecomIncomingMessage
|
||||
if err := json.Unmarshal(env.Body, &msg); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to parse WeCom message callback", map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !c.recent.Mark(msg.MsgID) {
|
||||
return
|
||||
}
|
||||
|
||||
reqID := env.Headers.ReqID
|
||||
if reqID == "" {
|
||||
logger.WarnC("wecom", "WeCom message callback missing req_id")
|
||||
return
|
||||
}
|
||||
if msg.Event != nil && msg.Event.EventType != "" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.dispatchIncoming(reqID, msg); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to dispatch WeCom message", map[string]any{
|
||||
"req_id": reqID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
_ = c.respondImmediate(reqID, "The WeCom message could not be processed.")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) error {
|
||||
senderID := msg.From.UserID
|
||||
if senderID == "" {
|
||||
senderID = "unknown"
|
||||
}
|
||||
actualChatID := incomingChatID(msg)
|
||||
chatType := incomingChatTypeCode(msg.ChatType)
|
||||
peerKind := "direct"
|
||||
if msg.ChatType == "group" {
|
||||
peerKind = "group"
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "wecom",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
|
||||
DisplayName: senderID,
|
||||
}
|
||||
|
||||
var (
|
||||
content string
|
||||
quoteText string
|
||||
mediaRefs []string
|
||||
err error
|
||||
)
|
||||
scope := channels.BuildMediaScope("wecom", actualChatID, msg.MsgID)
|
||||
switch msg.MsgType {
|
||||
case "text":
|
||||
if msg.Text != nil {
|
||||
content = strings.TrimSpace(msg.Text.Content)
|
||||
}
|
||||
case "voice":
|
||||
if msg.Voice != nil {
|
||||
content = strings.TrimSpace(msg.Voice.Content)
|
||||
}
|
||||
case "image":
|
||||
content = "[image]"
|
||||
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
||||
url: msg.Image.URL,
|
||||
aesKey: msg.Image.AESKey,
|
||||
}, "image", ".jpg")
|
||||
case "file":
|
||||
content = "[file]"
|
||||
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
||||
url: msg.File.URL,
|
||||
aesKey: msg.File.AESKey,
|
||||
}, "file", ".bin")
|
||||
case "video":
|
||||
content = "[video]"
|
||||
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
|
||||
url: msg.Video.URL,
|
||||
aesKey: msg.Video.AESKey,
|
||||
}, "video", ".mp4")
|
||||
case "mixed":
|
||||
content, mediaRefs, err = c.collectMixedMedia(c.ctx, scope, msg)
|
||||
default:
|
||||
return c.respondImmediate(reqID, "Unsupported WeCom message type: "+msg.MsgType)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Quote != nil && msg.Quote.Text != nil {
|
||||
quoteText = strings.TrimSpace(msg.Quote.Text.Content)
|
||||
if content == "" {
|
||||
content = quoteText
|
||||
}
|
||||
}
|
||||
if content == "" && len(mediaRefs) == 0 {
|
||||
return c.respondImmediate(reqID, "The WeCom message did not contain usable content.")
|
||||
}
|
||||
|
||||
turn := wecomTurn{
|
||||
ReqID: reqID,
|
||||
ChatID: actualChatID,
|
||||
ChatType: chatType,
|
||||
StreamID: randomID(10),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
c.queueTurn(actualChatID, turn)
|
||||
if err := c.routes.Put(actualChatID, reqID, chatType, wecomRouteTTL); err != nil {
|
||||
logger.WarnCF("wecom", "Failed to persist req_id route", map[string]any{
|
||||
"chat_id": actualChatID,
|
||||
"req_id": reqID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
opening := ""
|
||||
if c.config.SendThinkingMessage {
|
||||
opening = "Processing..."
|
||||
}
|
||||
if err := c.sendStreamChunk(turn, false, opening); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
|
||||
metadata := map[string]string{
|
||||
"channel": "wecom",
|
||||
"req_id": reqID,
|
||||
"chat_id": actualChatID,
|
||||
"chat_type": msg.ChatType,
|
||||
"msg_id": msg.MsgID,
|
||||
"msg_type": msg.MsgType,
|
||||
}
|
||||
if quoteText != "" {
|
||||
metadata["quote_text"] = quoteText
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) collectSingleMedia(
|
||||
ctx context.Context,
|
||||
scope, msgID string,
|
||||
payload interface {
|
||||
GetURL() string
|
||||
GetAESKey() string
|
||||
},
|
||||
label, fallbackExt string,
|
||||
) ([]string, error) {
|
||||
if payload == nil || payload.GetURL() == "" {
|
||||
return nil, fmt.Errorf("%s payload is empty", label)
|
||||
}
|
||||
ref, err := c.storeRemoteMedia(ctx, scope, msgID, payload.GetURL(), payload.GetAESKey(), fallbackExt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []string{ref}, nil
|
||||
}
|
||||
|
||||
type mediaPayload struct {
|
||||
url string
|
||||
aesKey string
|
||||
}
|
||||
|
||||
func (p *mediaPayload) GetURL() string { return p.url }
|
||||
func (p *mediaPayload) GetAESKey() string { return p.aesKey }
|
||||
|
||||
func (c *WeComChannel) collectMixedMedia(
|
||||
ctx context.Context,
|
||||
scope string,
|
||||
msg wecomIncomingMessage,
|
||||
) (string, []string, error) {
|
||||
if msg.Mixed == nil {
|
||||
return "", nil, fmt.Errorf("mixed message is empty")
|
||||
}
|
||||
|
||||
var textParts []string
|
||||
var refs []string
|
||||
for idx, item := range msg.Mixed.MsgItem {
|
||||
switch item.MsgType {
|
||||
case "text":
|
||||
if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" {
|
||||
textParts = append(textParts, strings.TrimSpace(item.Text.Content))
|
||||
}
|
||||
case "image":
|
||||
if item.Image != nil && item.Image.URL != "" {
|
||||
ref, err := c.storeRemoteMedia(
|
||||
ctx,
|
||||
scope,
|
||||
fmt.Sprintf("%s-%d", msg.MsgID, idx),
|
||||
item.Image.URL,
|
||||
item.Image.AESKey,
|
||||
".jpg",
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
case "file":
|
||||
if item.File != nil && item.File.URL != "" {
|
||||
ref, err := c.storeRemoteMedia(
|
||||
ctx,
|
||||
scope,
|
||||
fmt.Sprintf("%s-%d", msg.MsgID, idx),
|
||||
item.File.URL,
|
||||
item.File.AESKey,
|
||||
".bin",
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.Join(textParts, "\n")
|
||||
if content == "" && len(refs) > 0 {
|
||||
content = "[media]"
|
||||
}
|
||||
return content, refs, nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) respondImmediate(reqID, content string) error {
|
||||
turn := wecomTurn{
|
||||
ReqID: reqID,
|
||||
StreamID: randomID(10),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return c.sendStreamChunk(turn, true, content)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error {
|
||||
chunks := splitContent(content, wecomMaxContentBytes)
|
||||
for idx, chunk := range chunks {
|
||||
if err := c.sendStreamChunk(turn, idx == len(chunks)-1, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error {
|
||||
return c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdRespondMsg,
|
||||
Headers: wecomHeaders{ReqID: turn.ReqID},
|
||||
Body: wecomRespondMsgBody{
|
||||
MsgType: "stream",
|
||||
Stream: &wecomStreamContent{
|
||||
ID: turn.StreamID,
|
||||
Finish: finish,
|
||||
Content: content,
|
||||
},
|
||||
},
|
||||
}, wecomCommandTimeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error {
|
||||
if strings.TrimSpace(chatID) == "" {
|
||||
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
|
||||
}
|
||||
for _, chunk := range splitContent(content, wecomMaxContentBytes) {
|
||||
if err := c.sendCommand(wecomCommand{
|
||||
Cmd: wecomCmdSendMsg,
|
||||
Headers: wecomHeaders{ReqID: randomID(10)},
|
||||
Body: wecomSendMsgBody{
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
MsgType: "markdown",
|
||||
Markdown: &wecomMarkdownContent{Content: chunk},
|
||||
},
|
||||
}, wecomCommandTimeout); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error {
|
||||
if c.commandSend != nil {
|
||||
return c.commandSend(cmd, timeout)
|
||||
}
|
||||
return c.writeCurrent(cmd, timeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) writeCurrent(cmd wecomCommand, timeout time.Duration) error {
|
||||
c.connMu.Lock()
|
||||
conn := c.conn
|
||||
c.connMu.Unlock()
|
||||
if conn == nil {
|
||||
return fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary)
|
||||
}
|
||||
return c.writeAndWait(conn, cmd, timeout)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error {
|
||||
if cmd.Headers.ReqID == "" {
|
||||
cmd.Headers.ReqID = randomID(10)
|
||||
}
|
||||
waitCh := make(chan wecomEnvelope, 1)
|
||||
c.pendingMu.Lock()
|
||||
c.pending[cmd.Headers.ReqID] = waitCh
|
||||
c.pendingMu.Unlock()
|
||||
defer func() {
|
||||
c.pendingMu.Lock()
|
||||
delete(c.pending, cmd.Headers.ReqID)
|
||||
c.pendingMu.Unlock()
|
||||
}()
|
||||
|
||||
data, err := json.Marshal(cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", channels.ErrSendFailed, err)
|
||||
}
|
||||
c.connMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.connMu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case env := <-waitCh:
|
||||
if env.ErrCode != 0 {
|
||||
return fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg)
|
||||
}
|
||||
return nil
|
||||
case <-timer.C:
|
||||
return fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary)
|
||||
case <-c.ctx.Done():
|
||||
return c.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeComChannel) getTurn(chatID string) (wecomTurn, bool) {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
queue := c.turns[chatID]
|
||||
if len(queue) == 0 {
|
||||
return wecomTurn{}, false
|
||||
}
|
||||
return queue[0], true
|
||||
}
|
||||
|
||||
func (c *WeComChannel) deleteTurn(chatID string) {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
queue := c.turns[chatID]
|
||||
if len(queue) <= 1 {
|
||||
delete(c.turns, chatID)
|
||||
return
|
||||
}
|
||||
c.turns[chatID] = queue[1:]
|
||||
}
|
||||
|
||||
func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) {
|
||||
c.turnsMu.Lock()
|
||||
defer c.turnsMu.Unlock()
|
||||
c.turns[chatID] = append(c.turns[chatID], turn)
|
||||
}
|
||||
|
||||
func (c *WeComChannel) clearTurns() {
|
||||
c.turnsMu.Lock()
|
||||
c.turns = make(map[string][]wecomTurn)
|
||||
c.turnsMu.Unlock()
|
||||
}
|
||||
|
||||
func randomID(n int) string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
if n <= 0 {
|
||||
n = 10
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
for i := range buf {
|
||||
v, _ := rand.Int(rand.Reader, big.NewInt(int64(len(alphabet))))
|
||||
buf[i] = alphabet[v.Int64()]
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func splitContent(content string, maxBytes int) []string {
|
||||
if content == "" {
|
||||
return []string{""}
|
||||
}
|
||||
if len(content) <= maxBytes {
|
||||
return []string{content}
|
||||
}
|
||||
chunks := channels.SplitMessage(content, maxBytes)
|
||||
var result []string
|
||||
for _, chunk := range chunks {
|
||||
if len(chunk) <= maxBytes {
|
||||
result = append(result, chunk)
|
||||
continue
|
||||
}
|
||||
for len(chunk) > maxBytes {
|
||||
end := maxBytes
|
||||
for end > 0 && chunk[end]>>6 == 0b10 {
|
||||
end--
|
||||
}
|
||||
if end == 0 {
|
||||
end = maxBytes
|
||||
}
|
||||
result = append(result, chunk[:end])
|
||||
chunk = strings.TrimLeft(chunk[end:], " \t\r\n")
|
||||
}
|
||||
if chunk != "" {
|
||||
result = append(result, chunk)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := newTestWeComChannel(t, messageBus)
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) error {
|
||||
commands = append(commands, cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := wecomIncomingMessage{
|
||||
MsgID: "msg-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: "direct",
|
||||
MsgType: "text",
|
||||
Text: &struct {
|
||||
Content string `json:"content"`
|
||||
}{Content: "hello"},
|
||||
}
|
||||
msg.From.UserID = "user-1"
|
||||
|
||||
if err := ch.dispatchIncoming("req-1", msg); err != nil {
|
||||
t.Fatalf("dispatchIncoming() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case inbound := <-messageBus.InboundChan():
|
||||
if inbound.ChatID != "chat-1" {
|
||||
t.Fatalf("inbound ChatID = %q, want chat-1", inbound.ChatID)
|
||||
}
|
||||
if inbound.MessageID != "msg-1" {
|
||||
t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID)
|
||||
}
|
||||
if inbound.Peer.ID != "chat-1" {
|
||||
t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID)
|
||||
}
|
||||
if inbound.Metadata["req_id"] != "req-1" {
|
||||
t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"])
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected inbound message to be published")
|
||||
}
|
||||
|
||||
turn, ok := ch.getTurn("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected queued turn for chat-1")
|
||||
}
|
||||
if turn.ReqID != "req-1" {
|
||||
t.Fatalf("turn.ReqID = %q, want req-1", turn.ReqID)
|
||||
}
|
||||
|
||||
route, ok := ch.routes.Get("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected persisted route for chat-1")
|
||||
}
|
||||
if route.ReqID != "req-1" || route.ChatType != 1 {
|
||||
t.Fatalf("route = %+v", route)
|
||||
}
|
||||
|
||||
if len(commands) != 1 {
|
||||
t.Fatalf("expected 1 opening command, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdRespondMsg {
|
||||
t.Fatalf("opening command = %q, want %q", commands[0].Cmd, wecomCmdRespondMsg)
|
||||
}
|
||||
if commands[0].Headers.ReqID != "req-1" {
|
||||
t.Fatalf("opening req_id = %q, want req-1", commands[0].Headers.ReqID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := newTestWeComChannel(t, bus.NewMessageBus())
|
||||
ch.SetRunning(true)
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-1",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-1",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
ch.queueTurn("chat-1", wecomTurn{
|
||||
ReqID: "req-2",
|
||||
ChatID: "chat-1",
|
||||
ChatType: 1,
|
||||
StreamID: "stream-2",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
if err := ch.routes.Put("chat-1", "req-2", 1, time.Hour); err != nil {
|
||||
t.Fatalf("Put() error = %v", err)
|
||||
}
|
||||
|
||||
var commands []wecomCommand
|
||||
ch.commandSend = func(cmd wecomCommand, _ time.Duration) error {
|
||||
commands = append(commands, cmd)
|
||||
if len(commands) == 1 && cmd.Cmd == wecomCmdRespondMsg {
|
||||
return errors.New("stream send failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
Channel: "wecom",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}); err != nil {
|
||||
t.Fatalf("Send() error = %v", err)
|
||||
}
|
||||
|
||||
if len(commands) != 2 {
|
||||
t.Fatalf("expected 2 commands, got %d", len(commands))
|
||||
}
|
||||
if commands[0].Cmd != wecomCmdRespondMsg || commands[0].Headers.ReqID != "req-1" {
|
||||
t.Fatalf("first command = %+v", commands[0])
|
||||
}
|
||||
if commands[1].Cmd != wecomCmdSendMsg {
|
||||
t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdSendMsg)
|
||||
}
|
||||
body, ok := commands[1].Body.(wecomSendMsgBody)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected send body type %T", commands[1].Body)
|
||||
}
|
||||
if body.ChatID != "chat-1" {
|
||||
t.Fatalf("send chatid = %q, want chat-1", body.ChatID)
|
||||
}
|
||||
if body.ChatType != 1 {
|
||||
t.Fatalf("send chat_type = %d, want 1", body.ChatType)
|
||||
}
|
||||
|
||||
nextTurn, ok := ch.getTurn("chat-1")
|
||||
if !ok {
|
||||
t.Fatal("expected second turn to remain queued")
|
||||
}
|
||||
if nextTurn.ReqID != "req-2" {
|
||||
t.Fatalf("next queued req_id = %q, want req-2", nextTurn.ReqID)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.WeComConfig{BotID: "bot-1"}
|
||||
cfg.SetSecret("secret-1")
|
||||
ch, err := NewChannel(cfg, messageBus)
|
||||
if err != nil {
|
||||
t.Fatalf("NewChannel() error = %v", err)
|
||||
}
|
||||
ch.ctx = context.Background()
|
||||
ch.routes = newReqIDStore(filepath.Join(t.TempDir(), "reqids.json"))
|
||||
return ch
|
||||
}
|
||||
+22
-181
@@ -321,10 +321,7 @@ type AgentDefaults struct {
|
||||
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
DefaultWeComAIBotProcessingMessage = "⏳ Processing, please wait. The results will be sent shortly."
|
||||
)
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
func (d *AgentDefaults) GetMaxMediaSize() int {
|
||||
if d.MaxMediaSize > 0 {
|
||||
@@ -364,9 +361,7 @@ type ChannelsConfig struct {
|
||||
Matrix MatrixConfig `json:"matrix"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
WeCom WeComConfig `json:"wecom"`
|
||||
WeComApp WeComAppConfig `json:"wecom_app"`
|
||||
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
|
||||
WeCom WeComConfig `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
|
||||
Weixin WeixinConfig `json:"weixin"`
|
||||
Pico PicoConfig `json:"pico"`
|
||||
PicoClient PicoClientConfig `json:"pico_client"`
|
||||
@@ -678,136 +673,28 @@ func (c *OneBotConfig) SetAccessToken(token string) {
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
type WeComGroupConfig struct {
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from,omitempty"`
|
||||
}
|
||||
|
||||
type WeComConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
|
||||
token string
|
||||
encodingAESKey string
|
||||
WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
BotID string `json:"bot_id" env:"BOT_ID"`
|
||||
secret string
|
||||
WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"`
|
||||
SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
}
|
||||
|
||||
// Token returns the WeCom token
|
||||
func (c *WeComConfig) Token() string {
|
||||
return c.token
|
||||
}
|
||||
|
||||
// SetToken sets the WeCom token
|
||||
func (c *WeComConfig) SetToken(token string) {
|
||||
c.token = token
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// EncodingAESKey returns the WeCom encoding AES key
|
||||
func (c *WeComConfig) EncodingAESKey() string {
|
||||
return c.encodingAESKey
|
||||
}
|
||||
|
||||
// SetEncodingAESKey sets the WeCom encoding AES key
|
||||
func (c *WeComConfig) SetEncodingAESKey(key string) {
|
||||
c.encodingAESKey = key
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
type WeComAppConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
|
||||
CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
|
||||
corpSecret string
|
||||
AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
|
||||
token string
|
||||
encodingAESKey string
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
}
|
||||
|
||||
// CorpSecret returns the corporate secret for WeCom app
|
||||
func (c *WeComAppConfig) CorpSecret() string {
|
||||
return c.corpSecret
|
||||
}
|
||||
|
||||
// SetCorpSecret sets the corporate secret for WeCom app
|
||||
func (c *WeComAppConfig) SetCorpSecret(secret string) {
|
||||
c.corpSecret = secret
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// Token returns the webhook token for WeCom app
|
||||
func (c *WeComAppConfig) Token() string {
|
||||
return c.token
|
||||
}
|
||||
|
||||
// SetToken sets the webhook token for WeCom app
|
||||
func (c *WeComAppConfig) SetToken(token string) {
|
||||
c.token = token
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// EncodingAESKey returns the encoding AES key for WeCom app
|
||||
func (c *WeComAppConfig) EncodingAESKey() string {
|
||||
return c.encodingAESKey
|
||||
}
|
||||
|
||||
// SetEncodingAESKey sets the encoding AES key for WeCom app
|
||||
func (c *WeComAppConfig) SetEncodingAESKey(key string) {
|
||||
c.encodingAESKey = key
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
type WeComAIBotConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
|
||||
BotID string `json:"bot_id,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_BOT_ID"`
|
||||
secret string
|
||||
token string
|
||||
encodingAESKey string
|
||||
WebhookPath string `json:"webhook_path,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
|
||||
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
|
||||
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
|
||||
ProcessingMessage string `json:"processing_message,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_PROCESSING_MESSAGE"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
}
|
||||
|
||||
// Token returns the webhook token for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) Token() string {
|
||||
return c.token
|
||||
}
|
||||
|
||||
// EncodingAESKey returns the encoding AES key for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) EncodingAESKey() string {
|
||||
return c.encodingAESKey
|
||||
}
|
||||
|
||||
// SetToken sets the token for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) SetToken(token string) {
|
||||
c.token = token
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
// SetEncodingAESKey sets the encoding AES key for WeCom AI bot
|
||||
func (c *WeComAIBotConfig) SetEncodingAESKey(key string) {
|
||||
c.encodingAESKey = key
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
func (c *WeComAIBotConfig) Secret() string {
|
||||
// Secret returns the WeCom bot secret.
|
||||
func (c *WeComConfig) Secret() string {
|
||||
return c.secret
|
||||
}
|
||||
|
||||
func (c *WeComAIBotConfig) SetSecret(secret string) {
|
||||
// SetSecret sets the WeCom bot secret.
|
||||
func (c *WeComConfig) SetSecret(secret string) {
|
||||
c.secret = secret
|
||||
c.secDirty = true
|
||||
}
|
||||
@@ -1623,39 +1510,10 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error {
|
||||
cfg.Channels.OneBot.accessToken = sec.Channels.OneBot.AccessToken
|
||||
}
|
||||
|
||||
// Handle WeCom token and encoding key
|
||||
// Handle WeCom bot secret
|
||||
if sec.Channels.WeCom != nil {
|
||||
if sec.Channels.WeCom.Token != "" {
|
||||
cfg.Channels.WeCom.token = sec.Channels.WeCom.Token
|
||||
}
|
||||
if sec.Channels.WeCom.EncodingAESKey != "" {
|
||||
cfg.Channels.WeCom.encodingAESKey = sec.Channels.WeCom.EncodingAESKey
|
||||
}
|
||||
}
|
||||
|
||||
// Handle WeCom App credentials
|
||||
if sec.Channels.WeComApp != nil {
|
||||
if sec.Channels.WeComApp.CorpSecret != "" {
|
||||
cfg.Channels.WeComApp.corpSecret = sec.Channels.WeComApp.CorpSecret
|
||||
}
|
||||
if sec.Channels.WeComApp.Token != "" {
|
||||
cfg.Channels.WeComApp.token = sec.Channels.WeComApp.Token
|
||||
}
|
||||
if sec.Channels.WeComApp.EncodingAESKey != "" {
|
||||
cfg.Channels.WeComApp.encodingAESKey = sec.Channels.WeComApp.EncodingAESKey
|
||||
}
|
||||
}
|
||||
|
||||
// Handle WeCom AI Bot credentials
|
||||
if sec.Channels.WeComAIBot != nil {
|
||||
if sec.Channels.WeComAIBot.Token != "" {
|
||||
cfg.Channels.WeComAIBot.token = sec.Channels.WeComAIBot.Token
|
||||
}
|
||||
if sec.Channels.WeComAIBot.EncodingAESKey != "" {
|
||||
cfg.Channels.WeComAIBot.encodingAESKey = sec.Channels.WeComAIBot.EncodingAESKey
|
||||
}
|
||||
if sec.Channels.WeComAIBot.Secret != "" {
|
||||
cfg.Channels.WeComAIBot.secret = sec.Channels.WeComAIBot.Secret
|
||||
if sec.Channels.WeCom.Secret != "" {
|
||||
cfg.Channels.WeCom.secret = sec.Channels.WeCom.Secret
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1879,27 +1737,10 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
}
|
||||
if cfg.Channels.WeCom.secDirty {
|
||||
cfg.security.Channels.WeCom = &WeComSecurity{
|
||||
Token: cfg.Channels.WeCom.Token(),
|
||||
EncodingAESKey: cfg.Channels.WeCom.EncodingAESKey(),
|
||||
Secret: cfg.Channels.WeCom.Secret(),
|
||||
}
|
||||
cfg.Channels.WeCom.secDirty = false
|
||||
}
|
||||
if cfg.Channels.WeComApp.secDirty {
|
||||
cfg.security.Channels.WeComApp = &WeComAppSecurity{
|
||||
CorpSecret: cfg.Channels.WeComApp.CorpSecret(),
|
||||
Token: cfg.Channels.WeComApp.Token(),
|
||||
EncodingAESKey: cfg.Channels.WeComApp.EncodingAESKey(),
|
||||
}
|
||||
cfg.Channels.WeComApp.secDirty = false
|
||||
}
|
||||
if cfg.Channels.WeComAIBot.secDirty {
|
||||
cfg.security.Channels.WeComAIBot = &WeComAIBotSecurity{
|
||||
Token: cfg.Channels.WeComAIBot.Token(),
|
||||
EncodingAESKey: cfg.Channels.WeComAIBot.EncodingAESKey(),
|
||||
Secret: cfg.Channels.WeComAIBot.Secret(),
|
||||
}
|
||||
cfg.Channels.WeComAIBot.secDirty = false
|
||||
}
|
||||
if cfg.Tools.Web.Brave.secDirty {
|
||||
cfg.security.Web.Brave = &BraveSecurity{
|
||||
APIKeys: cfg.Tools.Web.Brave.APIKeys(),
|
||||
|
||||
+63
-153
@@ -85,23 +85,21 @@ type toolsConfigV0 struct {
|
||||
}
|
||||
|
||||
type channelsConfigV0 struct {
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp"`
|
||||
Telegram telegramConfigV0 `json:"telegram"`
|
||||
Feishu feishuConfigV0 `json:"feishu"`
|
||||
Discord discordConfigV0 `json:"discord"`
|
||||
MaixCam maixcamConfigV0 `json:"maixcam"`
|
||||
Weixin weixinConfigV0 `json:"weixin"`
|
||||
QQ qqConfigV0 `json:"qq"`
|
||||
DingTalk dingtalkConfigV0 `json:"dingtalk"`
|
||||
Slack slackConfigV0 `json:"slack"`
|
||||
Matrix matrixConfigV0 `json:"matrix"`
|
||||
LINE lineConfigV0 `json:"line"`
|
||||
OneBot onebotConfigV0 `json:"onebot"`
|
||||
WeCom wecomConfigV0 `json:"wecom"`
|
||||
WeComApp wecomappConfigV0 `json:"wecom_app"`
|
||||
WeComAIBot wecomaibotConfigV0 `json:"wecom_aibot"`
|
||||
Pico picoConfigV0 `json:"pico"`
|
||||
IRC ircConfigV0 `json:"irc"`
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp"`
|
||||
Telegram telegramConfigV0 `json:"telegram"`
|
||||
Feishu feishuConfigV0 `json:"feishu"`
|
||||
Discord discordConfigV0 `json:"discord"`
|
||||
MaixCam maixcamConfigV0 `json:"maixcam"`
|
||||
Weixin weixinConfigV0 `json:"weixin"`
|
||||
QQ qqConfigV0 `json:"qq"`
|
||||
DingTalk dingtalkConfigV0 `json:"dingtalk"`
|
||||
Slack slackConfigV0 `json:"slack"`
|
||||
Matrix matrixConfigV0 `json:"matrix"`
|
||||
LINE lineConfigV0 `json:"line"`
|
||||
OneBot onebotConfigV0 `json:"onebot"`
|
||||
WeCom wecomConfigV0 `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
|
||||
Pico picoConfigV0 `json:"pico"`
|
||||
IRC ircConfigV0 `json:"irc"`
|
||||
}
|
||||
|
||||
func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity) {
|
||||
@@ -117,45 +115,39 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity)
|
||||
line, lineSecurity := v.LINE.ToLINEConfig()
|
||||
onebot, onebotSecurity := v.OneBot.ToOneBotConfig()
|
||||
wecom, wecomSecurity := v.WeCom.ToWeComConfig()
|
||||
wecomapp, wecomappSecurity := v.WeComApp.ToWeComAppConfig()
|
||||
wecomaibot, wecomaibotSecurity := v.WeComAIBot.ToWeComAIBotConfig()
|
||||
pico, picoSecurity := v.Pico.ToPicoConfig()
|
||||
irc, ircSecurity := v.IRC.ToIRCConfig()
|
||||
|
||||
return ChannelsConfig{
|
||||
WhatsApp: v.WhatsApp,
|
||||
Telegram: telegram,
|
||||
Feishu: feishu,
|
||||
Discord: discord,
|
||||
MaixCam: maixcam,
|
||||
QQ: qq,
|
||||
Weixin: weixin,
|
||||
DingTalk: dingtalk,
|
||||
Slack: slack,
|
||||
Matrix: matrix,
|
||||
LINE: line,
|
||||
OneBot: onebot,
|
||||
WeCom: wecom,
|
||||
WeComApp: wecomapp,
|
||||
WeComAIBot: wecomaibot,
|
||||
Pico: pico,
|
||||
IRC: irc,
|
||||
WhatsApp: v.WhatsApp,
|
||||
Telegram: telegram,
|
||||
Feishu: feishu,
|
||||
Discord: discord,
|
||||
MaixCam: maixcam,
|
||||
QQ: qq,
|
||||
Weixin: weixin,
|
||||
DingTalk: dingtalk,
|
||||
Slack: slack,
|
||||
Matrix: matrix,
|
||||
LINE: line,
|
||||
OneBot: onebot,
|
||||
WeCom: wecom,
|
||||
Pico: pico,
|
||||
IRC: irc,
|
||||
}, ChannelsSecurity{
|
||||
Telegram: telegramSecurity,
|
||||
Feishu: feishuSecurity,
|
||||
Discord: discordSecurity,
|
||||
QQ: qqSecurity,
|
||||
Weixin: weixinSecurity,
|
||||
DingTalk: dingtalkSecurity,
|
||||
Slack: slackSecurity,
|
||||
Matrix: matrixSecurity,
|
||||
LINE: lineSecurity,
|
||||
OneBot: onebotSecurity,
|
||||
WeCom: wecomSecurity,
|
||||
WeComApp: wecomappSecurity,
|
||||
WeComAIBot: wecomaibotSecurity,
|
||||
Pico: picoSecurity,
|
||||
IRC: ircSecurity,
|
||||
Telegram: telegramSecurity,
|
||||
Feishu: feishuSecurity,
|
||||
Discord: discordSecurity,
|
||||
QQ: qqSecurity,
|
||||
Weixin: weixinSecurity,
|
||||
DingTalk: dingtalkSecurity,
|
||||
Slack: slackSecurity,
|
||||
Matrix: matrixSecurity,
|
||||
LINE: lineSecurity,
|
||||
OneBot: onebotSecurity,
|
||||
WeCom: wecomSecurity,
|
||||
Pico: picoSecurity,
|
||||
IRC: ircSecurity,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,39 +465,32 @@ func (v *onebotConfigV0) ToOneBotConfig() (OneBotConfig, *OneBotSecurity) {
|
||||
}
|
||||
|
||||
type wecomConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
|
||||
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
|
||||
WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"`
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
BotID string `json:"bot_id" env:"BOT_ID"`
|
||||
Secret string `json:"secret" env:"SECRET"`
|
||||
WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"`
|
||||
SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"`
|
||||
DMPolicy string `json:"dm_policy,omitempty" env:"DM_POLICY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"`
|
||||
GroupPolicy string `json:"group_policy,omitempty" env:"GROUP_POLICY"`
|
||||
GroupAllowFrom FlexibleStringSlice `json:"group_allow_from,omitempty" env:"GROUP_ALLOW_FROM"`
|
||||
Groups map[string]WeComGroupConfig `json:"groups,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
func (v *wecomConfigV0) ToWeComConfig() (WeComConfig, *WeComSecurity) {
|
||||
var sec *WeComSecurity
|
||||
if v.Token != "" || v.EncodingAESKey != "" {
|
||||
sec = &WeComSecurity{
|
||||
Token: v.Token,
|
||||
EncodingAESKey: v.EncodingAESKey,
|
||||
}
|
||||
if v.Secret != "" {
|
||||
sec = &WeComSecurity{Secret: v.Secret}
|
||||
}
|
||||
return WeComConfig{
|
||||
Enabled: v.Enabled,
|
||||
token: v.Token,
|
||||
encodingAESKey: v.EncodingAESKey,
|
||||
WebhookURL: v.WebhookURL,
|
||||
WebhookHost: v.WebhookHost,
|
||||
WebhookPort: v.WebhookPort,
|
||||
WebhookPath: v.WebhookPath,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReplyTimeout: v.ReplyTimeout,
|
||||
GroupTrigger: v.GroupTrigger,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
Enabled: v.Enabled,
|
||||
BotID: v.BotID,
|
||||
secret: v.Secret,
|
||||
WebSocketURL: v.WebSocketURL,
|
||||
SendThinkingMessage: v.SendThinkingMessage,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
}, sec
|
||||
}
|
||||
|
||||
@@ -537,81 +522,6 @@ func (v *weixinConfigV0) ToWeiXinConfig() (WeixinConfig, *WeixinSecurity) {
|
||||
}, sec
|
||||
}
|
||||
|
||||
type wecomappConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
|
||||
CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
|
||||
CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
|
||||
AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
|
||||
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
func (v *wecomappConfigV0) ToWeComAppConfig() (WeComAppConfig, *WeComAppSecurity) {
|
||||
var sec *WeComAppSecurity
|
||||
if v.CorpSecret != "" || v.Token != "" || v.EncodingAESKey != "" {
|
||||
sec = &WeComAppSecurity{
|
||||
CorpSecret: v.CorpSecret,
|
||||
Token: v.Token,
|
||||
EncodingAESKey: v.EncodingAESKey,
|
||||
}
|
||||
}
|
||||
return WeComAppConfig{
|
||||
Enabled: v.Enabled,
|
||||
CorpID: v.CorpID,
|
||||
corpSecret: v.CorpSecret,
|
||||
AgentID: v.AgentID,
|
||||
token: v.Token,
|
||||
encodingAESKey: v.EncodingAESKey,
|
||||
WebhookHost: v.WebhookHost,
|
||||
WebhookPort: v.WebhookPort,
|
||||
WebhookPath: v.WebhookPath,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReplyTimeout: v.ReplyTimeout,
|
||||
GroupTrigger: v.GroupTrigger,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
}, sec
|
||||
}
|
||||
|
||||
type wecomaibotConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
|
||||
Secret string `json:"secret" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"`
|
||||
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
|
||||
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
|
||||
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"`
|
||||
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
func (v *wecomaibotConfigV0) ToWeComAIBotConfig() (WeComAIBotConfig, *WeComAIBotSecurity) {
|
||||
var sec *WeComAIBotSecurity
|
||||
if v.Token != "" || v.Secret != "" || v.EncodingAESKey != "" {
|
||||
sec = &WeComAIBotSecurity{
|
||||
Token: v.Token,
|
||||
Secret: v.Secret,
|
||||
EncodingAESKey: v.EncodingAESKey,
|
||||
}
|
||||
}
|
||||
return WeComAIBotConfig{
|
||||
Enabled: v.Enabled,
|
||||
WebhookPath: v.WebhookPath,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReplyTimeout: v.ReplyTimeout,
|
||||
MaxSteps: v.MaxSteps,
|
||||
WelcomeMessage: v.WelcomeMessage,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
}, sec
|
||||
}
|
||||
|
||||
type picoConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
|
||||
|
||||
@@ -1372,8 +1372,7 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) {
|
||||
Feishu: &FeishuSecurity{AppSecret: "feishu-app-secret-123", EncryptKey: "feishu-encrypt-key"},
|
||||
DingTalk: &DingTalkSecurity{ClientSecret: "dingtalk-client-secret"},
|
||||
OneBot: &OneBotSecurity{AccessToken: "onebot-access-token"},
|
||||
WeCom: &WeComSecurity{Token: "wecom-token", EncodingAESKey: "wecom-aes-key"},
|
||||
WeComApp: &WeComAppSecurity{CorpSecret: "wecom-app-secret", Token: "wecom-app-token"},
|
||||
WeCom: &WeComSecurity{Secret: "wecom-secret"},
|
||||
Pico: &PicoSecurity{Token: "pico-token-abc123"},
|
||||
IRC: &IRCSecurity{
|
||||
Password: "irc-password",
|
||||
|
||||
+5
-26
@@ -129,32 +129,11 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
WeCom: WeComConfig{
|
||||
Enabled: false,
|
||||
WebhookURL: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18793,
|
||||
WebhookPath: "/webhook/wecom",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
},
|
||||
WeComApp: WeComAppConfig{
|
||||
Enabled: false,
|
||||
CorpID: "",
|
||||
AgentID: 0,
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18792,
|
||||
WebhookPath: "/webhook/wecom-app",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
},
|
||||
WeComAIBot: WeComAIBotConfig{
|
||||
Enabled: false,
|
||||
WebhookPath: "/webhook/wecom-aibot",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
ReplyTimeout: 5,
|
||||
MaxSteps: 10,
|
||||
WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
|
||||
ProcessingMessage: DefaultWeComAIBotProcessingMessage,
|
||||
Enabled: false,
|
||||
BotID: "",
|
||||
WebSocketURL: "wss://openws.work.weixin.qq.com",
|
||||
SendThinkingMessage: true,
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
Weixin: WeixinConfig{
|
||||
Enabled: false,
|
||||
|
||||
+14
-29
@@ -69,21 +69,19 @@ type ModelSecurityEntry struct {
|
||||
|
||||
// ChannelsSecurity stores channel-related security data
|
||||
type ChannelsSecurity struct {
|
||||
Telegram *TelegramSecurity `yaml:"telegram,omitempty"`
|
||||
Feishu *FeishuSecurity `yaml:"feishu,omitempty"`
|
||||
Discord *DiscordSecurity `yaml:"discord,omitempty"`
|
||||
Weixin *WeixinSecurity `yaml:"weixin,omitempty"`
|
||||
QQ *QQSecurity `yaml:"qq,omitempty"`
|
||||
DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"`
|
||||
Slack *SlackSecurity `yaml:"slack,omitempty"`
|
||||
Matrix *MatrixSecurity `yaml:"matrix,omitempty"`
|
||||
LINE *LINESecurity `yaml:"line,omitempty"`
|
||||
OneBot *OneBotSecurity `yaml:"onebot,omitempty"`
|
||||
WeCom *WeComSecurity `yaml:"wecom,omitempty"`
|
||||
WeComApp *WeComAppSecurity `yaml:"wecom_app,omitempty"`
|
||||
WeComAIBot *WeComAIBotSecurity `yaml:"wecom_aibot,omitempty"`
|
||||
Pico *PicoSecurity `yaml:"pico,omitempty"`
|
||||
IRC *IRCSecurity `yaml:"irc,omitempty"`
|
||||
Telegram *TelegramSecurity `yaml:"telegram,omitempty"`
|
||||
Feishu *FeishuSecurity `yaml:"feishu,omitempty"`
|
||||
Discord *DiscordSecurity `yaml:"discord,omitempty"`
|
||||
Weixin *WeixinSecurity `yaml:"weixin,omitempty"`
|
||||
QQ *QQSecurity `yaml:"qq,omitempty"`
|
||||
DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"`
|
||||
Slack *SlackSecurity `yaml:"slack,omitempty"`
|
||||
Matrix *MatrixSecurity `yaml:"matrix,omitempty"`
|
||||
LINE *LINESecurity `yaml:"line,omitempty"`
|
||||
OneBot *OneBotSecurity `yaml:"onebot,omitempty"`
|
||||
WeCom *WeComSecurity `yaml:"wecom,omitempty"`
|
||||
Pico *PicoSecurity `yaml:"pico,omitempty"`
|
||||
IRC *IRCSecurity `yaml:"irc,omitempty"`
|
||||
}
|
||||
|
||||
type TelegramSecurity struct {
|
||||
@@ -131,20 +129,7 @@ type OneBotSecurity struct {
|
||||
}
|
||||
|
||||
type WeComSecurity struct {
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
|
||||
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
|
||||
}
|
||||
|
||||
type WeComAppSecurity struct {
|
||||
CorpSecret string `yaml:"corp_secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
|
||||
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
|
||||
}
|
||||
|
||||
type WeComAIBotSecurity struct {
|
||||
Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"`
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
|
||||
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
|
||||
Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_SECRET"`
|
||||
}
|
||||
|
||||
type PicoSecurity struct {
|
||||
|
||||
@@ -240,15 +240,7 @@ func TestAllSecurityKeysAccessible(t *testing.T) {
|
||||
},
|
||||
"wecom": {
|
||||
"enabled": true,
|
||||
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook"
|
||||
},
|
||||
"wecom_app": {
|
||||
"enabled": true,
|
||||
"corp_id": "test_corp_id",
|
||||
"agent_id": 123456
|
||||
},
|
||||
"wecom_aibot": {
|
||||
"enabled": true
|
||||
"bot_id": "test_wecom_bot_id"
|
||||
},
|
||||
"pico": {
|
||||
"enabled": true
|
||||
@@ -315,15 +307,7 @@ channels:
|
||||
onebot:
|
||||
access_token: "onebot_test_access_token"
|
||||
wecom:
|
||||
token: "wecom_test_webhook_token"
|
||||
encoding_aes_key: "wecom_test_aes_key"
|
||||
wecom_app:
|
||||
corp_secret: "wecom_app_test_corp_secret"
|
||||
token: "wecom_app_test_token"
|
||||
encoding_aes_key: "wecom_app_test_aes_key"
|
||||
wecom_aibot:
|
||||
token: "wecom_aibot_test_token"
|
||||
encoding_aes_key: "wecom_aibot_test_aes_key"
|
||||
secret: "wecom_test_secret"
|
||||
pico:
|
||||
token: "pico_test_token"
|
||||
irc:
|
||||
@@ -409,24 +393,10 @@ skills:
|
||||
t.Logf("OneBot AccessToken(): %s", cfg.Channels.OneBot.AccessToken())
|
||||
|
||||
// WeCom
|
||||
assert.Equal(t, "wecom_test_webhook_token", cfg.Channels.WeCom.Token())
|
||||
assert.Equal(t, "wecom_test_aes_key", cfg.Channels.WeCom.EncodingAESKey())
|
||||
t.Logf("WeCom Token(): %s", cfg.Channels.WeCom.Token())
|
||||
t.Logf("WeCom EncodingAESKey(): %s", cfg.Channels.WeCom.EncodingAESKey())
|
||||
|
||||
// WeCom App
|
||||
assert.Equal(t, "wecom_app_test_corp_secret", cfg.Channels.WeComApp.CorpSecret())
|
||||
assert.Equal(t, "wecom_app_test_token", cfg.Channels.WeComApp.Token())
|
||||
assert.Equal(t, "wecom_app_test_aes_key", cfg.Channels.WeComApp.EncodingAESKey())
|
||||
t.Logf("WeComApp CorpSecret(): %s", cfg.Channels.WeComApp.CorpSecret())
|
||||
t.Logf("WeComApp Token(): %s", cfg.Channels.WeComApp.Token())
|
||||
t.Logf("WeComApp EncodingAESKey(): %s", cfg.Channels.WeComApp.EncodingAESKey())
|
||||
|
||||
// WeCom AI Bot
|
||||
assert.Equal(t, "wecom_aibot_test_token", cfg.Channels.WeComAIBot.Token())
|
||||
assert.Equal(t, "wecom_aibot_test_aes_key", cfg.Channels.WeComAIBot.EncodingAESKey())
|
||||
t.Logf("WeComAIBot Token(): %s", cfg.Channels.WeComAIBot.Token())
|
||||
t.Logf("WeComAIBot EncodingAESKey(): %s", cfg.Channels.WeComAIBot.EncodingAESKey())
|
||||
assert.Equal(t, "test_wecom_bot_id", cfg.Channels.WeCom.BotID)
|
||||
assert.Equal(t, "wecom_test_secret", cfg.Channels.WeCom.Secret())
|
||||
t.Logf("WeCom BotID: %s", cfg.Channels.WeCom.BotID)
|
||||
t.Logf("WeCom Secret(): %s", cfg.Channels.WeCom.Secret())
|
||||
|
||||
// Pico
|
||||
assert.Equal(t, "pico_test_token", cfg.Channels.Pico.Token())
|
||||
|
||||
@@ -13,17 +13,16 @@ var migrateableDirs = []string{
|
||||
}
|
||||
|
||||
var supportedChannels = map[string]bool{
|
||||
"whatsapp": true,
|
||||
"telegram": true,
|
||||
"feishu": true,
|
||||
"discord": true,
|
||||
"maixcam": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"slack": true,
|
||||
"matrix": true,
|
||||
"line": true,
|
||||
"onebot": true,
|
||||
"wecom": true,
|
||||
"wecom_app": true,
|
||||
"whatsapp": true,
|
||||
"telegram": true,
|
||||
"feishu": true,
|
||||
"discord": true,
|
||||
"maixcam": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"slack": true,
|
||||
"matrix": true,
|
||||
"line": true,
|
||||
"onebot": true,
|
||||
"wecom": true,
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ var channelCatalog = []channelCatalogItem{
|
||||
{Name: "qq", ConfigKey: "qq"},
|
||||
{Name: "onebot", ConfigKey: "onebot"},
|
||||
{Name: "wecom", ConfigKey: "wecom"},
|
||||
{Name: "wecom_app", ConfigKey: "wecom_app"},
|
||||
{Name: "wecom_aibot", ConfigKey: "wecom_aibot"},
|
||||
{Name: "whatsapp", ConfigKey: "whatsapp", Variant: "bridge"},
|
||||
{Name: "whatsapp_native", ConfigKey: "whatsapp", Variant: "native"},
|
||||
{Name: "pico", ConfigKey: "pico"},
|
||||
|
||||
@@ -209,6 +209,15 @@ func validateConfig(cfg *config.Config) []string {
|
||||
errs = append(errs, "channels.discord.token is required when discord channel is enabled")
|
||||
}
|
||||
|
||||
if cfg.Channels.WeCom.Enabled {
|
||||
if cfg.Channels.WeCom.BotID == "" {
|
||||
errs = append(errs, "channels.wecom.bot_id is required when wecom channel is enabled")
|
||||
}
|
||||
if cfg.Channels.WeCom.Secret() == "" {
|
||||
errs = append(errs, "channels.wecom.secret is required when wecom channel is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Tools.Exec.Enabled {
|
||||
if cfg.Tools.Exec.EnableDenyPatterns {
|
||||
errs = append(
|
||||
|
||||
@@ -146,13 +146,7 @@ function isConfigured(
|
||||
case "weixin":
|
||||
return asString(config.account_id) !== ""
|
||||
case "wecom":
|
||||
return asString(config.token) !== ""
|
||||
case "wecom_app":
|
||||
return (
|
||||
asString(config.corp_id) !== "" && asString(config.corp_secret) !== ""
|
||||
)
|
||||
case "wecom_aibot":
|
||||
return asString(config.token) !== ""
|
||||
return asString(config.bot_id) !== ""
|
||||
case "whatsapp":
|
||||
return asString(config.bridge_url) !== ""
|
||||
case "whatsapp_native":
|
||||
@@ -193,11 +187,7 @@ function getRequiredFieldKeys(channelName: string): string[] {
|
||||
case "onebot":
|
||||
return ["ws_url"]
|
||||
case "wecom":
|
||||
return ["token"]
|
||||
case "wecom_app":
|
||||
return ["corp_id", "corp_secret"]
|
||||
case "wecom_aibot":
|
||||
return ["token"]
|
||||
return ["bot_id", "secret"]
|
||||
case "whatsapp":
|
||||
return ["bridge_url"]
|
||||
case "pico":
|
||||
|
||||
@@ -28,6 +28,7 @@ const SECRET_FIELDS = new Set([
|
||||
"encoding_aes_key",
|
||||
"encrypt_key",
|
||||
"verification_token",
|
||||
"secret",
|
||||
"password",
|
||||
"nickserv_password",
|
||||
"sasl_password",
|
||||
@@ -44,6 +45,7 @@ const OBJECT_FIELDS = new Set([
|
||||
"allow_token_query",
|
||||
"allow_from",
|
||||
"allow_origins",
|
||||
"groups",
|
||||
])
|
||||
|
||||
function formatLabel(key: string): string {
|
||||
@@ -118,6 +120,14 @@ export function GenericForm({
|
||||
app_id: t("channels.form.desc.appId"),
|
||||
client_id: t("channels.form.desc.clientId"),
|
||||
corp_id: t("channels.form.desc.corpId"),
|
||||
bot_id: t("channels.form.desc.appId"),
|
||||
websocket_url: t("channels.form.desc.wsUrl"),
|
||||
dm_policy: t("channels.form.desc.genericField", { field: "DM policy" }),
|
||||
group_policy: t("channels.form.desc.genericField", { field: "group policy" }),
|
||||
group_allow_from: t("channels.form.desc.allowFrom"),
|
||||
send_thinking_message: t("channels.form.desc.genericField", {
|
||||
field: "thinking message behavior",
|
||||
}),
|
||||
agent_id: t("channels.form.desc.agentId"),
|
||||
webhook_url: t("channels.form.desc.webhookUrl"),
|
||||
webhook_host: t("channels.form.desc.webhookHost"),
|
||||
|
||||
@@ -35,8 +35,6 @@ const CHANNEL_IMPORTANCE_ORDER = [
|
||||
"slack",
|
||||
"line",
|
||||
"wecom",
|
||||
"wecom_app",
|
||||
"wecom_aibot",
|
||||
"dingtalk",
|
||||
"qq",
|
||||
"onebot",
|
||||
@@ -76,8 +74,6 @@ const CHANNEL_ICON_MAP: Record<
|
||||
line: IconBrandLine,
|
||||
qq: IconBrandQq,
|
||||
wecom: IconBrandWechat,
|
||||
wecom_app: IconBrandWechat,
|
||||
wecom_aibot: IconBrandWechat,
|
||||
whatsapp: IconBrandWhatsapp,
|
||||
whatsapp_native: IconBrandWhatsapp,
|
||||
matrix: IconBrandMatrix,
|
||||
|
||||
@@ -233,8 +233,6 @@
|
||||
"qq": "QQ",
|
||||
"onebot": "OneBot",
|
||||
"wecom": "WeCom",
|
||||
"wecom_app": "WeCom App",
|
||||
"wecom_aibot": "WeCom AI Bot",
|
||||
"whatsapp": "WhatsApp",
|
||||
"whatsapp_native": "WhatsApp Native",
|
||||
"pico": "Web",
|
||||
|
||||
@@ -233,8 +233,6 @@
|
||||
"qq": "QQ",
|
||||
"onebot": "OneBot",
|
||||
"wecom": "企业微信",
|
||||
"wecom_app": "企业微信应用",
|
||||
"wecom_aibot": "企业微信 AI 机器人",
|
||||
"whatsapp": "WhatsApp",
|
||||
"whatsapp_native": "WhatsApp Native",
|
||||
"pico": "Web",
|
||||
|
||||
Reference in New Issue
Block a user