refactor(wecom): rebuild ai bot channel

This commit is contained in:
Hoshina
2026-03-24 15:03:41 +08:00
parent 8b6cbd9909
commit a1f95f02bc
36 changed files with 1833 additions and 7196 deletions
+11 -13
View File
@@ -1495,18 +1495,17 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
t.Fatalf("Failed to create channel manager: %v", err)
}
for name, id := range map[string]string{
"whatsapp": "rid-whatsapp",
"telegram": "rid-telegram",
"feishu": "rid-feishu",
"discord": "rid-discord",
"maixcam": "rid-maixcam",
"qq": "rid-qq",
"dingtalk": "rid-dingtalk",
"slack": "rid-slack",
"line": "rid-line",
"onebot": "rid-onebot",
"wecom": "rid-wecom",
"wecom_app": "rid-wecom-app",
"whatsapp": "rid-whatsapp",
"telegram": "rid-telegram",
"feishu": "rid-feishu",
"discord": "rid-discord",
"maixcam": "rid-maixcam",
"qq": "rid-qq",
"dingtalk": "rid-dingtalk",
"slack": "rid-slack",
"line": "rid-line",
"onebot": "rid-onebot",
"wecom": "rid-wecom",
} {
chManager.RegisterChannel(name, &fakeChannel{id: id})
}
@@ -1526,7 +1525,6 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
{channel: "line", wantID: "rid-line"},
{channel: "onebot", wantID: "rid-onebot"},
{channel: "wecom", wantID: "rid-wecom"},
{channel: "wecom_app", wantID: "rid-wecom-app"},
{channel: "unknown", wantID: ""},
}
+1 -10
View File
@@ -371,19 +371,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
m.initChannel("onebot", "OneBot")
}
if channels.WeCom.Enabled && channels.WeCom.Token() != "" {
if channels.WeCom.Enabled && channels.WeCom.BotID != "" && channels.WeCom.Secret() != "" {
m.initChannel("wecom", "WeCom")
}
if channels.WeComAIBot.Enabled && (channels.WeComAIBot.Token() != "" ||
(channels.WeComAIBot.Secret() != "" && channels.WeComAIBot.BotID != "")) {
m.initChannel("wecom_aibot", "WeCom AI Bot")
}
if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" {
m.initChannel("wecom_app", "WeCom App")
}
if channels.Weixin.Enabled && channels.Weixin.Token() != "" {
m.initChannel("weixin", "Weixin")
}
+2 -19
View File
@@ -49,15 +49,7 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
value["token"] = ch.LINE.ChannelAccessToken()
value["secret"] = ch.LINE.ChannelSecret()
case "wecom":
value["token"] = ch.WeCom.Token()
value["key"] = ch.WeCom.EncodingAESKey()
case "wecom_app":
value["token"] = ch.WeComApp.Token()
value["secret"] = ch.WeComApp.CorpSecret()
case "wecom_aibot":
value["token"] = ch.WeComAIBot.Token()
value["key"] = ch.WeComAIBot.EncodingAESKey()
value["secret"] = ch.WeComAIBot.Secret()
value["secret"] = ch.WeCom.Secret()
case "dingtalk":
value["secret"] = ch.QQ.AppSecret()
case "qq":
@@ -156,16 +148,7 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
newcfg.LINE.SetChannelSecret(old.LINE.ChannelSecret())
}
if newcfg.WeCom.Enabled {
newcfg.WeCom.SetToken(old.WeCom.Token())
newcfg.WeCom.SetEncodingAESKey(old.WeCom.EncodingAESKey())
}
if newcfg.WeComApp.Enabled {
newcfg.WeComApp.SetToken(old.WeComApp.Token())
newcfg.WeComApp.SetCorpSecret(old.WeComApp.CorpSecret())
}
if newcfg.WeComAIBot.Enabled {
newcfg.WeComAIBot.SetToken(old.WeComAIBot.Token())
newcfg.WeComAIBot.SetEncodingAESKey(old.WeComAIBot.EncodingAESKey())
newcfg.WeCom.SetSecret(old.WeCom.Secret())
}
if newcfg.DingTalk.Enabled {
newcfg.DingTalk.SetClientSecret(old.DingTalk.ClientSecret())
File diff suppressed because it is too large Load Diff
-559
View File
@@ -1,559 +0,0 @@
package wecom
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
// ---- Webhook mode tests ----
func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) {
t.Run("success with valid config", func(t *testing.T) {
cfg := config.WeComAIBotConfig{}
cfg.Enabled = true
cfg.SetToken("test_token")
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
cfg.WebhookPath = "/webhook/test"
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if ch == nil {
t.Fatal("Expected channel to be created")
}
if ch.Name() != "wecom_aibot" {
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
}
// Webhook mode must implement WebhookHandler.
if _, ok := ch.(channels.WebhookHandler); !ok {
t.Error("Webhook mode channel should implement WebhookHandler")
}
})
t.Run("error with missing token", func(t *testing.T) {
cfg := config.WeComAIBotConfig{}
cfg.Enabled = true
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing token, got nil")
}
})
t.Run("error with missing encoding key", func(t *testing.T) {
cfg := config.WeComAIBotConfig{}
cfg.Enabled = true
cfg.SetToken("test_token")
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing encoding key, got nil")
}
})
}
func TestWeComAIBotWebhookChannelStartStop(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
}
cfg.SetToken("test_token")
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ctx := context.Background()
if err := ch.Start(ctx); err != nil {
t.Fatalf("Failed to start channel: %v", err)
}
if !ch.IsRunning() {
t.Error("Expected channel to be running after Start")
}
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Failed to stop channel: %v", err)
}
if ch.IsRunning() {
t.Error("Expected channel to be stopped after Stop")
}
}
func TestWeComAIBotChannelWebhookPath(t *testing.T) {
t.Run("default path", func(t *testing.T) {
cfg := config.WeComAIBotConfig{}
cfg.Enabled = true
cfg.SetToken("test_token")
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
wh, ok := ch.(channels.WebhookHandler)
if !ok {
t.Fatal("Expected channel to implement WebhookHandler")
}
expectedPath := "/webhook/wecom-aibot"
if wh.WebhookPath() != expectedPath {
t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath())
}
})
t.Run("custom path", func(t *testing.T) {
customPath := "/custom/webhook"
cfg := config.WeComAIBotConfig{}
cfg.Enabled = true
cfg.SetToken("test_token")
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
cfg.WebhookPath = customPath
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
wh, ok := ch.(channels.WebhookHandler)
if !ok {
t.Fatal("Expected channel to implement WebhookHandler")
}
if wh.WebhookPath() != customPath {
t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath())
}
})
}
func TestWeComAIBotChannelGetStreamResponseProcessingMessage(t *testing.T) {
validAESKey := "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG"
t.Run("uses default processing message", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
}
cfg.SetToken("test_token")
cfg.SetEncodingAESKey(validAESKey)
messageBus := bus.NewMessageBus()
channel, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ch, ok := channel.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
task := &streamTask{
StreamID: "stream-default",
ChatID: "chat-default",
Deadline: time.Now().Add(-time.Second),
}
ch.streamTasks[task.StreamID] = task
ch.chatTasks[task.ChatID] = []*streamTask{task}
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
if !resp.Stream.Finish {
t.Fatal("Expected finished stream response after deadline")
}
if resp.Stream.Content != config.DefaultWeComAIBotProcessingMessage {
t.Fatalf("Expected default processing message %q, got %q",
config.DefaultWeComAIBotProcessingMessage, resp.Stream.Content)
}
if !task.StreamClosed {
t.Fatal("Expected task stream to be marked closed")
}
if _, ok := ch.streamTasks[task.StreamID]; ok {
t.Fatal("Expected closed stream task to be removed from streamTasks")
}
if len(ch.chatTasks[task.ChatID]) != 1 {
t.Fatalf("Expected task to remain queued for response_url delivery, got %d entries",
len(ch.chatTasks[task.ChatID]))
}
})
t.Run("uses custom processing message", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
ProcessingMessage: "Please wait a moment. The result will be delivered in a follow-up message.",
}
cfg.SetToken("test_token")
cfg.SetEncodingAESKey(validAESKey)
messageBus := bus.NewMessageBus()
channel, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ch, ok := channel.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
task := &streamTask{
StreamID: "stream-custom",
ChatID: "chat-custom",
Deadline: time.Now().Add(-time.Second),
}
resp := decodeStreamResponse(t, ch, ch.getStreamResponse(task, "1234567890", "nonce"))
if resp.Stream.Content != cfg.ProcessingMessage {
t.Fatalf("Expected custom processing message %q, got %q", cfg.ProcessingMessage, resp.Stream.Content)
}
})
}
func TestGenerateStreamID(t *testing.T) {
cfg := config.WeComAIBotConfig{}
cfg.Enabled = true
cfg.SetToken("test_token")
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
webhookCh, ok := ch.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := webhookCh.generateStreamID()
if len(id) != 10 {
t.Errorf("Expected stream ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate stream ID generated: %s", id)
}
ids[id] = true
}
}
func TestEncryptDecrypt(t *testing.T) {
// Use a valid 43-character base64 key (企业微信标准格式)
cfg := config.WeComAIBotConfig{}
cfg.Enabled = true
cfg.SetToken("test_token")
cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG") // 43 characters
messageBus := bus.NewMessageBus()
ch, _ := NewWeComAIBotChannel(cfg, messageBus)
webhookCh, ok := ch.(*WeComAIBotChannel)
if !ok {
t.Fatal("Expected webhook mode channel")
}
plaintext := "Hello, World!"
receiveid := ""
encrypted, err := webhookCh.encryptMessage(plaintext, receiveid)
if err != nil {
t.Fatalf("Failed to encrypt message: %v", err)
}
if encrypted == "" {
t.Fatal("Encrypted message is empty")
}
// Decrypt
decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey(), receiveid)
if err != nil {
t.Fatalf("Failed to decrypt message: %v", err)
}
if decrypted != plaintext {
t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
}
}
func TestGenerateSignature(t *testing.T) {
token := "test_token"
timestamp := "1234567890"
nonce := "test_nonce"
encrypt := "encrypted_msg"
signature := computeSignature(token, timestamp, nonce, encrypt)
if signature == "" {
t.Error("Generated signature is empty")
}
if !verifySignature(token, signature, timestamp, nonce, encrypt) {
t.Error("Generated signature does not verify correctly")
}
}
func decodeStreamResponse(t *testing.T, ch *WeComAIBotChannel, encryptedResponse string) WeComAIBotStreamResponse {
t.Helper()
var wrapped WeComAIBotEncryptedResponse
if err := json.Unmarshal([]byte(encryptedResponse), &wrapped); err != nil {
t.Fatalf("Failed to unmarshal encrypted response: %v", err)
}
plaintext, err := decryptMessageWithVerify(wrapped.Encrypt, ch.config.EncodingAESKey(), "")
if err != nil {
t.Fatalf("Failed to decrypt response: %v", err)
}
var resp WeComAIBotStreamResponse
if err := json.Unmarshal([]byte(plaintext), &resp); err != nil {
t.Fatalf("Failed to unmarshal decrypted response: %v", err)
}
return resp
}
// ---- WebSocket long-connection mode tests ----
func TestNewWeComAIBotChannel_WSMode(t *testing.T) {
t.Run("success with bot_id and secret", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
}
cfg.SetSecret("test_secret")
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if ch == nil {
t.Fatal("Expected channel to be created")
}
if ch.Name() != "wecom_aibot" {
t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
}
// WebSocket mode must NOT implement WebhookHandler.
if _, ok := ch.(channels.WebhookHandler); ok {
t.Error("WebSocket mode channel should NOT implement WebhookHandler")
}
})
t.Run("ws mode takes priority over webhook fields", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
}
cfg.SetSecret("test_secret")
cfg.SetToken("also_set")
cfg.SetEncodingAESKey("testkey1234567890123456789012345678901234567")
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if _, ok := ch.(*WeComAIBotWSChannel); !ok {
t.Error("Expected WebSocket mode channel when both BotID+secret and Token+Key are set")
}
})
t.Run("error with missing bot_id", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
}
cfg.SetSecret("test_secret")
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
// Missing bot_id alone means neither WS mode nor webhook mode is fully configured.
if err == nil {
t.Fatal("Expected error for missing bot_id, got nil")
}
})
t.Run("error with missing secret", func(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
}
messageBus := bus.NewMessageBus()
_, err := NewWeComAIBotChannel(cfg, messageBus)
if err == nil {
t.Fatal("Expected error for missing secret, got nil")
}
})
}
func TestWeComAIBotWSChannelStartStop(t *testing.T) {
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
}
cfg.SetSecret("test_secret")
messageBus := bus.NewMessageBus()
ch, err := NewWeComAIBotChannel(cfg, messageBus)
if err != nil {
t.Fatalf("Failed to create channel: %v", err)
}
ctx := context.Background()
// Start launches a background goroutine; it should not block or return an error.
if err := ch.Start(ctx); err != nil {
t.Fatalf("Failed to start channel: %v", err)
}
if !ch.IsRunning() {
t.Error("Expected channel to be running after Start")
}
// Stop should work regardless of whether the WebSocket actually connected.
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Failed to stop channel: %v", err)
}
if ch.IsRunning() {
t.Error("Expected channel to be stopped after Stop")
}
}
func TestGenerateRandomID(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 200; i++ {
id := generateRandomID(10)
if len(id) != 10 {
t.Errorf("Expected ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate ID generated: %s", id)
}
ids[id] = true
}
}
func TestWSGenerateID(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 200; i++ {
id := wsGenerateID()
if len(id) != 10 {
t.Errorf("Expected ID length 10, got %d", len(id))
}
if ids[id] {
t.Errorf("Duplicate wsGenerateID result: %s", id)
}
ids[id] = true
}
}
// ---- Webhook streaming fallback tests ----
// makeWebhookChannel creates a started WeComAIBotChannel for testing.
func makeWebhookChannel(t *testing.T) *WeComAIBotChannel {
t.Helper()
cfg := config.WeComAIBotConfig{
Enabled: true,
}
cfg.SetToken("test_token")
cfg.SetEncodingAESKey("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG")
ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("create channel: %v", err)
}
wc := ch.(*WeComAIBotChannel)
wc.ctx, wc.cancel = context.WithCancel(context.Background())
return wc
}
// makeStreamTask creates and registers a streamTask for testing.
func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask {
t.Helper()
task := &streamTask{
StreamID: streamID,
ChatID: chatID,
Deadline: deadline,
answerCh: make(chan string, 1),
}
task.ctx, task.cancel = context.WithCancel(ch.ctx)
ch.taskMu.Lock()
ch.streamTasks[streamID] = task
ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task)
ch.taskMu.Unlock()
return task
}
// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already
// placed its answer in answerCh, getStreamResponse returns a finish=true response
// and fully removes the task.
func TestGetStreamResponse_ImmediateAnswer(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second))
task.answerCh <- "hello from agent"
result := ch.getStreamResponse(task, "ts123", "nonce123")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, exists := ch.streamTasks["stream-1"]
ch.taskMu.RUnlock()
if exists {
t.Error("task should have been removed from streamTasks after normal finish")
}
if !task.Finished {
t.Error("task.Finished should be true after normal finish")
}
}
// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has
// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the
// task alive so the response_url fallback can still deliver the answer.
func TestGetStreamResponse_DeadlinePassed(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond))
result := ch.getStreamResponse(task, "ts456", "nonce456")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, stillStreaming := ch.streamTasks["stream-2"]
ch.taskMu.RUnlock()
if stillStreaming {
t.Error("task should have been removed from streamTasks after deadline")
}
if !task.StreamClosed {
t.Error("task.StreamClosed should be true after deadline")
}
if task.Finished {
t.Error("task.Finished must remain false: agent reply still expected via response_url")
}
}
// TestGetStreamResponse_StillPending verifies that when neither the agent has
// replied nor the deadline has passed, getStreamResponse returns without altering
// task state (client should poll again).
func TestGetStreamResponse_StillPending(t *testing.T) {
ch := makeWebhookChannel(t)
defer ch.cancel()
task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second))
result := ch.getStreamResponse(task, "ts789", "nonce789")
if result == "" {
t.Fatal("expected non-empty encrypted response")
}
ch.taskMu.RLock()
_, exists := ch.streamTasks["stream-3"]
ch.taskMu.RUnlock()
if !exists {
t.Error("pending task should still be in streamTasks")
}
if task.Finished || task.StreamClosed {
t.Error("pending task should not be finished or stream-closed")
}
// Cleanup.
ch.removeTask(task)
}
File diff suppressed because it is too large Load Diff
-295
View File
@@ -1,295 +0,0 @@
package wecom
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing.
func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel {
t.Helper()
cfg := config.WeComAIBotConfig{
Enabled: true,
BotID: "test_bot_id",
}
cfg.SetSecret("test_secret")
ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("create WS channel: %v", err)
}
return ch
}
// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no
// MediaStore has been injected.
func TestStoreWSMedia_NilStore(t *testing.T) {
ch := newTestWSChannel(t)
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg")
if err == nil {
t.Fatal("expected error when no MediaStore is set")
}
}
// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors
// from the media server.
func TestStoreWSMedia_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
}))
defer srv.Close()
ch := newTestWSChannel(t)
ch.SetMediaStore(media.NewFileMediaStore())
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
if err == nil {
t.Fatal("expected error for HTTP 404")
}
}
// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear
// error when the media server cannot be reached.
func TestStoreWSMedia_ServerUnavailable(t *testing.T) {
ch := newTestWSChannel(t)
ch.SetMediaStore(media.NewFileMediaStore())
// Port 1 is reserved and will refuse the connection immediately.
_, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg")
if err == nil {
t.Fatal("expected error for unreachable server")
}
}
// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded,
// a media ref is returned, and the file persists and is readable via Resolve until
// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used.
func TestStoreWSMedia_Success_NoAES(t *testing.T) {
imageData := bytes.Repeat([]byte("x"), 256)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageData)
}))
defer srv.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg")
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ref == "" {
t.Fatal("expected non-empty ref")
}
// File must be accessible after storeWSMedia returns (no premature deletion).
path, err := store.Resolve(ref)
if err != nil {
t.Fatalf("ref should resolve: %v", err)
}
got, err := os.ReadFile(path)
if err != nil {
t.Fatalf("file should exist at %s: %v", path, err)
}
if !bytes.Equal(got, imageData) {
t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData))
}
// ReleaseAll must delete the file (store owns lifecycle).
scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1")
if err := store.ReleaseAll(scope); err != nil {
t.Fatalf("ReleaseAll failed: %v", err)
}
if _, err := os.Stat(path); !os.IsNotExist(err) {
t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err)
}
}
// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with
// different msgIDs do not collide and each resolve to distinct files.
func TestStoreWSMedia_MultipleMessages(t *testing.T) {
imageA := bytes.Repeat([]byte("a"), 64)
imageB := bytes.Repeat([]byte("b"), 64)
srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageA)
}))
defer srvA.Close()
srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(imageB)
}))
defer srvB.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg")
if err != nil {
t.Fatalf("storeWSMedia A: %v", err)
}
refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg")
if err != nil {
t.Fatalf("storeWSMedia B: %v", err)
}
if refA == refB {
t.Fatal("distinct messages must produce distinct refs")
}
pathA, _ := store.Resolve(refA)
pathB, _ := store.Resolve(refB)
if pathA == pathB {
t.Fatal("distinct messages must be stored at distinct paths")
}
gotA, _ := os.ReadFile(pathA)
gotB, _ := os.ReadFile(pathB)
if !bytes.Equal(gotA, imageA) {
t.Errorf("content mismatch for message A")
}
if !bytes.Equal(gotB, imageB) {
t.Errorf("content mismatch for message B")
}
}
// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred
// from the HTTP Content-Type header and the defaultExt fallback is used when the
// type is absent or unrecognized.
func TestStoreWSMedia_ContentTypeExt(t *testing.T) {
tests := []struct {
contentType string
wantExt string
}{
{"image/jpeg", ".jpg"},
{"image/png", ".png"},
{"video/mp4", ".mp4"},
{"application/pdf", ".pdf"},
{"application/zip", ".zip"},
// With parameters stripped.
{"video/mp4; codecs=avc1", ".mp4"},
// Unknown type → falls back to defaultExt.
{"", ""},
{"application/octet-stream", ""},
}
for _, tc := range tests {
got := wsMediaExtFromContentType(tc.contentType)
if got != tc.wantExt {
t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt)
}
}
// End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin.
// The stored file should carry the .mp4 extension, not .bin.
payload := bytes.Repeat([]byte("v"), 128)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "video/mp4")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(payload)
}))
defer srv.Close()
ch := newTestWSChannel(t)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin")
if err != nil {
t.Fatalf("storeWSMedia: %v", err)
}
path, err := store.Resolve(ref)
if err != nil {
t.Fatalf("resolve: %v", err)
}
if ext := path[len(path)-4:]; ext != ".mp4" {
t.Errorf("expected .mp4 extension from Content-Type, got %q", ext)
}
}
// TestSplitWSContent verifies byte-aware splitting of stream content.
func TestSplitWSContent(t *testing.T) {
t.Run("short content is not split", func(t *testing.T) {
chunks := splitWSContent("hello", 20480)
if len(chunks) != 1 || chunks[0] != "hello" {
t.Fatalf("unexpected chunks: %v", chunks)
}
})
t.Run("ASCII content split at byte boundary", func(t *testing.T) {
// Build a string just over the limit.
content := strings.Repeat("a", 20481)
chunks := splitWSContent(content, 20480)
if len(chunks) < 2 {
t.Fatalf("expected >= 2 chunks, got %d", len(chunks))
}
for i, c := range chunks {
if len(c) > 20480 {
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
}
}
// Reassembled content must equal the original (possibly without leading
// whitespace that splitWSContent trims between chunks).
joined := strings.Join(chunks, "")
if len(joined) < len(content)-len(chunks) {
t.Errorf("joined length %d too short (original %d)", len(joined), len(content))
}
})
t.Run("CJK content split within byte limit", func(t *testing.T) {
// Each CJK rune is 3 bytes in UTF-8.
// 7000 CJK chars = 21000 bytes, which exceeds 20480.
content := strings.Repeat("\u4e2d", 7000)
chunks := splitWSContent(content, 20480)
if len(chunks) < 2 {
t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks))
}
for i, c := range chunks {
if len(c) > 20480 {
t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c))
}
// Every chunk must be valid UTF-8.
if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 {
// quick plausibility check — content was pure CJK
}
}
})
}
// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter.
func TestSplitAtByteBoundary(t *testing.T) {
t.Run("ASCII fits in one chunk", func(t *testing.T) {
parts := splitAtByteBoundary("hello world", 100)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d", len(parts))
}
})
t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) {
// 10 CJK characters = 30 bytes; split at 20 bytes.
s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes
parts := splitAtByteBoundary(s, 20)
for i, p := range parts {
if len(p) > 20 {
t.Errorf("part %d has %d bytes, want <= 20", i, len(p))
}
// Must be valid UTF-8 (no torn multi-byte sequences).
for j, r := range p {
if r == '\uFFFD' {
t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j)
}
}
}
})
}
-756
View File
@@ -1,756 +0,0 @@
package wecom
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
const (
wecomAPIBase = "https://qyapi.weixin.qq.com"
)
// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用)
type WeComAppChannel struct {
*channels.BaseChannel
config config.WeComAppConfig
client *http.Client
accessToken string
tokenExpiry time.Time
tokenMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
processedMsgs *MessageDeduplicator
}
// WeComXMLMessage represents the XML message structure from WeCom
type WeComXMLMessage struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
MsgId int64 `xml:"MsgId"`
AgentID int64 `xml:"AgentID"`
PicUrl string `xml:"PicUrl"`
MediaId string `xml:"MediaId"`
Format string `xml:"Format"`
ThumbMediaId string `xml:"ThumbMediaId"`
LocationX float64 `xml:"Location_X"`
LocationY float64 `xml:"Location_Y"`
Scale int `xml:"Scale"`
Label string `xml:"Label"`
Title string `xml:"Title"`
Description string `xml:"Description"`
Url string `xml:"Url"`
Event string `xml:"Event"`
EventKey string `xml:"EventKey"`
}
// WeComTextMessage represents text message for sending
type WeComTextMessage struct {
ToUser string `json:"touser"`
MsgType string `json:"msgtype"`
AgentID int64 `json:"agentid"`
Text struct {
Content string `json:"content"`
} `json:"text"`
Safe int `json:"safe,omitempty"`
}
// WeComMarkdownMessage represents markdown message for sending
type WeComMarkdownMessage struct {
ToUser string `json:"touser"`
MsgType string `json:"msgtype"`
AgentID int64 `json:"agentid"`
Markdown struct {
Content string `json:"content"`
} `json:"markdown"`
}
// WeComImageMessage represents image message for sending
type WeComImageMessage struct {
ToUser string `json:"touser"`
MsgType string `json:"msgtype"`
AgentID int64 `json:"agentid"`
Image struct {
MediaID string `json:"media_id"`
} `json:"image"`
}
// WeComAccessTokenResponse represents the access token API response
type WeComAccessTokenResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}
// WeComSendMessageResponse represents the send message API response
type WeComSendMessageResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
InvalidUser string `json:"invaliduser"`
InvalidParty string `json:"invalidparty"`
InvalidTag string `json:"invalidtag"`
}
// PKCS7Padding adds PKCS7 padding
type PKCS7Padding struct{}
// NewWeComAppChannel creates a new WeCom App channel instance
func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) {
if cfg.CorpID == "" || cfg.CorpSecret() == "" || cfg.AgentID == 0 {
return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required")
}
base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom,
channels.WithMaxMessageLength(2048),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
// Client timeout must be >= the configured ReplyTimeout so the
// per-request context deadline is always the effective limit.
clientTimeout := 30 * time.Second
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
clientTimeout = d
}
ctx, cancel := context.WithCancel(context.Background())
return &WeComAppChannel{
BaseChannel: base,
config: cfg,
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
// Name returns the channel name
func (c *WeComAppChannel) Name() string {
return "wecom_app"
}
// Start initializes the WeCom App channel
func (c *WeComAppChannel) Start(ctx context.Context) error {
logger.InfoC("wecom_app", "Starting WeCom App channel...")
// Cancel the context created in the constructor to avoid a resource leak.
if c.cancel != nil {
c.cancel()
}
c.ctx, c.cancel = context.WithCancel(ctx)
// Get initial access token
if err := c.refreshAccessToken(); err != nil {
logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{
"error": err.Error(),
})
}
// Start token refresh goroutine
go c.tokenRefreshLoop()
c.SetRunning(true)
logger.InfoC("wecom_app", "WeCom App channel started")
return nil
}
// Stop gracefully stops the WeCom App channel
func (c *WeComAppChannel) Stop(ctx context.Context) error {
logger.InfoC("wecom_app", "Stopping WeCom App channel...")
if c.cancel != nil {
c.cancel()
}
c.SetRunning(false)
logger.InfoC("wecom_app", "WeCom App channel stopped")
return nil
}
// Send sends a message to WeCom user proactively using access token
func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
accessToken := c.getAccessToken()
if accessToken == "" {
return fmt.Errorf("no valid access token available")
}
logger.DebugCF("wecom_app", "Sending message", map[string]any{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content)
}
// SendMedia implements the channels.MediaSender interface.
func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
accessToken := c.getAccessToken()
if accessToken == "" {
return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary)
}
store := c.GetMediaStore()
if store == nil {
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
}
for _, part := range msg.Parts {
localPath, err := store.Resolve(part.Ref)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{
"ref": part.Ref,
"error": err.Error(),
})
continue
}
// Map part type to WeCom media type
var mediaType string
switch part.Type {
case "image":
mediaType = "image"
case "audio":
mediaType = "voice"
case "video":
mediaType = "video"
default:
mediaType = "file"
}
// Upload media to get media_id
mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{
"type": mediaType,
"error": err.Error(),
})
// Fallback: send caption as text
if part.Caption != "" {
_ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption)
}
continue
}
// Send media message using the media_id
if mediaType == "image" {
err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID)
} else {
// For non-image types, send as text fallback with caption
caption := part.Caption
if caption == "" {
caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename)
}
err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption)
}
if err != nil {
return err
}
}
return nil
}
// uploadMedia uploads a local file to WeCom temporary media storage.
func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) {
apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s",
wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType))
file, err := os.Open(localPath)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
filename := filepath.Base(localPath)
formFile, err := writer.CreateFormFile("media", filename)
if err != nil {
return "", fmt.Errorf("failed to create form file: %w", err)
}
if _, err = io.Copy(formFile, file); err != nil {
return "", fmt.Errorf("failed to copy file content: %w", err)
}
writer.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := c.client.Do(req)
if err != nil {
return "", channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom upload error response: %w", readErr),
)
}
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom upload error: %s", string(respBody)),
)
}
var result struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
MediaID string `json:"media_id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse upload response: %w", err)
}
if result.ErrCode != 0 {
return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
}
return result.MediaID, nil
}
// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
timeout := c.config.ReplyTimeout
if timeout <= 0 {
timeout = 5
}
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom_app error response: %w", readErr),
)
}
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom_app API error: %s", string(respBody)),
)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
var sendResp WeComSendMessageResponse
if err := json.Unmarshal(respBody, &sendResp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if sendResp.ErrCode != 0 {
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
}
return nil
}
// sendImageMessage sends an image message using a media_id.
func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
msg := WeComImageMessage{
ToUser: userID,
MsgType: "image",
AgentID: c.config.AgentID,
}
msg.Image.MediaID = mediaID
return c.sendWeComMessage(ctx, accessToken, msg)
}
// WebhookPath returns the path for registering on the shared HTTP server.
func (c *WeComAppChannel) WebhookPath() string {
if c.config.WebhookPath != "" {
return c.config.WebhookPath
}
return "/webhook/wecom-app"
}
// ServeHTTP implements http.Handler for the shared HTTP server.
func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.handleWebhook(w, r)
}
// HealthPath returns the health check endpoint path.
func (c *WeComAppChannel) HealthPath() string {
return "/health/wecom-app"
}
// HealthHandler handles health check requests.
func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
c.handleHealth(w, r)
}
// handleWebhook handles incoming webhook requests from WeCom
func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Log all incoming requests for debugging
logger.DebugCF("wecom_app", "Received webhook request", map[string]any{
"method": r.Method,
"url": r.URL.String(),
"path": r.URL.Path,
"query": r.URL.RawQuery,
})
if r.Method == http.MethodGet {
// Handle verification request
c.handleVerification(ctx, w, r)
return
}
if r.Method == http.MethodPost {
// Handle message callback
c.handleMessageCallback(ctx, w, r)
return
}
logger.WarnCF("wecom_app", "Method not allowed", map[string]any{
"method": r.Method,
})
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// handleVerification handles the URL verification request from WeCom
func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
msgSignature := query.Get("msg_signature")
timestamp := query.Get("timestamp")
nonce := query.Get("nonce")
echostr := query.Get("echostr")
logger.DebugCF("wecom_app", "Handling verification request", map[string]any{
"msg_signature": msgSignature,
"timestamp": timestamp,
"nonce": nonce,
"echostr": echostr,
"corp_id": c.config.CorpID,
})
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
logger.ErrorC("wecom_app", "Missing parameters in verification request")
http.Error(w, "Missing parameters", http.StatusBadRequest)
return
}
// Verify signature
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) {
logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{
"token": c.config.Token(),
"msg_signature": msgSignature,
"timestamp": timestamp,
"nonce": nonce,
})
http.Error(w, "Invalid signature", http.StatusForbidden)
return
}
logger.DebugC("wecom_app", "Signature verification passed")
// Decrypt echostr with CorpID verification
// For WeCom App (自建应用), receiveid should be corp_id
logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{
"encoding_aes_key": c.config.EncodingAESKey(),
"corp_id": c.config.CorpID,
})
decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), c.config.CorpID)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{
"error": err.Error(),
"encoding_aes_key": c.config.EncodingAESKey,
"corp_id": c.config.CorpID,
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{
"decrypted": decryptedEchoStr,
})
// Remove BOM and whitespace as per WeCom documentation
// The response must be plain text without quotes, BOM, or newlines
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
w.Write([]byte(decryptedEchoStr))
}
// handleMessageCallback handles incoming messages from WeCom
func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
msgSignature := query.Get("msg_signature")
timestamp := query.Get("timestamp")
nonce := query.Get("nonce")
if msgSignature == "" || timestamp == "" || nonce == "" {
http.Error(w, "Missing parameters", http.StatusBadRequest)
return
}
// Read request body
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read body", http.StatusBadRequest)
return
}
defer r.Body.Close()
// Parse XML to get encrypted message
var encryptedMsg struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
Encrypt string `xml:"Encrypt"`
AgentID string `xml:"AgentID"`
}
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{
"error": err.Error(),
})
http.Error(w, "Invalid XML", http.StatusBadRequest)
return
}
// Verify signature
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
logger.WarnC("wecom_app", "Message signature verification failed")
http.Error(w, "Invalid signature", http.StatusForbidden)
return
}
// Decrypt message with CorpID verification
// For WeCom App (自建应用), receiveid should be corp_id
decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), c.config.CorpID)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{
"error": err.Error(),
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
// Parse decrypted XML message
var msg WeComXMLMessage
if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{
"error": err.Error(),
})
http.Error(w, "Invalid message format", http.StatusBadRequest)
return
}
// Process the message with the channel's long-lived context (not the HTTP
// request context, which is canceled as soon as we return the response).
go c.processMessage(c.ctx, msg)
// Return success response immediately
// WeCom App requires response within configured timeout (default 5 seconds)
w.Write([]byte("success"))
}
// processMessage processes the received message
func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) {
// Skip non-text messages for now (can be extended)
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" {
logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{
"msg_type": msg.MsgType,
})
return
}
// Message deduplication: Use msg_id to prevent duplicate processing
// As per WeCom documentation, use msg_id for deduplication
msgID := fmt.Sprintf("%d", msg.MsgId)
if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
senderID := msg.FromUserName
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
// Build metadata
// WeCom App only supports direct messages (private chat)
peer := bus.Peer{Kind: "direct", ID: senderID}
messageID := fmt.Sprintf("%d", msg.MsgId)
metadata := map[string]string{
"msg_type": msg.MsgType,
"msg_id": fmt.Sprintf("%d", msg.MsgId),
"agent_id": fmt.Sprintf("%d", msg.AgentID),
"platform": "wecom_app",
"media_id": msg.MediaId,
"create_time": fmt.Sprintf("%d", msg.CreateTime),
}
content := msg.Content
logger.DebugCF("wecom_app", "Received message", map[string]any{
"sender_id": senderID,
"msg_type": msg.MsgType,
"preview": utils.Truncate(content, 50),
})
// Build sender info
appSender := bus.SenderInfo{
Platform: "wecom",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
}
// Handle the message through the base channel
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender)
}
// tokenRefreshLoop periodically refreshes the access token
func (c *WeComAppChannel) tokenRefreshLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
if err := c.refreshAccessToken(); err != nil {
logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{
"error": err.Error(),
})
}
}
}
}
// refreshAccessToken gets a new access token from WeCom API
func (c *WeComAppChannel) refreshAccessToken() error {
apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret()))
resp, err := http.Get(apiURL)
if err != nil {
return fmt.Errorf("failed to request access token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
var tokenResp WeComAccessTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if tokenResp.ErrCode != 0 {
return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode)
}
c.tokenMu.Lock()
c.accessToken = tokenResp.AccessToken
c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early
c.tokenMu.Unlock()
logger.DebugC("wecom_app", "Access token refreshed successfully")
return nil
}
// getAccessToken returns the current valid access token
func (c *WeComAppChannel) getAccessToken() string {
c.tokenMu.RLock()
defer c.tokenMu.RUnlock()
if time.Now().After(c.tokenExpiry) {
return ""
}
return c.accessToken
}
// sendTextMessage sends a text message to a user.
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
msg := WeComTextMessage{
ToUser: userID,
MsgType: "text",
AgentID: c.config.AgentID,
}
msg.Text.Content = content
return c.sendWeComMessage(ctx, accessToken, msg)
}
// handleHealth handles health check requests
func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
status := map[string]any{
"status": "ok",
"running": c.IsRunning(),
"has_token": c.getAccessToken() != "",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
File diff suppressed because it is too large Load Diff
-499
View File
@@ -1,499 +0,0 @@
package wecom
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人)
// Uses webhook callback mode - simpler than WeCom App but only supports passive replies
type WeComBotChannel struct {
*channels.BaseChannel
config config.WeComConfig
client *http.Client
ctx context.Context
cancel context.CancelFunc
processedMsgs *MessageDeduplicator
}
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
type WeComBotMessage struct {
MsgID string `json:"msgid"`
AIBotID string `json:"aibotid"`
ChatID string `json:"chatid"` // Session ID, only present for group chats
ChatType string `json:"chattype"` // "single" for DM, "group" for group chat
From struct {
UserID string `json:"userid"`
} `json:"from"`
ResponseURL string `json:"response_url"`
MsgType string `json:"msgtype"` // text, image, voice, file, mixed
Text struct {
Content string `json:"content"`
} `json:"text"`
Image struct {
URL string `json:"url"`
} `json:"image"`
Voice struct {
Content string `json:"content"` // Voice to text content
} `json:"voice"`
File struct {
URL string `json:"url"`
} `json:"file"`
Mixed struct {
MsgItem []struct {
MsgType string `json:"msgtype"`
Text struct {
Content string `json:"content"`
} `json:"text"`
Image struct {
URL string `json:"url"`
} `json:"image"`
} `json:"msg_item"`
} `json:"mixed"`
Quote struct {
MsgType string `json:"msgtype"`
Text struct {
Content string `json:"content"`
} `json:"text"`
} `json:"quote"`
}
// WeComBotReplyMessage represents the reply message structure
type WeComBotReplyMessage struct {
MsgType string `json:"msgtype"`
Text struct {
Content string `json:"content"`
} `json:"text,omitempty"`
}
// NewWeComBotChannel creates a new WeCom Bot channel instance
func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) {
if cfg.Token() == "" || cfg.WebhookURL == "" {
return nil, fmt.Errorf("wecom token and webhook_url are required")
}
base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom,
channels.WithMaxMessageLength(2048),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
// Client timeout must be >= the configured ReplyTimeout so the
// per-request context deadline is always the effective limit.
clientTimeout := 30 * time.Second
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
clientTimeout = d
}
ctx, cancel := context.WithCancel(context.Background())
return &WeComBotChannel{
BaseChannel: base,
config: cfg,
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
// Name returns the channel name
func (c *WeComBotChannel) Name() string {
return "wecom"
}
// Start initializes the WeCom Bot channel
func (c *WeComBotChannel) Start(ctx context.Context) error {
logger.InfoC("wecom", "Starting WeCom Bot channel...")
// Cancel the context created in the constructor to avoid a resource leak.
if c.cancel != nil {
c.cancel()
}
c.ctx, c.cancel = context.WithCancel(ctx)
c.SetRunning(true)
logger.InfoC("wecom", "WeCom Bot channel started")
return nil
}
// Stop gracefully stops the WeCom Bot channel
func (c *WeComBotChannel) Stop(ctx context.Context) error {
logger.InfoC("wecom", "Stopping WeCom Bot channel...")
if c.cancel != nil {
c.cancel()
}
c.SetRunning(false)
logger.InfoC("wecom", "WeCom Bot channel stopped")
return nil
}
// Send sends a message to WeCom user via webhook API
// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message
// For delayed responses, we use the webhook URL
func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
logger.DebugCF("wecom", "Sending message via webhook", map[string]any{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
return c.sendWebhookReply(ctx, msg.ChatID, msg.Content)
}
// WebhookPath returns the path for registering on the shared HTTP server.
func (c *WeComBotChannel) WebhookPath() string {
if c.config.WebhookPath != "" {
return c.config.WebhookPath
}
return "/webhook/wecom"
}
// ServeHTTP implements http.Handler for the shared HTTP server.
func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.handleWebhook(w, r)
}
// HealthPath returns the health check endpoint path.
func (c *WeComBotChannel) HealthPath() string {
return "/health/wecom"
}
// HealthHandler handles health check requests.
func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
c.handleHealth(w, r)
}
// handleWebhook handles incoming webhook requests from WeCom
func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.Method == http.MethodGet {
// Handle verification request
c.handleVerification(ctx, w, r)
return
}
if r.Method == http.MethodPost {
// Handle message callback
c.handleMessageCallback(ctx, w, r)
return
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// handleVerification handles the URL verification request from WeCom
func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
msgSignature := query.Get("msg_signature")
timestamp := query.Get("timestamp")
nonce := query.Get("nonce")
echostr := query.Get("echostr")
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
http.Error(w, "Missing parameters", http.StatusBadRequest)
return
}
// Verify signature
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, echostr) {
logger.WarnC("wecom", "Signature verification failed")
http.Error(w, "Invalid signature", http.StatusForbidden)
return
}
// Decrypt echostr
// For AIBOT (智能机器人), receiveid should be empty string ""
// Reference: https://developer.work.weixin.qq.com/document/path/101033
decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey(), "")
if err != nil {
logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{
"error": err.Error(),
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
// Remove BOM and whitespace as per WeCom documentation
// The response must be plain text without quotes, BOM, or newlines
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
w.Write([]byte(decryptedEchoStr))
}
// handleMessageCallback handles incoming messages from WeCom
func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
msgSignature := query.Get("msg_signature")
timestamp := query.Get("timestamp")
nonce := query.Get("nonce")
if msgSignature == "" || timestamp == "" || nonce == "" {
http.Error(w, "Missing parameters", http.StatusBadRequest)
return
}
// Read request body
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read body", http.StatusBadRequest)
return
}
defer r.Body.Close()
// Parse XML to get encrypted message
var encryptedMsg struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
Encrypt string `xml:"Encrypt"`
AgentID string `xml:"AgentID"`
}
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{
"error": err.Error(),
})
http.Error(w, "Invalid XML", http.StatusBadRequest)
return
}
// Verify signature
if !verifySignature(c.config.Token(), msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
logger.WarnC("wecom", "Message signature verification failed")
http.Error(w, "Invalid signature", http.StatusForbidden)
return
}
// Decrypt message
// For AIBOT (智能机器人), receiveid should be empty string ""
// Reference: https://developer.work.weixin.qq.com/document/path/101033
decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey(), "")
if err != nil {
logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{
"error": err.Error(),
})
http.Error(w, "Decryption failed", http.StatusInternalServerError)
return
}
// Parse decrypted JSON message (AIBOT uses JSON format)
var msg WeComBotMessage
if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{
"error": err.Error(),
})
http.Error(w, "Invalid message format", http.StatusBadRequest)
return
}
// Process the message with the channel's long-lived context (not the HTTP
// request context, which is canceled as soon as we return the response).
go c.processMessage(c.ctx, msg)
// Return success response immediately
// WeCom Bot requires response within configured timeout (default 5 seconds)
w.Write([]byte("success"))
}
// processMessage processes the received message
func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) {
// Skip unsupported message types
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" &&
msg.MsgType != "mixed" {
logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{
"msg_type": msg.MsgType,
})
return
}
// Message deduplication: Use msg_id to prevent duplicate processing
msgID := msg.MsgID
if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
senderID := msg.From.UserID
// Determine if this is a group chat or direct message
// ChatType: "single" for DM, "group" for group chat
isGroupChat := msg.ChatType == "group"
var chatID, peerKind, peerID string
if isGroupChat {
// Group chat: use ChatID as chatID and peer_id
chatID = msg.ChatID
peerKind = "group"
peerID = msg.ChatID
} else {
// Direct message: use senderID as chatID and peer_id
chatID = senderID
peerKind = "direct"
peerID = senderID
}
// Extract content based on message type
var content string
switch msg.MsgType {
case "text":
content = msg.Text.Content
case "voice":
content = msg.Voice.Content // Voice to text content
case "mixed":
// For mixed messages, concatenate text items
for _, item := range msg.Mixed.MsgItem {
if item.MsgType == "text" {
content += item.Text.Content
}
}
case "image", "file":
// For image and file, we don't have text content
content = ""
}
// Build metadata
peer := bus.Peer{Kind: peerKind, ID: peerID}
// In group chats, apply unified group trigger filtering
if isGroupChat {
respond, cleaned := c.ShouldRespondInGroup(false, content)
if !respond {
return
}
content = cleaned
}
metadata := map[string]string{
"msg_type": msg.MsgType,
"msg_id": msg.MsgID,
"platform": "wecom",
"response_url": msg.ResponseURL,
}
if isGroupChat {
metadata["chat_id"] = msg.ChatID
metadata["sender_id"] = senderID
}
logger.DebugCF("wecom", "Received message", map[string]any{
"sender_id": senderID,
"msg_type": msg.MsgType,
"peer_kind": peerKind,
"is_group_chat": isGroupChat,
"preview": utils.Truncate(content, 50),
})
// Build sender info
sender := bus.SenderInfo{
Platform: "wecom",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
}
if !c.IsAllowedSender(sender) {
return
}
// Handle the message through the base channel
c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender)
}
// sendWebhookReply sends a reply using the webhook URL
func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error {
reply := WeComBotReplyMessage{
MsgType: "text",
}
reply.Text.Content = content
jsonData, err := json.Marshal(reply)
if err != nil {
return fmt.Errorf("failed to marshal reply: %w", err)
}
// Use configurable timeout (default 5 seconds)
timeout := c.config.ReplyTimeout
if timeout <= 0 {
timeout = 5
}
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading webhook error response: %w", readErr),
)
}
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("webhook API error: %s", string(body)),
)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
// Check response
var result struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(body, &result); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
if result.ErrCode != 0 {
return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
}
return nil
}
// handleHealth handles health check requests
func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
status := map[string]any{
"status": "ok",
"running": c.IsRunning(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
-734
View File
@@ -1,734 +0,0 @@
package wecom
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
// generateTestAESKey generates a valid test AES key
func generateTestAESKey() string {
// AES key needs to be 32 bytes (256 bits) for AES-256
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
// Return base64 encoded key without padding
return base64.StdEncoding.EncodeToString(key)[:43]
}
// encryptTestMessage encrypts a message for testing (AIBOT JSON format)
func encryptTestMessage(message, aesKey string) (string, error) {
// Decode AES key
key, err := base64.StdEncoding.DecodeString(aesKey + "=")
if err != nil {
return "", err
}
// Prepare message: random(16) + msg_len(4) + msg + receiveid
random := make([]byte, 0, 16)
for i := range 16 {
random = append(random, byte(i))
}
msgBytes := []byte(message)
receiveID := []byte("test_aibot_id")
msgLen := uint32(len(msgBytes))
lenBytes := make([]byte, 4)
binary.BigEndian.PutUint32(lenBytes, msgLen)
plainText := append(random, lenBytes...)
plainText = append(plainText, msgBytes...)
plainText = append(plainText, receiveID...)
// PKCS7 padding
blockSize := aes.BlockSize
padding := blockSize - len(plainText)%blockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
plainText = append(plainText, padText...)
// Encrypt
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize])
cipherText := make([]byte, len(plainText))
mode.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// generateSignature generates a signature for testing
func generateSignature(token, timestamp, nonce, msgEncrypt string) string {
params := []string{token, timestamp, nonce, msgEncrypt}
sort.Strings(params)
str := strings.Join(params, "")
hash := sha1.Sum([]byte(str))
return fmt.Sprintf("%x", hash)
}
func TestNewWeComBotChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing token", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
_, err := NewWeComBotChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing token, got nil")
}
})
t.Run("missing webhook_url", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = ""
_, err := NewWeComBotChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing webhook_url, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
cfg.AllowFrom = []string{"user1", "user2"}
ch, err := NewWeComBotChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "wecom" {
t.Errorf("Name() = %q, want %q", ch.Name(), "wecom")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
}
func TestWeComBotChannelIsAllowed(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("empty allowlist allows all", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
cfg.AllowFrom = []string{}
ch, _ := NewWeComBotChannel(cfg, msgBus)
if !ch.IsAllowed("any_user") {
t.Error("empty allowlist should allow all users")
}
})
t.Run("allowlist restricts users", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
cfg.AllowFrom = []string{"allowed_user"}
ch, _ := NewWeComBotChannel(cfg, msgBus)
if !ch.IsAllowed("allowed_user") {
t.Error("allowed user should pass allowlist check")
}
if ch.IsAllowed("blocked_user") {
t.Error("non-allowed user should be blocked")
}
})
}
func TestWeComBotVerifySignature(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
ch, _ := NewWeComBotChannel(cfg, msgBus)
t.Run("valid signature", func(t *testing.T) {
timestamp := "1234567890"
nonce := "test_nonce"
msgEncrypt := "test_message"
expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt)
if !verifySignature(ch.config.Token(), expectedSig, timestamp, nonce, msgEncrypt) {
t.Error("valid signature should pass verification")
}
})
t.Run("invalid signature", func(t *testing.T) {
timestamp := "1234567890"
nonce := "test_nonce"
msgEncrypt := "test_message"
if verifySignature(ch.config.Token(), "invalid_sig", timestamp, nonce, msgEncrypt) {
t.Error("invalid signature should fail verification")
}
})
t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) {
cfgEmpty := config.WeComConfig{}
cfgEmpty.SetToken("")
cfgEmpty.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
chEmpty := &WeComBotChannel{
config: cfgEmpty,
}
if verifySignature(chEmpty.config.Token(), "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should reject verification (fail-closed)")
}
})
}
func TestWeComBotDecryptMessage(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("decrypt without AES key", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
cfg.SetEncodingAESKey("")
ch, _ := NewWeComBotChannel(cfg, msgBus)
// Without AES key, message should be base64 decoded only
plainText := "hello world"
encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
result, err := decryptMessage(encoded, ch.config.EncodingAESKey())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != plainText {
t.Errorf("decryptMessage() = %q, want %q", result, plainText)
}
})
t.Run("decrypt with AES key", func(t *testing.T) {
aesKey := generateTestAESKey()
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
cfg.SetEncodingAESKey(aesKey)
ch, _ := NewWeComBotChannel(cfg, msgBus)
originalMsg := "<xml><Content>Hello</Content></xml>"
encrypted, err := encryptTestMessage(originalMsg, aesKey)
if err != nil {
t.Fatalf("failed to encrypt test message: %v", err)
}
result, err := decryptMessage(encrypted, ch.config.EncodingAESKey())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != originalMsg {
t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg)
}
})
t.Run("invalid base64", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
cfg.SetEncodingAESKey("")
ch, _ := NewWeComBotChannel(cfg, msgBus)
_, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey())
if err == nil {
t.Error("expected error for invalid base64, got nil")
}
})
t.Run("invalid AES key", func(t *testing.T) {
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
cfg.SetEncodingAESKey("invalid_key")
ch, _ := NewWeComBotChannel(cfg, msgBus)
_, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey())
if err == nil {
t.Error("expected error for invalid AES key, got nil")
}
})
}
func TestWeComBotPKCS7Unpad(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
}{
{
name: "empty input",
input: []byte{},
expected: []byte{},
},
{
name: "valid padding 3 bytes",
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
expected: []byte("hello"),
},
{
name: "valid padding 16 bytes (full block)",
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("123456789012345"),
},
{
name: "invalid padding larger than data",
input: []byte{20},
expected: nil, // should return error
},
{
name: "invalid padding zero",
input: append([]byte("test"), byte(0)),
expected: nil, // should return error
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs7Unpad(tt.input)
if tt.expected == nil {
// This case should return an error
if err == nil {
t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
}
return
}
if err != nil {
t.Errorf("pkcs7Unpad() unexpected error: %v", err)
return
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
}
})
}
}
func TestWeComBotHandleVerification(t *testing.T) {
msgBus := bus.NewMessageBus()
aesKey := generateTestAESKey()
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.SetEncodingAESKey(aesKey)
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
ch, _ := NewWeComBotChannel(cfg, msgBus)
t.Run("valid verification request", func(t *testing.T) {
echostr := "test_echostr_123"
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr)
req := httptest.NewRequest(
http.MethodGet,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
nil,
)
w := httptest.NewRecorder()
ch.handleVerification(context.Background(), w, req)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
if w.Body.String() != echostr {
t.Errorf("response body = %q, want %q", w.Body.String(), echostr)
}
})
t.Run("missing parameters", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig&timestamp=ts", nil)
w := httptest.NewRecorder()
ch.handleVerification(context.Background(), w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
}
})
t.Run("invalid signature", func(t *testing.T) {
echostr := "test_echostr"
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
timestamp := "1234567890"
nonce := "test_nonce"
req := httptest.NewRequest(
http.MethodGet,
"/webhook/wecom?msg_signature=invalid_sig&timestamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
nil,
)
w := httptest.NewRecorder()
ch.handleVerification(context.Background(), w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
}
})
}
func TestWeComBotHandleMessageCallback(t *testing.T) {
msgBus := bus.NewMessageBus()
aesKey := generateTestAESKey()
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.SetEncodingAESKey(aesKey)
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
ch, _ := NewWeComBotChannel(cfg, msgBus)
runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
t.Helper()
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
}{
Encrypt: encrypted,
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encrypted)
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
return w
}
t.Run("valid direct message callback", func(t *testing.T) {
w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_123",
"aibotid": "test_aibot_id",
"chattype": "single",
"from": {"userid": "user123"},
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello World"}
}`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
if w.Body.String() != "success" {
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
}
})
t.Run("valid group message callback", func(t *testing.T) {
w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_456",
"aibotid": "test_aibot_id",
"chatid": "group_chat_id_123",
"chattype": "group",
"from": {"userid": "user456"},
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello Group"}
}`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
if w.Body.String() != "success" {
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
}
})
t.Run("missing parameters", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
}
})
t.Run("invalid XML", func(t *testing.T) {
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, "")
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce,
strings.NewReader("invalid xml"),
)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
}
})
t.Run("invalid signature", func(t *testing.T) {
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
}{
Encrypt: "encrypted_data",
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
timestamp := "1234567890"
nonce := "test_nonce"
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature=invalid_sig&timestamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
ch.handleMessageCallback(context.Background(), w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
}
})
}
func TestWeComBotProcessMessage(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
ch, _ := NewWeComBotChannel(cfg, msgBus)
t.Run("process direct text message", func(t *testing.T) {
msg := WeComBotMessage{
MsgID: "test_msg_id_123",
AIBotID: "test_aibot_id",
ChatType: "single",
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
MsgType: "text",
}
msg.From.UserID = "user123"
msg.Text.Content = "Hello World"
// Should not panic
ch.processMessage(context.Background(), msg)
})
t.Run("process group text message", func(t *testing.T) {
msg := WeComBotMessage{
MsgID: "test_msg_id_456",
AIBotID: "test_aibot_id",
ChatID: "group_chat_id_123",
ChatType: "group",
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
MsgType: "text",
}
msg.From.UserID = "user456"
msg.Text.Content = "Hello Group"
// Should not panic
ch.processMessage(context.Background(), msg)
})
t.Run("process voice message", func(t *testing.T) {
msg := WeComBotMessage{
MsgID: "test_msg_id_789",
AIBotID: "test_aibot_id",
ChatType: "single",
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
MsgType: "voice",
}
msg.From.UserID = "user123"
msg.Voice.Content = "Voice message text"
// Should not panic
ch.processMessage(context.Background(), msg)
})
t.Run("skip unsupported message type", func(t *testing.T) {
msg := WeComBotMessage{
MsgID: "test_msg_id_000",
AIBotID: "test_aibot_id",
ChatType: "single",
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
MsgType: "video",
}
msg.From.UserID = "user123"
// Should not panic
ch.processMessage(context.Background(), msg)
})
}
func TestWeComBotHandleWebhook(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
ch, _ := NewWeComBotChannel(cfg, msgBus)
t.Run("GET request calls verification", func(t *testing.T) {
echostr := "test_echostr"
encoded := base64.StdEncoding.EncodeToString([]byte(echostr))
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encoded)
req := httptest.NewRequest(
http.MethodGet,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded,
nil,
)
w := httptest.NewRecorder()
ch.handleWebhook(w, req)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("POST request calls message callback", func(t *testing.T) {
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
}{
Encrypt: base64.StdEncoding.EncodeToString([]byte("test")),
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt)
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"&timestamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
ch.handleWebhook(w, req)
// Should not be method not allowed
if w.Code == http.StatusMethodNotAllowed {
t.Error("POST request should not return Method Not Allowed")
}
})
t.Run("unsupported method", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil)
w := httptest.NewRecorder()
ch.handleWebhook(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed)
}
})
}
func TestWeComBotHandleHealth(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := config.WeComConfig{}
cfg.SetToken("test_token")
cfg.WebhookURL = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test"
ch, _ := NewWeComBotChannel(cfg, msgBus)
req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil)
w := httptest.NewRecorder()
ch.handleHealth(w, req)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
contentType := w.Header().Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Content-Type = %q, want %q", contentType, "application/json")
}
body := w.Body.String()
if !strings.Contains(body, "status") || !strings.Contains(body, "running") {
t.Errorf("response body should contain status and running fields, got: %s", body)
}
}
func TestWeComBotReplyMessage(t *testing.T) {
msg := WeComBotReplyMessage{
MsgType: "text",
}
msg.Text.Content = "Hello World"
if msg.MsgType != "text" {
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
}
if msg.Text.Content != "Hello World" {
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
}
}
func TestWeComBotMessageStructure(t *testing.T) {
jsonData := `{
"msgid": "test_msg_id_123",
"aibotid": "test_aibot_id",
"chatid": "group_chat_id_123",
"chattype": "group",
"from": {"userid": "user123"},
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello World"}
}`
var msg WeComBotMessage
err := json.Unmarshal([]byte(jsonData), &msg)
if err != nil {
t.Fatalf("failed to unmarshal JSON: %v", err)
}
if msg.MsgID != "test_msg_id_123" {
t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123")
}
if msg.AIBotID != "test_aibot_id" {
t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id")
}
if msg.ChatID != "group_chat_id_123" {
t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123")
}
if msg.ChatType != "group" {
t.Errorf("ChatType = %q, want %q", msg.ChatType, "group")
}
if msg.From.UserID != "user123" {
t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123")
}
if msg.MsgType != "text" {
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
}
if msg.Text.Content != "Hello World" {
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
}
}
-199
View File
@@ -1,199 +0,0 @@
package wecom
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"math/big"
"sort"
"strings"
)
// blockSize is the PKCS7 block size used by WeCom (32)
const blockSize = 32
// computeSignature computes the WeCom message signature from the given parameters.
// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest.
func computeSignature(token, timestamp, nonce, encrypt string) string {
params := []string{token, timestamp, nonce, encrypt}
sort.Strings(params)
str := strings.Join(params, "")
hash := sha1.Sum([]byte(str))
return fmt.Sprintf("%x", hash)
}
// verifySignature verifies the message signature for WeCom
// This is a common function used by both WeCom Bot and WeCom App
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
if token == "" {
return false
}
return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
}
// decryptMessage decrypts the encrypted message using AES
// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id
func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) {
return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "")
}
// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid
// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification.
func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) {
if encodingAESKey == "" {
// No encryption, return as is (base64 decode)
decoded, err := base64.StdEncoding.DecodeString(encryptedMsg)
if err != nil {
return "", err
}
return string(decoded), nil
}
aesKey, err := decodeWeComAESKey(encodingAESKey)
if err != nil {
return "", err
}
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
if err != nil {
return "", fmt.Errorf("failed to decode message: %w", err)
}
plainText, err := decryptAESCBC(aesKey, cipherText)
if err != nil {
return "", err
}
return unpackWeComFrame(plainText, receiveid)
}
// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
// appended automatically) and validates that the result is exactly 32 bytes.
// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return nil, fmt.Errorf("failed to decode AES key: %w", err)
}
if len(aesKey) != 32 {
return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
}
return aesKey, nil
}
// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
// plaintext to a multiple of aes.BlockSize before calling.
func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
iv := aesKey[:aes.BlockSize]
ciphertext := make([]byte, len(plaintext))
cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
return ciphertext, nil
}
// packWeComFrame builds the WeCom wire format:
//
// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
func packWeComFrame(msg, receiveid string) ([]byte, error) {
randomBytes := make([]byte, 16)
for i := range 16 {
n, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
return nil, fmt.Errorf("failed to generate random: %w", err)
}
randomBytes[i] = byte('0' + n.Int64())
}
msgBytes := []byte(msg)
msgLenBytes := make([]byte, 4)
binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
var buf bytes.Buffer
buf.Write(randomBytes)
buf.Write(msgLenBytes)
buf.Write(msgBytes)
buf.WriteString(receiveid)
return buf.Bytes(), nil
}
// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
// If receiveid is non-empty it verifies the frame's trailing receiveid field.
func unpackWeComFrame(data []byte, receiveid string) (string, error) {
if len(data) < 20 {
return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
}
msgLen := binary.BigEndian.Uint32(data[16:20])
if int(msgLen) > len(data)-20 {
return "", fmt.Errorf("invalid message length: %d", msgLen)
}
msg := data[20 : 20+msgLen]
if receiveid != "" && len(data) > 20+int(msgLen) {
actualReceiveID := string(data[20+msgLen:])
if actualReceiveID != receiveid {
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
}
}
return string(msg), nil
}
// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
if len(ciphertext) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
}
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
}
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
iv := aesKey[:aes.BlockSize]
plaintext := make([]byte, len(ciphertext))
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
plaintext, err = pkcs7Unpad(plaintext)
if err != nil {
return nil, fmt.Errorf("failed to unpad: %w", err)
}
return plaintext, nil
}
// pkcs7Pad adds PKCS7 padding
func pkcs7Pad(data []byte, blockSize int) []byte {
padding := blockSize - (len(data) % blockSize)
if padding == 0 {
padding = blockSize
}
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padText...)
}
// pkcs7Unpad removes PKCS7 padding with validation
func pkcs7Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {
return data, nil
}
padding := int(data[len(data)-1])
// WeCom uses 32-byte block size for PKCS7 padding
if padding == 0 || padding > blockSize {
return nil, fmt.Errorf("invalid padding size: %d", padding)
}
if padding > len(data) {
return nil, fmt.Errorf("padding size larger than data")
}
// Verify all padding bytes
for i := range padding {
if data[len(data)-1-i] != byte(padding) {
return nil, fmt.Errorf("invalid padding byte at position %d", i)
}
}
return data[:len(data)-padding], nil
}
-54
View File
@@ -1,54 +0,0 @@
package wecom
import "sync"
const wecomMaxProcessedMessages = 1000
// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
// messages without causing "amnesia cliffs" when the limit is reached.
type MessageDeduplicator struct {
mu sync.Mutex
msgs map[string]bool
ring []string
idx int
max int
}
// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
if maxEntries <= 0 {
maxEntries = wecomMaxProcessedMessages
}
return &MessageDeduplicator{
msgs: make(map[string]bool, maxEntries),
ring: make([]string, maxEntries),
max: maxEntries,
}
}
// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
d.mu.Lock()
defer d.mu.Unlock()
// 1. Check for duplicate
if d.msgs[msgID] {
return false
}
// 2. Evict the oldest message at our current ring position (if any)
oldestID := d.ring[d.idx]
if oldestID != "" {
delete(d.msgs, oldestID)
}
// 3. Store the new message
d.msgs[msgID] = true
d.ring[d.idx] = msgID
// 4. Advance the circle queue index
d.idx = (d.idx + 1) % d.max
return true
}
-83
View File
@@ -1,83 +0,0 @@
package wecom
import (
"sync"
"testing"
)
func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
if ok := d.MarkMessageProcessed("msg-1"); !ok {
t.Fatalf("first message should be accepted")
}
if ok := d.MarkMessageProcessed("msg-1"); ok {
t.Fatalf("duplicate message should be rejected")
}
}
func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
const goroutines = 64
var wg sync.WaitGroup
wg.Add(goroutines)
results := make(chan bool, goroutines)
for i := 0; i < goroutines; i++ {
go func() {
defer wg.Done()
results <- d.MarkMessageProcessed("msg-concurrent")
}()
}
wg.Wait()
close(results)
successes := 0
for ok := range results {
if ok {
successes++
}
}
if successes != 1 {
t.Fatalf("expected exactly 1 successful mark, got %d", successes)
}
}
func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
// Create a deduplicator with a very small capacity to test eviction easily.
capacity := 3
d := NewMessageDeduplicator(capacity)
// Fill the queue.
d.MarkMessageProcessed("msg-1")
d.MarkMessageProcessed("msg-2")
d.MarkMessageProcessed("msg-3")
// At this point, the queue is full. msg-1 is the oldest.
if len(d.msgs) != 3 {
t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
}
// This should evict msg-1 and add msg-4.
if ok := d.MarkMessageProcessed("msg-4"); !ok {
t.Fatalf("msg-4 should be accepted")
}
if len(d.msgs) != 3 {
t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
}
// msg-1 should now be forgotten (evicted).
if ok := d.MarkMessageProcessed("msg-1"); !ok {
t.Fatalf("msg-1 should be accepted again because it was evicted")
}
// msg-2 should have been evicted when we added msg-1 back.
if ok := d.MarkMessageProcessed("msg-2"); !ok {
t.Fatalf("msg-2 should be accepted again because it was evicted")
}
}
+1 -7
View File
@@ -8,12 +8,6 @@ import (
func init() {
channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewWeComBotChannel(cfg.Channels.WeCom, b)
})
channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewWeComAppChannel(cfg.Channels.WeComApp, b)
})
channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b)
return NewChannel(cfg.Channels.WeCom, b)
})
}
+291
View File
@@ -0,0 +1,291 @@
package wecom
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"io"
"mime"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/h2non/filetype"
"github.com/sipeed/picoclaw/pkg/media"
)
func decodeMediaAESKey(value string) ([]byte, error) {
if value == "" {
return nil, nil
}
key, err := base64.StdEncoding.DecodeString(value)
if err == nil && len(key) == 32 {
return key, nil
}
key, err = base64.StdEncoding.DecodeString(value + "=")
if err != nil {
return nil, fmt.Errorf("decode AES key: %w", err)
}
if len(key) != 32 {
return nil, fmt.Errorf("invalid AES key length %d", len(key))
}
return key, nil
}
func decryptAESCBC(key, ciphertext []byte) ([]byte, error) {
if len(ciphertext) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
}
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("create cipher: %w", err)
}
plaintext := make([]byte, len(ciphertext))
iv := key[:aes.BlockSize]
cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
return pkcs7Unpad(plaintext)
}
func pkcs7Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {
return nil, fmt.Errorf("empty plaintext")
}
padding := int(data[len(data)-1])
if padding == 0 || padding > 32 || padding > len(data) {
return nil, fmt.Errorf("invalid padding size %d", padding)
}
for i := 0; i < padding; i++ {
if data[len(data)-1-i] != byte(padding) {
return nil, fmt.Errorf("invalid padding byte")
}
}
return data[:len(data)-padding], nil
}
func inferMediaExt(contentType, fallback string) string {
contentType = normalizeWeComContentType(contentType)
switch contentType {
case "image/jpeg", "image/jpg":
return ".jpg"
case "image/png":
return ".png"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "application/pdf":
return ".pdf"
case "video/mp4":
return ".mp4"
default:
return fallback
}
}
func normalizeWeComContentType(value string) string {
value = strings.ToLower(strings.TrimSpace(value))
if idx := strings.Index(value, ";"); idx >= 0 {
value = strings.TrimSpace(value[:idx])
}
return value
}
func isGenericWeComContentType(value string) bool {
switch normalizeWeComContentType(value) {
case "", "application/octet-stream", "binary/octet-stream", "application/unknown", "application/binary":
return true
default:
return false
}
}
func sanitizeWeComFilename(name string) string {
name = filepath.Base(strings.TrimSpace(name))
if name == "." || name == "/" || name == "" {
return ""
}
return name
}
func candidateWeComFilename(resourceURL, contentDisposition, fallbackName string) string {
if _, params, err := mime.ParseMediaType(contentDisposition); err == nil {
if name := sanitizeWeComFilename(params["filename"]); name != "" {
return name
}
if name := sanitizeWeComFilename(params["filename*"]); name != "" {
return name
}
}
if parsed, err := url.Parse(resourceURL); err == nil {
query := parsed.Query()
for _, key := range []string{"filename", "file_name", "name"} {
if name := sanitizeWeComFilename(query.Get(key)); name != "" {
return name
}
}
if name := sanitizeWeComFilename(parsed.Path); name != "" {
return name
}
}
return sanitizeWeComFilename(fallbackName)
}
func detectWeComFiletype(data []byte) (string, string) {
kind, err := filetype.Match(data)
if err != nil || kind == filetype.Unknown {
return "", ""
}
ext := ""
if kind.Extension != "" {
ext = "." + strings.ToLower(kind.Extension)
}
return normalizeWeComContentType(kind.MIME.Value), ext
}
func detectWeComMediaMetadata(data []byte, fallbackName, fallbackContentType, resourceURL, contentDisposition string) (string, string) {
filename := candidateWeComFilename(resourceURL, contentDisposition, fallbackName)
if filename == "" {
filename = "media"
}
ext := strings.ToLower(filepath.Ext(filename))
contentType := normalizeWeComContentType(fallbackContentType)
detectedType, detectedExt := detectWeComFiletype(data)
if ext != "" && isGenericWeComContentType(contentType) {
if byExt := normalizeWeComContentType(mime.TypeByExtension(ext)); byExt != "" {
contentType = byExt
}
}
if detectedType != "" {
switch {
case contentType == "":
contentType = detectedType
case isGenericWeComContentType(contentType):
contentType = detectedType
case strings.HasPrefix(detectedType, "image/") && !strings.HasPrefix(contentType, "image/"):
contentType = detectedType
case strings.HasPrefix(detectedType, "audio/") && !strings.HasPrefix(contentType, "audio/"):
contentType = detectedType
case strings.HasPrefix(detectedType, "video/") && !strings.HasPrefix(contentType, "video/"):
contentType = detectedType
}
}
if contentType == "" && ext != "" {
contentType = normalizeWeComContentType(mime.TypeByExtension(ext))
}
if contentType == "" {
contentType = normalizeWeComContentType(http.DetectContentType(data))
}
if ext == "" {
ext = detectedExt
}
if ext == "" && contentType != "" {
if exts, err := mime.ExtensionsByType(contentType); err == nil && len(exts) > 0 {
ext = strings.ToLower(exts[0])
}
}
if filepath.Ext(filename) == "" && ext != "" {
filename += ext
}
return filename, contentType
}
func (c *WeComChannel) storeRemoteMedia(
ctx context.Context,
scope, msgID, resourceURL, aesKey, fallbackExt string,
) (string, error) {
store := c.GetMediaStore()
if store == nil {
return "", fmt.Errorf("no media store available")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
resp, err := c.mediaClient.Do(req)
if err != nil {
return "", fmt.Errorf("download media: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("download media returned HTTP %d", resp.StatusCode)
}
const maxSize = 20 << 20
data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1))
if err != nil {
return "", fmt.Errorf("read media: %w", err)
}
if len(data) > maxSize {
return "", fmt.Errorf("media too large")
}
if aesKey != "" {
key, keyErr := decodeMediaAESKey(aesKey)
if keyErr != nil {
return "", keyErr
}
data, err = decryptAESCBC(key, data)
if err != nil {
return "", fmt.Errorf("decrypt media: %w", err)
}
}
filename, contentType := detectWeComMediaMetadata(
data,
msgID+fallbackExt,
resp.Header.Get("Content-Type"),
resourceURL,
resp.Header.Get("Content-Disposition"),
)
ext := filepath.Ext(filename)
if ext == "" {
ext = inferMediaExt(contentType, fallbackExt)
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
return "", fmt.Errorf("mkdir media dir: %w", mkdirErr)
}
tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext)
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
if _, writeErr := tmpFile.Write(data); writeErr != nil {
tmpFile.Close()
_ = os.Remove(tmpPath)
return "", fmt.Errorf("write temp file: %w", writeErr)
}
if closeErr := tmpFile.Close(); closeErr != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("close temp file: %w", closeErr)
}
ref, err := store.Store(tmpPath, media.MediaMeta{
Filename: filename,
ContentType: contentType,
Source: "wecom",
CleanupPolicy: media.CleanupPolicyDeleteOnCleanup,
}, scope)
if err != nil {
_ = os.Remove(tmpPath)
return "", err
}
return ref, nil
}
+180
View File
@@ -0,0 +1,180 @@
package wecom
import (
"bytes"
"context"
"encoding/base64"
"io"
"net/http"
"strings"
"testing"
basechannels "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestStoreRemoteMedia_DetectsJPEGContentTypeFromBody(t *testing.T) {
t.Parallel()
const jpegBase64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAP//////////////////////////////////////////////////////////////////////////////////////" +
"//////////////////////////////////////////////////////////////////////////////////////////////2wBDAf//////////////////////////////////////////////////////////////////////////////////////" +
"//////////////////////////////////////////////////////////////////////////////////////////////wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAVEQEBAAAAAAAAAAAAAAAAAAAABf/aAAwDAQACEAMQAAAB6A//xAAVEAEBAAAAAAAAAAAAAAAAAAAAEf/aAAgBAQABBQJf/8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAwEBPwF//8QAFBEBAAAAAAAAAAAAAAAAAAAAEP/aAAgBAgEBPwF//8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQAGPwJf/8QAFBABAAAAAAAAAAAAAAAAAAAAEP/aAAgBAQABPyFf/9k="
jpegData := decodeTestBase64(t, jpegBase64)
store := media.NewFileMediaStore()
ch := &WeComChannel{
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
mediaClient: &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/octet-stream"}},
Body: io.NopCloser(bytes.NewReader(jpegData)),
}, nil
}),
},
}
ch.SetMediaStore(store)
ref, err := ch.storeRemoteMedia(context.Background(), "test-scope", "msg-1", "https://wecom.example/media", "", "")
if err != nil {
t.Fatalf("storeRemoteMedia returned error: %v", err)
}
t.Cleanup(func() {
_ = store.ReleaseAll("test-scope")
})
_, meta, err := store.ResolveWithMeta(ref)
if err != nil {
t.Fatalf("resolve media ref: %v", err)
}
if meta.ContentType != "image/jpeg" {
t.Fatalf("expected image/jpeg content type, got %q", meta.ContentType)
}
if !strings.HasSuffix(meta.Filename, ".jpg") && !strings.HasSuffix(meta.Filename, ".jpeg") {
t.Fatalf("expected jpeg filename, got %q", meta.Filename)
}
}
func TestDetectWeComMediaMetadata_UsesFallbackExtensionWhenBodyUnknown(t *testing.T) {
t.Parallel()
filename, contentType := detectWeComMediaMetadata([]byte("not a real image"), "msg-2.pdf", "", "", "")
if filename != "msg-2.pdf" {
t.Fatalf("expected fallback filename to be preserved, got %q", filename)
}
if contentType != "application/pdf" {
t.Fatalf("expected application/pdf from fallback extension, got %q", contentType)
}
}
func TestStoreRemoteMedia_PreservesSuffixFromURL(t *testing.T) {
t.Parallel()
docxLikeData := []byte("PK\x03\x04fake office payload")
store := media.NewFileMediaStore()
ch := &WeComChannel{
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
mediaClient: &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/octet-stream"}},
Body: io.NopCloser(bytes.NewReader(docxLikeData)),
}, nil
}),
},
}
ch.SetMediaStore(store)
ref, err := ch.storeRemoteMedia(
context.Background(),
"test-scope",
"msg-docx",
"https://wecom.example/media/report.docx?signature=1",
"",
".bin",
)
if err != nil {
t.Fatalf("storeRemoteMedia returned error: %v", err)
}
t.Cleanup(func() {
_ = store.ReleaseAll("test-scope")
})
localPath, meta, err := store.ResolveWithMeta(ref)
if err != nil {
t.Fatalf("resolve media ref: %v", err)
}
if !strings.HasSuffix(meta.Filename, ".docx") {
t.Fatalf("expected docx filename, got %q", meta.Filename)
}
if !strings.HasSuffix(strings.ToLower(localPath), ".docx") {
t.Fatalf("expected docx temp path, got %q", localPath)
}
}
func TestStoreRemoteMedia_PreservesSuffixFromContentDisposition(t *testing.T) {
t.Parallel()
pptxLikeData := []byte("PK\x03\x04fake office payload")
store := media.NewFileMediaStore()
ch := &WeComChannel{
BaseChannel: basechannels.NewBaseChannel("wecom", nil, nil, nil),
mediaClient: &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/octet-stream"},
"Content-Disposition": []string{`attachment; filename="slides.pptx"`},
},
Body: io.NopCloser(bytes.NewReader(pptxLikeData)),
}, nil
}),
},
}
ch.SetMediaStore(store)
ref, err := ch.storeRemoteMedia(
context.Background(),
"test-scope",
"msg-pptx",
"https://wecom.example/media/download",
"",
".bin",
)
if err != nil {
t.Fatalf("storeRemoteMedia returned error: %v", err)
}
t.Cleanup(func() {
_ = store.ReleaseAll("test-scope")
})
localPath, meta, err := store.ResolveWithMeta(ref)
if err != nil {
t.Fatalf("resolve media ref: %v", err)
}
if !strings.HasSuffix(meta.Filename, ".pptx") {
t.Fatalf("expected pptx filename, got %q", meta.Filename)
}
if !strings.HasSuffix(strings.ToLower(localPath), ".pptx") {
t.Fatalf("expected pptx temp path, got %q", localPath)
}
}
func decodeTestBase64(t *testing.T, value string) []byte {
t.Helper()
data, err := io.ReadAll(base64.NewDecoder(base64.StdEncoding, strings.NewReader(value)))
if err != nil {
t.Fatalf("decode base64 fixture: %v", err)
}
return data
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
+122
View File
@@ -0,0 +1,122 @@
package wecom
import "encoding/json"
const (
wecomDefaultWebSocketURL = "wss://openws.work.weixin.qq.com"
wecomCmdSubscribe = "aibot_subscribe"
wecomCmdPing = "ping"
wecomCmdMsgCallback = "aibot_msg_callback"
wecomCmdEventCallback = "aibot_event_callback"
wecomCmdRespondMsg = "aibot_respond_msg"
wecomCmdSendMsg = "aibot_send_msg"
wecomMaxContentBytes = 20480
)
type wecomEnvelope struct {
Cmd string `json:"cmd,omitempty"`
Headers wecomHeaders `json:"headers"`
Body json.RawMessage `json:"body,omitempty"`
ErrCode int `json:"errcode,omitempty"`
ErrMsg string `json:"errmsg,omitempty"`
}
type wecomHeaders struct {
ReqID string `json:"req_id,omitempty"`
}
type wecomCommand struct {
Cmd string `json:"cmd"`
Headers wecomHeaders `json:"headers"`
Body any `json:"body,omitempty"`
}
type wecomSendMsgBody struct {
ChatID string `json:"chatid"`
ChatType uint32 `json:"chat_type,omitempty"`
MsgType string `json:"msgtype"`
Markdown *wecomMarkdownContent `json:"markdown,omitempty"`
}
type wecomRespondMsgBody struct {
MsgType string `json:"msgtype"`
Stream *wecomStreamContent `json:"stream,omitempty"`
}
type wecomStreamContent struct {
ID string `json:"id"`
Finish bool `json:"finish"`
Content string `json:"content,omitempty"`
}
type wecomMarkdownContent struct {
Content string `json:"content"`
}
type wecomIncomingMessage struct {
MsgID string `json:"msgid"`
AIBotID string `json:"aibotid"`
ChatID string `json:"chatid,omitempty"`
ChatType string `json:"chattype,omitempty"`
From struct {
UserID string `json:"userid"`
} `json:"from"`
MsgType string `json:"msgtype"`
Text *struct {
Content string `json:"content"`
} `json:"text,omitempty"`
Image *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"image,omitempty"`
File *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"file,omitempty"`
Video *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"video,omitempty"`
Voice *struct {
Content string `json:"content"`
} `json:"voice,omitempty"`
Mixed *struct {
MsgItem []struct {
MsgType string `json:"msgtype"`
Text *struct {
Content string `json:"content"`
} `json:"text,omitempty"`
Image *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"image,omitempty"`
File *struct {
URL string `json:"url"`
AESKey string `json:"aeskey,omitempty"`
} `json:"file,omitempty"`
} `json:"msg_item"`
} `json:"mixed,omitempty"`
Quote *struct {
MsgType string `json:"msgtype"`
Text *struct {
Content string `json:"content"`
} `json:"text,omitempty"`
} `json:"quote,omitempty"`
Event *struct {
EventType string `json:"eventtype"`
} `json:"event,omitempty"`
}
func incomingChatID(msg wecomIncomingMessage) string {
if msg.ChatID != "" {
return msg.ChatID
}
return msg.From.UserID
}
func incomingChatTypeCode(kind string) uint32 {
if kind == "group" {
return 2
}
return 1
}
+113
View File
@@ -0,0 +1,113 @@
package wecom
import (
"encoding/json"
"errors"
"os"
"path/filepath"
"sync"
"time"
)
type wecomRoute struct {
ReqID string `json:"req_id"`
ChatID string `json:"chat_id"`
ChatType uint32 `json:"chat_type"`
ExpiresAt time.Time `json:"expires_at"`
}
type reqIDStore struct {
mu sync.Mutex
path string
routes map[string]wecomRoute
}
func newReqIDStore(path string) *reqIDStore {
if path == "" {
path = defaultReqIDStorePath()
}
s := &reqIDStore{
path: path,
routes: make(map[string]wecomRoute),
}
_ = s.load()
return s
}
func defaultReqIDStorePath() string {
if home, err := os.UserHomeDir(); err == nil && home != "" {
return filepath.Join(home, ".picoclaw", "wecom", "reqid-store.json")
}
return filepath.Join(os.TempDir(), "picoclaw-wecom-reqid-store.json")
}
func (s *reqIDStore) Put(chatID, reqID string, chatType uint32, ttl time.Duration) error {
if reqID == "" || chatID == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
s.deleteExpiredLocked(time.Now())
s.routes[chatID] = wecomRoute{
ReqID: reqID,
ChatID: chatID,
ChatType: chatType,
ExpiresAt: time.Now().Add(ttl),
}
return s.saveLocked()
}
func (s *reqIDStore) Get(chatID string) (wecomRoute, bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.deleteExpiredLocked(time.Now())
route, ok := s.routes[chatID]
return route, ok
}
func (s *reqIDStore) Delete(chatID string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.routes, chatID)
return s.saveLocked()
}
func (s *reqIDStore) load() error {
s.mu.Lock()
defer s.mu.Unlock()
data, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var routes map[string]wecomRoute
if err := json.Unmarshal(data, &routes); err != nil {
return err
}
s.routes = routes
s.deleteExpiredLocked(time.Now())
return nil
}
func (s *reqIDStore) deleteExpiredLocked(now time.Time) {
for chatID, route := range s.routes {
if !route.ExpiresAt.IsZero() && now.After(route.ExpiresAt) {
delete(s.routes, chatID)
}
}
}
func (s *reqIDStore) saveLocked() error {
if err := os.MkdirAll(filepath.Dir(s.path), 0o700); err != nil {
return err
}
data, err := json.MarshalIndent(s.routes, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.path, data, 0o600)
}
+24
View File
@@ -0,0 +1,24 @@
package wecom
import (
"path/filepath"
"testing"
"time"
)
func TestReqIDStorePersistsRoutes(t *testing.T) {
storePath := filepath.Join(t.TempDir(), "reqids.json")
store := newReqIDStore(storePath)
if err := store.Put("chat-1", "req-1", 2, time.Hour); err != nil {
t.Fatalf("Put() error = %v", err)
}
reloaded := newReqIDStore(storePath)
route, ok := reloaded.Get("chat-1")
if !ok {
t.Fatal("expected persisted route to be loaded")
}
if route.ChatID != "chat-1" || route.ReqID != "req-1" || route.ChatType != 2 {
t.Fatalf("loaded route = %+v", route)
}
}
+777
View File
@@ -0,0 +1,777 @@
package wecom
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"math/big"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
const (
wecomConnectTimeout = 15 * time.Second
wecomCommandTimeout = 10 * time.Second
wecomHeartbeatInterval = 30 * time.Second
wecomStreamMaxDuration = 5*time.Minute + 30*time.Second
wecomRouteTTL = 30 * time.Minute
wecomMediaTimeout = 30 * time.Second
wecomRecentMessageMax = 1000
)
type WeComChannel struct {
*channels.BaseChannel
config config.WeComConfig
ctx context.Context
cancel context.CancelFunc
conn *websocket.Conn
connMu sync.Mutex
pendingMu sync.Mutex
pending map[string]chan wecomEnvelope
turnsMu sync.Mutex
turns map[string][]wecomTurn
recent *recentMessageSet
routes *reqIDStore
mediaClient *http.Client
commandSend func(wecomCommand, time.Duration) error
}
type wecomTurn struct {
ReqID string
ChatID string
ChatType uint32
StreamID string
CreatedAt time.Time
}
type recentMessageSet struct {
mu sync.Mutex
seen map[string]struct{}
ring []string
idx int
}
func newRecentMessageSet(capacity int) *recentMessageSet {
if capacity <= 0 {
capacity = wecomRecentMessageMax
}
return &recentMessageSet{
seen: make(map[string]struct{}, capacity),
ring: make([]string, capacity),
}
}
func (s *recentMessageSet) Mark(id string) bool {
if id == "" {
return true
}
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.seen[id]; ok {
return false
}
if old := s.ring[s.idx]; old != "" {
delete(s.seen, old)
}
s.ring[s.idx] = id
s.idx = (s.idx + 1) % len(s.ring)
s.seen[id] = struct{}{}
return true
}
func NewChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComChannel, error) {
if cfg.BotID == "" || cfg.Secret() == "" {
return nil, fmt.Errorf("wecom bot_id and secret are required")
}
if cfg.WebSocketURL == "" {
cfg.WebSocketURL = wecomDefaultWebSocketURL
}
base := channels.NewBaseChannel(
"wecom",
cfg,
messageBus,
cfg.AllowFrom,
channels.WithMaxMessageLength(wecomMaxContentBytes),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
ch := &WeComChannel{
BaseChannel: base,
config: cfg,
pending: make(map[string]chan wecomEnvelope),
turns: make(map[string][]wecomTurn),
recent: newRecentMessageSet(wecomRecentMessageMax),
routes: newReqIDStore(""),
mediaClient: &http.Client{Timeout: wecomMediaTimeout},
}
ch.SetOwner(ch)
return ch, nil
}
func (c *WeComChannel) Name() string { return "wecom" }
func (c *WeComChannel) Start(ctx context.Context) error {
logger.InfoC("wecom", "Starting WeCom channel...")
c.ctx, c.cancel = context.WithCancel(ctx)
c.SetRunning(true)
go c.connectLoop()
return nil
}
func (c *WeComChannel) Stop(_ context.Context) error {
logger.InfoC("wecom", "Stopping WeCom channel...")
if c.cancel != nil {
c.cancel()
}
c.connMu.Lock()
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
c.connMu.Unlock()
c.clearTurns()
c.SetRunning(false)
return nil
}
func (c *WeComChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
content := strings.TrimSpace(msg.Content)
if content == "" {
return nil
}
if turn, ok := c.getTurn(msg.ChatID); ok {
if time.Since(turn.CreatedAt) <= wecomStreamMaxDuration {
if err := c.sendStreamReply(turn, content); err == nil {
c.deleteTurn(msg.ChatID)
return nil
}
}
c.deleteTurn(msg.ChatID)
}
if route, ok := c.routes.Get(msg.ChatID); ok {
if err := c.sendActivePush(route.ChatID, route.ChatType, content); err != nil {
return err
}
return nil
}
if err := c.sendActivePush(msg.ChatID, 0, content); err != nil {
return err
}
return nil
}
func (c *WeComChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
var parts []string
for _, part := range msg.Parts {
switch {
case part.Caption != "":
parts = append(parts, part.Caption)
case part.Filename != "":
parts = append(parts, fmt.Sprintf("[media: %s]", part.Filename))
default:
parts = append(parts, "[media attachments are not yet supported]")
}
}
return c.Send(ctx, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: strings.Join(parts, "\n"),
})
}
func (c *WeComChannel) connectLoop() {
backoff := time.Second
for {
select {
case <-c.ctx.Done():
return
default:
}
if err := c.runConnection(); err != nil {
logger.WarnCF("wecom", "WeCom connection lost", map[string]any{
"error": err.Error(),
"backoff": backoff.String(),
})
select {
case <-time.After(backoff):
case <-c.ctx.Done():
return
}
if backoff < time.Minute {
backoff *= 2
if backoff > time.Minute {
backoff = time.Minute
}
}
continue
}
return
}
}
func (c *WeComChannel) runConnection() error {
dialCtx, cancel := context.WithTimeout(c.ctx, wecomConnectTimeout)
defer cancel()
conn, resp, err := websocket.DefaultDialer.DialContext(dialCtx, c.config.WebSocketURL, nil)
if resp != nil {
_ = resp.Body.Close()
}
if err != nil {
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
}
c.connMu.Lock()
c.conn = conn
c.connMu.Unlock()
defer func() {
c.connMu.Lock()
if c.conn == conn {
c.conn = nil
}
c.connMu.Unlock()
_ = conn.Close()
c.clearTurns()
}()
readErrCh := make(chan error, 1)
go func() {
readErrCh <- c.readLoop(conn)
}()
if writeErr := c.writeAndWait(conn, wecomCommand{
Cmd: wecomCmdSubscribe,
Headers: wecomHeaders{ReqID: randomID(10)},
Body: map[string]string{
"bot_id": c.config.BotID,
"secret": c.config.Secret(),
},
}, wecomCommandTimeout); writeErr != nil {
return writeErr
}
heartbeatDone := make(chan struct{})
go func() {
defer close(heartbeatDone)
c.heartbeatLoop(conn)
}()
err = <-readErrCh
_ = conn.Close()
<-heartbeatDone
return err
}
func (c *WeComChannel) heartbeatLoop(conn *websocket.Conn) {
ticker := time.NewTicker(wecomHeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.writeAndWait(conn, wecomCommand{
Cmd: wecomCmdPing,
Headers: wecomHeaders{ReqID: randomID(10)},
}, wecomCommandTimeout); err != nil {
logger.WarnCF("wecom", "Heartbeat failed", map[string]any{"error": err.Error()})
_ = conn.Close()
return
}
case <-c.ctx.Done():
return
}
}
}
func (c *WeComChannel) readLoop(conn *websocket.Conn) error {
for {
_, raw, err := conn.ReadMessage()
if err != nil {
select {
case <-c.ctx.Done():
return nil
default:
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
}
}
var env wecomEnvelope
if err := json.Unmarshal(raw, &env); err != nil {
logger.WarnCF("wecom", "Failed to parse WebSocket message", map[string]any{"error": err.Error()})
continue
}
if env.Cmd == "" && env.Headers.ReqID != "" {
c.pendingMu.Lock()
ch, ok := c.pending[env.Headers.ReqID]
if ok {
delete(c.pending, env.Headers.ReqID)
}
c.pendingMu.Unlock()
if ok {
ch <- env
}
continue
}
go c.handleEnvelope(env)
}
}
func (c *WeComChannel) handleEnvelope(env wecomEnvelope) {
switch env.Cmd {
case wecomCmdMsgCallback:
c.handleMessageCallback(env)
case wecomCmdEventCallback:
c.handleEventCallback(env)
default:
logger.DebugCF("wecom", "Ignoring unsupported WeCom command", map[string]any{"cmd": env.Cmd})
}
}
func (c *WeComChannel) handleEventCallback(env wecomEnvelope) {
var msg wecomIncomingMessage
if err := json.Unmarshal(env.Body, &msg); err != nil {
logger.WarnCF("wecom", "Failed to parse WeCom event callback", map[string]any{"error": err.Error()})
}
}
func (c *WeComChannel) handleMessageCallback(env wecomEnvelope) {
var msg wecomIncomingMessage
if err := json.Unmarshal(env.Body, &msg); err != nil {
logger.WarnCF("wecom", "Failed to parse WeCom message callback", map[string]any{"error": err.Error()})
return
}
if !c.recent.Mark(msg.MsgID) {
return
}
reqID := env.Headers.ReqID
if reqID == "" {
logger.WarnC("wecom", "WeCom message callback missing req_id")
return
}
if msg.Event != nil && msg.Event.EventType != "" {
return
}
if err := c.dispatchIncoming(reqID, msg); err != nil {
logger.WarnCF("wecom", "Failed to dispatch WeCom message", map[string]any{
"req_id": reqID,
"error": err.Error(),
})
_ = c.respondImmediate(reqID, "The WeCom message could not be processed.")
}
}
func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) error {
senderID := msg.From.UserID
if senderID == "" {
senderID = "unknown"
}
actualChatID := incomingChatID(msg)
chatType := incomingChatTypeCode(msg.ChatType)
peerKind := "direct"
if msg.ChatType == "group" {
peerKind = "group"
}
sender := bus.SenderInfo{
Platform: "wecom",
PlatformID: senderID,
CanonicalID: identity.BuildCanonicalID("wecom", senderID),
DisplayName: senderID,
}
var (
content string
quoteText string
mediaRefs []string
err error
)
scope := channels.BuildMediaScope("wecom", actualChatID, msg.MsgID)
switch msg.MsgType {
case "text":
if msg.Text != nil {
content = strings.TrimSpace(msg.Text.Content)
}
case "voice":
if msg.Voice != nil {
content = strings.TrimSpace(msg.Voice.Content)
}
case "image":
content = "[image]"
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
url: msg.Image.URL,
aesKey: msg.Image.AESKey,
}, "image", ".jpg")
case "file":
content = "[file]"
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
url: msg.File.URL,
aesKey: msg.File.AESKey,
}, "file", ".bin")
case "video":
content = "[video]"
mediaRefs, err = c.collectSingleMedia(c.ctx, scope, msg.MsgID, &mediaPayload{
url: msg.Video.URL,
aesKey: msg.Video.AESKey,
}, "video", ".mp4")
case "mixed":
content, mediaRefs, err = c.collectMixedMedia(c.ctx, scope, msg)
default:
return c.respondImmediate(reqID, "Unsupported WeCom message type: "+msg.MsgType)
}
if err != nil {
return err
}
if msg.Quote != nil && msg.Quote.Text != nil {
quoteText = strings.TrimSpace(msg.Quote.Text.Content)
if content == "" {
content = quoteText
}
}
if content == "" && len(mediaRefs) == 0 {
return c.respondImmediate(reqID, "The WeCom message did not contain usable content.")
}
turn := wecomTurn{
ReqID: reqID,
ChatID: actualChatID,
ChatType: chatType,
StreamID: randomID(10),
CreatedAt: time.Now(),
}
c.queueTurn(actualChatID, turn)
if err := c.routes.Put(actualChatID, reqID, chatType, wecomRouteTTL); err != nil {
logger.WarnCF("wecom", "Failed to persist req_id route", map[string]any{
"chat_id": actualChatID,
"req_id": reqID,
"error": err.Error(),
})
}
opening := ""
if c.config.SendThinkingMessage {
opening = "Processing..."
}
if err := c.sendStreamChunk(turn, false, opening); err != nil {
return err
}
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
metadata := map[string]string{
"channel": "wecom",
"req_id": reqID,
"chat_id": actualChatID,
"chat_type": msg.ChatType,
"msg_id": msg.MsgID,
"msg_type": msg.MsgType,
}
if quoteText != "" {
metadata["quote_text"] = quoteText
}
c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender)
return nil
}
func (c *WeComChannel) collectSingleMedia(
ctx context.Context,
scope, msgID string,
payload interface {
GetURL() string
GetAESKey() string
},
label, fallbackExt string,
) ([]string, error) {
if payload == nil || payload.GetURL() == "" {
return nil, fmt.Errorf("%s payload is empty", label)
}
ref, err := c.storeRemoteMedia(ctx, scope, msgID, payload.GetURL(), payload.GetAESKey(), fallbackExt)
if err != nil {
return nil, err
}
return []string{ref}, nil
}
type mediaPayload struct {
url string
aesKey string
}
func (p *mediaPayload) GetURL() string { return p.url }
func (p *mediaPayload) GetAESKey() string { return p.aesKey }
func (c *WeComChannel) collectMixedMedia(
ctx context.Context,
scope string,
msg wecomIncomingMessage,
) (string, []string, error) {
if msg.Mixed == nil {
return "", nil, fmt.Errorf("mixed message is empty")
}
var textParts []string
var refs []string
for idx, item := range msg.Mixed.MsgItem {
switch item.MsgType {
case "text":
if item.Text != nil && strings.TrimSpace(item.Text.Content) != "" {
textParts = append(textParts, strings.TrimSpace(item.Text.Content))
}
case "image":
if item.Image != nil && item.Image.URL != "" {
ref, err := c.storeRemoteMedia(
ctx,
scope,
fmt.Sprintf("%s-%d", msg.MsgID, idx),
item.Image.URL,
item.Image.AESKey,
".jpg",
)
if err != nil {
return "", nil, err
}
refs = append(refs, ref)
}
case "file":
if item.File != nil && item.File.URL != "" {
ref, err := c.storeRemoteMedia(
ctx,
scope,
fmt.Sprintf("%s-%d", msg.MsgID, idx),
item.File.URL,
item.File.AESKey,
".bin",
)
if err != nil {
return "", nil, err
}
refs = append(refs, ref)
}
}
}
content := strings.Join(textParts, "\n")
if content == "" && len(refs) > 0 {
content = "[media]"
}
return content, refs, nil
}
func (c *WeComChannel) respondImmediate(reqID, content string) error {
turn := wecomTurn{
ReqID: reqID,
StreamID: randomID(10),
CreatedAt: time.Now(),
}
return c.sendStreamChunk(turn, true, content)
}
func (c *WeComChannel) sendStreamReply(turn wecomTurn, content string) error {
chunks := splitContent(content, wecomMaxContentBytes)
for idx, chunk := range chunks {
if err := c.sendStreamChunk(turn, idx == len(chunks)-1, chunk); err != nil {
return err
}
}
return nil
}
func (c *WeComChannel) sendStreamChunk(turn wecomTurn, finish bool, content string) error {
return c.sendCommand(wecomCommand{
Cmd: wecomCmdRespondMsg,
Headers: wecomHeaders{ReqID: turn.ReqID},
Body: wecomRespondMsgBody{
MsgType: "stream",
Stream: &wecomStreamContent{
ID: turn.StreamID,
Finish: finish,
Content: content,
},
},
}, wecomCommandTimeout)
}
func (c *WeComChannel) sendActivePush(chatID string, chatType uint32, content string) error {
if strings.TrimSpace(chatID) == "" {
return fmt.Errorf("empty chat ID: %w", channels.ErrSendFailed)
}
for _, chunk := range splitContent(content, wecomMaxContentBytes) {
if err := c.sendCommand(wecomCommand{
Cmd: wecomCmdSendMsg,
Headers: wecomHeaders{ReqID: randomID(10)},
Body: wecomSendMsgBody{
ChatID: chatID,
ChatType: chatType,
MsgType: "markdown",
Markdown: &wecomMarkdownContent{Content: chunk},
},
}, wecomCommandTimeout); err != nil {
return err
}
}
return nil
}
func (c *WeComChannel) sendCommand(cmd wecomCommand, timeout time.Duration) error {
if c.commandSend != nil {
return c.commandSend(cmd, timeout)
}
return c.writeCurrent(cmd, timeout)
}
func (c *WeComChannel) writeCurrent(cmd wecomCommand, timeout time.Duration) error {
c.connMu.Lock()
conn := c.conn
c.connMu.Unlock()
if conn == nil {
return fmt.Errorf("wecom websocket not connected: %w", channels.ErrTemporary)
}
return c.writeAndWait(conn, cmd, timeout)
}
func (c *WeComChannel) writeAndWait(conn *websocket.Conn, cmd wecomCommand, timeout time.Duration) error {
if cmd.Headers.ReqID == "" {
cmd.Headers.ReqID = randomID(10)
}
waitCh := make(chan wecomEnvelope, 1)
c.pendingMu.Lock()
c.pending[cmd.Headers.ReqID] = waitCh
c.pendingMu.Unlock()
defer func() {
c.pendingMu.Lock()
delete(c.pending, cmd.Headers.ReqID)
c.pendingMu.Unlock()
}()
data, err := json.Marshal(cmd)
if err != nil {
return fmt.Errorf("%w: %v", channels.ErrSendFailed, err)
}
c.connMu.Lock()
err = conn.WriteMessage(websocket.TextMessage, data)
c.connMu.Unlock()
if err != nil {
return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case env := <-waitCh:
if env.ErrCode != 0 {
return fmt.Errorf("%w: wecom errcode=%d errmsg=%s", channels.ErrTemporary, env.ErrCode, env.ErrMsg)
}
return nil
case <-timer.C:
return fmt.Errorf("%w: timeout waiting for WeCom ack", channels.ErrTemporary)
case <-c.ctx.Done():
return c.ctx.Err()
}
}
func (c *WeComChannel) getTurn(chatID string) (wecomTurn, bool) {
c.turnsMu.Lock()
defer c.turnsMu.Unlock()
queue := c.turns[chatID]
if len(queue) == 0 {
return wecomTurn{}, false
}
return queue[0], true
}
func (c *WeComChannel) deleteTurn(chatID string) {
c.turnsMu.Lock()
defer c.turnsMu.Unlock()
queue := c.turns[chatID]
if len(queue) <= 1 {
delete(c.turns, chatID)
return
}
c.turns[chatID] = queue[1:]
}
func (c *WeComChannel) queueTurn(chatID string, turn wecomTurn) {
c.turnsMu.Lock()
defer c.turnsMu.Unlock()
c.turns[chatID] = append(c.turns[chatID], turn)
}
func (c *WeComChannel) clearTurns() {
c.turnsMu.Lock()
c.turns = make(map[string][]wecomTurn)
c.turnsMu.Unlock()
}
func randomID(n int) string {
const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
if n <= 0 {
n = 10
}
buf := make([]byte, n)
for i := range buf {
v, _ := rand.Int(rand.Reader, big.NewInt(int64(len(alphabet))))
buf[i] = alphabet[v.Int64()]
}
return string(buf)
}
func splitContent(content string, maxBytes int) []string {
if content == "" {
return []string{""}
}
if len(content) <= maxBytes {
return []string{content}
}
chunks := channels.SplitMessage(content, maxBytes)
var result []string
for _, chunk := range chunks {
if len(chunk) <= maxBytes {
result = append(result, chunk)
continue
}
for len(chunk) > maxBytes {
end := maxBytes
for end > 0 && chunk[end]>>6 == 0b10 {
end--
}
if end == 0 {
end = maxBytes
}
result = append(result, chunk[:end])
chunk = strings.TrimLeft(chunk[end:], " \t\r\n")
}
if chunk != "" {
result = append(result, chunk)
}
}
return result
}
+167
View File
@@ -0,0 +1,167 @@
package wecom
import (
"context"
"errors"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) {
t.Parallel()
messageBus := bus.NewMessageBus()
ch := newTestWeComChannel(t, messageBus)
var commands []wecomCommand
ch.commandSend = func(cmd wecomCommand, _ time.Duration) error {
commands = append(commands, cmd)
return nil
}
msg := wecomIncomingMessage{
MsgID: "msg-1",
ChatID: "chat-1",
ChatType: "direct",
MsgType: "text",
Text: &struct {
Content string `json:"content"`
}{Content: "hello"},
}
msg.From.UserID = "user-1"
if err := ch.dispatchIncoming("req-1", msg); err != nil {
t.Fatalf("dispatchIncoming() error = %v", err)
}
select {
case inbound := <-messageBus.InboundChan():
if inbound.ChatID != "chat-1" {
t.Fatalf("inbound ChatID = %q, want chat-1", inbound.ChatID)
}
if inbound.MessageID != "msg-1" {
t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID)
}
if inbound.Peer.ID != "chat-1" {
t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID)
}
if inbound.Metadata["req_id"] != "req-1" {
t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"])
}
default:
t.Fatal("expected inbound message to be published")
}
turn, ok := ch.getTurn("chat-1")
if !ok {
t.Fatal("expected queued turn for chat-1")
}
if turn.ReqID != "req-1" {
t.Fatalf("turn.ReqID = %q, want req-1", turn.ReqID)
}
route, ok := ch.routes.Get("chat-1")
if !ok {
t.Fatal("expected persisted route for chat-1")
}
if route.ReqID != "req-1" || route.ChatType != 1 {
t.Fatalf("route = %+v", route)
}
if len(commands) != 1 {
t.Fatalf("expected 1 opening command, got %d", len(commands))
}
if commands[0].Cmd != wecomCmdRespondMsg {
t.Fatalf("opening command = %q, want %q", commands[0].Cmd, wecomCmdRespondMsg)
}
if commands[0].Headers.ReqID != "req-1" {
t.Fatalf("opening req_id = %q, want req-1", commands[0].Headers.ReqID)
}
}
func TestSend_StreamFailureFallsBackToActualChatID(t *testing.T) {
t.Parallel()
ch := newTestWeComChannel(t, bus.NewMessageBus())
ch.SetRunning(true)
ch.queueTurn("chat-1", wecomTurn{
ReqID: "req-1",
ChatID: "chat-1",
ChatType: 1,
StreamID: "stream-1",
CreatedAt: time.Now(),
})
ch.queueTurn("chat-1", wecomTurn{
ReqID: "req-2",
ChatID: "chat-1",
ChatType: 1,
StreamID: "stream-2",
CreatedAt: time.Now(),
})
if err := ch.routes.Put("chat-1", "req-2", 1, time.Hour); err != nil {
t.Fatalf("Put() error = %v", err)
}
var commands []wecomCommand
ch.commandSend = func(cmd wecomCommand, _ time.Duration) error {
commands = append(commands, cmd)
if len(commands) == 1 && cmd.Cmd == wecomCmdRespondMsg {
return errors.New("stream send failed")
}
return nil
}
if err := ch.Send(context.Background(), bus.OutboundMessage{
Channel: "wecom",
ChatID: "chat-1",
Content: "hello",
}); err != nil {
t.Fatalf("Send() error = %v", err)
}
if len(commands) != 2 {
t.Fatalf("expected 2 commands, got %d", len(commands))
}
if commands[0].Cmd != wecomCmdRespondMsg || commands[0].Headers.ReqID != "req-1" {
t.Fatalf("first command = %+v", commands[0])
}
if commands[1].Cmd != wecomCmdSendMsg {
t.Fatalf("second command = %q, want %q", commands[1].Cmd, wecomCmdSendMsg)
}
body, ok := commands[1].Body.(wecomSendMsgBody)
if !ok {
t.Fatalf("unexpected send body type %T", commands[1].Body)
}
if body.ChatID != "chat-1" {
t.Fatalf("send chatid = %q, want chat-1", body.ChatID)
}
if body.ChatType != 1 {
t.Fatalf("send chat_type = %d, want 1", body.ChatType)
}
nextTurn, ok := ch.getTurn("chat-1")
if !ok {
t.Fatal("expected second turn to remain queued")
}
if nextTurn.ReqID != "req-2" {
t.Fatalf("next queued req_id = %q, want req-2", nextTurn.ReqID)
}
}
func newTestWeComChannel(t *testing.T, messageBus *bus.MessageBus) *WeComChannel {
t.Helper()
cfg := config.WeComConfig{BotID: "bot-1"}
cfg.SetSecret("secret-1")
ch, err := NewChannel(cfg, messageBus)
if err != nil {
t.Fatalf("NewChannel() error = %v", err)
}
ch.ctx = context.Background()
ch.routes = newReqIDStore(filepath.Join(t.TempDir(), "reqids.json"))
return ch
}
+22 -181
View File
@@ -321,10 +321,7 @@ type AgentDefaults struct {
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
}
const (
DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
DefaultWeComAIBotProcessingMessage = "⏳ Processing, please wait. The results will be sent shortly."
)
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
func (d *AgentDefaults) GetMaxMediaSize() int {
if d.MaxMediaSize > 0 {
@@ -364,9 +361,7 @@ type ChannelsConfig struct {
Matrix MatrixConfig `json:"matrix"`
LINE LINEConfig `json:"line"`
OneBot OneBotConfig `json:"onebot"`
WeCom WeComConfig `json:"wecom"`
WeComApp WeComAppConfig `json:"wecom_app"`
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
WeCom WeComConfig `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
Weixin WeixinConfig `json:"weixin"`
Pico PicoConfig `json:"pico"`
PicoClient PicoClientConfig `json:"pico_client"`
@@ -678,136 +673,28 @@ func (c *OneBotConfig) SetAccessToken(token string) {
c.secDirty = true
}
type WeComGroupConfig struct {
AllowFrom FlexibleStringSlice `json:"allow_from,omitempty"`
}
type WeComConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
token string
encodingAESKey string
WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"`
secDirty bool
Enabled bool `json:"enabled" env:"ENABLED"`
BotID string `json:"bot_id" env:"BOT_ID"`
secret string
WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"`
SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"`
secDirty bool
}
// Token returns the WeCom token
func (c *WeComConfig) Token() string {
return c.token
}
// SetToken sets the WeCom token
func (c *WeComConfig) SetToken(token string) {
c.token = token
c.secDirty = true
}
// EncodingAESKey returns the WeCom encoding AES key
func (c *WeComConfig) EncodingAESKey() string {
return c.encodingAESKey
}
// SetEncodingAESKey sets the WeCom encoding AES key
func (c *WeComConfig) SetEncodingAESKey(key string) {
c.encodingAESKey = key
c.secDirty = true
}
type WeComAppConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
corpSecret string
AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
token string
encodingAESKey string
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
secDirty bool
}
// CorpSecret returns the corporate secret for WeCom app
func (c *WeComAppConfig) CorpSecret() string {
return c.corpSecret
}
// SetCorpSecret sets the corporate secret for WeCom app
func (c *WeComAppConfig) SetCorpSecret(secret string) {
c.corpSecret = secret
c.secDirty = true
}
// Token returns the webhook token for WeCom app
func (c *WeComAppConfig) Token() string {
return c.token
}
// SetToken sets the webhook token for WeCom app
func (c *WeComAppConfig) SetToken(token string) {
c.token = token
c.secDirty = true
}
// EncodingAESKey returns the encoding AES key for WeCom app
func (c *WeComAppConfig) EncodingAESKey() string {
return c.encodingAESKey
}
// SetEncodingAESKey sets the encoding AES key for WeCom app
func (c *WeComAppConfig) SetEncodingAESKey(key string) {
c.encodingAESKey = key
c.secDirty = true
}
type WeComAIBotConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
BotID string `json:"bot_id,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_BOT_ID"`
secret string
token string
encodingAESKey string
WebhookPath string `json:"webhook_path,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
ProcessingMessage string `json:"processing_message,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_PROCESSING_MESSAGE"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
secDirty bool
}
// Token returns the webhook token for WeCom AI bot
func (c *WeComAIBotConfig) Token() string {
return c.token
}
// EncodingAESKey returns the encoding AES key for WeCom AI bot
func (c *WeComAIBotConfig) EncodingAESKey() string {
return c.encodingAESKey
}
// SetToken sets the token for WeCom AI bot
func (c *WeComAIBotConfig) SetToken(token string) {
c.token = token
c.secDirty = true
}
// SetEncodingAESKey sets the encoding AES key for WeCom AI bot
func (c *WeComAIBotConfig) SetEncodingAESKey(key string) {
c.encodingAESKey = key
c.secDirty = true
}
func (c *WeComAIBotConfig) Secret() string {
// Secret returns the WeCom bot secret.
func (c *WeComConfig) Secret() string {
return c.secret
}
func (c *WeComAIBotConfig) SetSecret(secret string) {
// SetSecret sets the WeCom bot secret.
func (c *WeComConfig) SetSecret(secret string) {
c.secret = secret
c.secDirty = true
}
@@ -1623,39 +1510,10 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error {
cfg.Channels.OneBot.accessToken = sec.Channels.OneBot.AccessToken
}
// Handle WeCom token and encoding key
// Handle WeCom bot secret
if sec.Channels.WeCom != nil {
if sec.Channels.WeCom.Token != "" {
cfg.Channels.WeCom.token = sec.Channels.WeCom.Token
}
if sec.Channels.WeCom.EncodingAESKey != "" {
cfg.Channels.WeCom.encodingAESKey = sec.Channels.WeCom.EncodingAESKey
}
}
// Handle WeCom App credentials
if sec.Channels.WeComApp != nil {
if sec.Channels.WeComApp.CorpSecret != "" {
cfg.Channels.WeComApp.corpSecret = sec.Channels.WeComApp.CorpSecret
}
if sec.Channels.WeComApp.Token != "" {
cfg.Channels.WeComApp.token = sec.Channels.WeComApp.Token
}
if sec.Channels.WeComApp.EncodingAESKey != "" {
cfg.Channels.WeComApp.encodingAESKey = sec.Channels.WeComApp.EncodingAESKey
}
}
// Handle WeCom AI Bot credentials
if sec.Channels.WeComAIBot != nil {
if sec.Channels.WeComAIBot.Token != "" {
cfg.Channels.WeComAIBot.token = sec.Channels.WeComAIBot.Token
}
if sec.Channels.WeComAIBot.EncodingAESKey != "" {
cfg.Channels.WeComAIBot.encodingAESKey = sec.Channels.WeComAIBot.EncodingAESKey
}
if sec.Channels.WeComAIBot.Secret != "" {
cfg.Channels.WeComAIBot.secret = sec.Channels.WeComAIBot.Secret
if sec.Channels.WeCom.Secret != "" {
cfg.Channels.WeCom.secret = sec.Channels.WeCom.Secret
}
}
@@ -1879,27 +1737,10 @@ func SaveConfig(path string, cfg *Config) error {
}
if cfg.Channels.WeCom.secDirty {
cfg.security.Channels.WeCom = &WeComSecurity{
Token: cfg.Channels.WeCom.Token(),
EncodingAESKey: cfg.Channels.WeCom.EncodingAESKey(),
Secret: cfg.Channels.WeCom.Secret(),
}
cfg.Channels.WeCom.secDirty = false
}
if cfg.Channels.WeComApp.secDirty {
cfg.security.Channels.WeComApp = &WeComAppSecurity{
CorpSecret: cfg.Channels.WeComApp.CorpSecret(),
Token: cfg.Channels.WeComApp.Token(),
EncodingAESKey: cfg.Channels.WeComApp.EncodingAESKey(),
}
cfg.Channels.WeComApp.secDirty = false
}
if cfg.Channels.WeComAIBot.secDirty {
cfg.security.Channels.WeComAIBot = &WeComAIBotSecurity{
Token: cfg.Channels.WeComAIBot.Token(),
EncodingAESKey: cfg.Channels.WeComAIBot.EncodingAESKey(),
Secret: cfg.Channels.WeComAIBot.Secret(),
}
cfg.Channels.WeComAIBot.secDirty = false
}
if cfg.Tools.Web.Brave.secDirty {
cfg.security.Web.Brave = &BraveSecurity{
APIKeys: cfg.Tools.Web.Brave.APIKeys(),
+63 -153
View File
@@ -85,23 +85,21 @@ type toolsConfigV0 struct {
}
type channelsConfigV0 struct {
WhatsApp WhatsAppConfig `json:"whatsapp"`
Telegram telegramConfigV0 `json:"telegram"`
Feishu feishuConfigV0 `json:"feishu"`
Discord discordConfigV0 `json:"discord"`
MaixCam maixcamConfigV0 `json:"maixcam"`
Weixin weixinConfigV0 `json:"weixin"`
QQ qqConfigV0 `json:"qq"`
DingTalk dingtalkConfigV0 `json:"dingtalk"`
Slack slackConfigV0 `json:"slack"`
Matrix matrixConfigV0 `json:"matrix"`
LINE lineConfigV0 `json:"line"`
OneBot onebotConfigV0 `json:"onebot"`
WeCom wecomConfigV0 `json:"wecom"`
WeComApp wecomappConfigV0 `json:"wecom_app"`
WeComAIBot wecomaibotConfigV0 `json:"wecom_aibot"`
Pico picoConfigV0 `json:"pico"`
IRC ircConfigV0 `json:"irc"`
WhatsApp WhatsAppConfig `json:"whatsapp"`
Telegram telegramConfigV0 `json:"telegram"`
Feishu feishuConfigV0 `json:"feishu"`
Discord discordConfigV0 `json:"discord"`
MaixCam maixcamConfigV0 `json:"maixcam"`
Weixin weixinConfigV0 `json:"weixin"`
QQ qqConfigV0 `json:"qq"`
DingTalk dingtalkConfigV0 `json:"dingtalk"`
Slack slackConfigV0 `json:"slack"`
Matrix matrixConfigV0 `json:"matrix"`
LINE lineConfigV0 `json:"line"`
OneBot onebotConfigV0 `json:"onebot"`
WeCom wecomConfigV0 `json:"wecom" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
Pico picoConfigV0 `json:"pico"`
IRC ircConfigV0 `json:"irc"`
}
func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity) {
@@ -117,45 +115,39 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity)
line, lineSecurity := v.LINE.ToLINEConfig()
onebot, onebotSecurity := v.OneBot.ToOneBotConfig()
wecom, wecomSecurity := v.WeCom.ToWeComConfig()
wecomapp, wecomappSecurity := v.WeComApp.ToWeComAppConfig()
wecomaibot, wecomaibotSecurity := v.WeComAIBot.ToWeComAIBotConfig()
pico, picoSecurity := v.Pico.ToPicoConfig()
irc, ircSecurity := v.IRC.ToIRCConfig()
return ChannelsConfig{
WhatsApp: v.WhatsApp,
Telegram: telegram,
Feishu: feishu,
Discord: discord,
MaixCam: maixcam,
QQ: qq,
Weixin: weixin,
DingTalk: dingtalk,
Slack: slack,
Matrix: matrix,
LINE: line,
OneBot: onebot,
WeCom: wecom,
WeComApp: wecomapp,
WeComAIBot: wecomaibot,
Pico: pico,
IRC: irc,
WhatsApp: v.WhatsApp,
Telegram: telegram,
Feishu: feishu,
Discord: discord,
MaixCam: maixcam,
QQ: qq,
Weixin: weixin,
DingTalk: dingtalk,
Slack: slack,
Matrix: matrix,
LINE: line,
OneBot: onebot,
WeCom: wecom,
Pico: pico,
IRC: irc,
}, ChannelsSecurity{
Telegram: telegramSecurity,
Feishu: feishuSecurity,
Discord: discordSecurity,
QQ: qqSecurity,
Weixin: weixinSecurity,
DingTalk: dingtalkSecurity,
Slack: slackSecurity,
Matrix: matrixSecurity,
LINE: lineSecurity,
OneBot: onebotSecurity,
WeCom: wecomSecurity,
WeComApp: wecomappSecurity,
WeComAIBot: wecomaibotSecurity,
Pico: picoSecurity,
IRC: ircSecurity,
Telegram: telegramSecurity,
Feishu: feishuSecurity,
Discord: discordSecurity,
QQ: qqSecurity,
Weixin: weixinSecurity,
DingTalk: dingtalkSecurity,
Slack: slackSecurity,
Matrix: matrixSecurity,
LINE: lineSecurity,
OneBot: onebotSecurity,
WeCom: wecomSecurity,
Pico: picoSecurity,
IRC: ircSecurity,
}
}
@@ -473,39 +465,32 @@ func (v *onebotConfigV0) ToOneBotConfig() (OneBotConfig, *OneBotSecurity) {
}
type wecomConfigV0 struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"`
Enabled bool `json:"enabled" env:"ENABLED"`
BotID string `json:"bot_id" env:"BOT_ID"`
Secret string `json:"secret" env:"SECRET"`
WebSocketURL string `json:"websocket_url,omitempty" env:"WEBSOCKET_URL"`
SendThinkingMessage bool `json:"send_thinking_message" env:"SEND_THINKING_MESSAGE"`
DMPolicy string `json:"dm_policy,omitempty" env:"DM_POLICY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"ALLOW_FROM"`
GroupPolicy string `json:"group_policy,omitempty" env:"GROUP_POLICY"`
GroupAllowFrom FlexibleStringSlice `json:"group_allow_from,omitempty" env:"GROUP_ALLOW_FROM"`
Groups map[string]WeComGroupConfig `json:"groups,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"REASONING_CHANNEL_ID"`
}
func (v *wecomConfigV0) ToWeComConfig() (WeComConfig, *WeComSecurity) {
var sec *WeComSecurity
if v.Token != "" || v.EncodingAESKey != "" {
sec = &WeComSecurity{
Token: v.Token,
EncodingAESKey: v.EncodingAESKey,
}
if v.Secret != "" {
sec = &WeComSecurity{Secret: v.Secret}
}
return WeComConfig{
Enabled: v.Enabled,
token: v.Token,
encodingAESKey: v.EncodingAESKey,
WebhookURL: v.WebhookURL,
WebhookHost: v.WebhookHost,
WebhookPort: v.WebhookPort,
WebhookPath: v.WebhookPath,
AllowFrom: v.AllowFrom,
ReplyTimeout: v.ReplyTimeout,
GroupTrigger: v.GroupTrigger,
ReasoningChannelID: v.ReasoningChannelID,
Enabled: v.Enabled,
BotID: v.BotID,
secret: v.Secret,
WebSocketURL: v.WebSocketURL,
SendThinkingMessage: v.SendThinkingMessage,
AllowFrom: v.AllowFrom,
ReasoningChannelID: v.ReasoningChannelID,
}, sec
}
@@ -537,81 +522,6 @@ func (v *weixinConfigV0) ToWeiXinConfig() (WeixinConfig, *WeixinSecurity) {
}, sec
}
type wecomappConfigV0 struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
}
func (v *wecomappConfigV0) ToWeComAppConfig() (WeComAppConfig, *WeComAppSecurity) {
var sec *WeComAppSecurity
if v.CorpSecret != "" || v.Token != "" || v.EncodingAESKey != "" {
sec = &WeComAppSecurity{
CorpSecret: v.CorpSecret,
Token: v.Token,
EncodingAESKey: v.EncodingAESKey,
}
}
return WeComAppConfig{
Enabled: v.Enabled,
CorpID: v.CorpID,
corpSecret: v.CorpSecret,
AgentID: v.AgentID,
token: v.Token,
encodingAESKey: v.EncodingAESKey,
WebhookHost: v.WebhookHost,
WebhookPort: v.WebhookPort,
WebhookPath: v.WebhookPath,
AllowFrom: v.AllowFrom,
ReplyTimeout: v.ReplyTimeout,
GroupTrigger: v.GroupTrigger,
ReasoningChannelID: v.ReasoningChannelID,
}, sec
}
type wecomaibotConfigV0 struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
Secret string `json:"secret" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"`
EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"`
WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
}
func (v *wecomaibotConfigV0) ToWeComAIBotConfig() (WeComAIBotConfig, *WeComAIBotSecurity) {
var sec *WeComAIBotSecurity
if v.Token != "" || v.Secret != "" || v.EncodingAESKey != "" {
sec = &WeComAIBotSecurity{
Token: v.Token,
Secret: v.Secret,
EncodingAESKey: v.EncodingAESKey,
}
}
return WeComAIBotConfig{
Enabled: v.Enabled,
WebhookPath: v.WebhookPath,
AllowFrom: v.AllowFrom,
ReplyTimeout: v.ReplyTimeout,
MaxSteps: v.MaxSteps,
WelcomeMessage: v.WelcomeMessage,
ReasoningChannelID: v.ReasoningChannelID,
}, sec
}
type picoConfigV0 struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
+1 -2
View File
@@ -1372,8 +1372,7 @@ func TestFilterSensitiveData_AllTokenTypes(t *testing.T) {
Feishu: &FeishuSecurity{AppSecret: "feishu-app-secret-123", EncryptKey: "feishu-encrypt-key"},
DingTalk: &DingTalkSecurity{ClientSecret: "dingtalk-client-secret"},
OneBot: &OneBotSecurity{AccessToken: "onebot-access-token"},
WeCom: &WeComSecurity{Token: "wecom-token", EncodingAESKey: "wecom-aes-key"},
WeComApp: &WeComAppSecurity{CorpSecret: "wecom-app-secret", Token: "wecom-app-token"},
WeCom: &WeComSecurity{Secret: "wecom-secret"},
Pico: &PicoSecurity{Token: "pico-token-abc123"},
IRC: &IRCSecurity{
Password: "irc-password",
+5 -26
View File
@@ -129,32 +129,11 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
},
WeCom: WeComConfig{
Enabled: false,
WebhookURL: "",
WebhookHost: "0.0.0.0",
WebhookPort: 18793,
WebhookPath: "/webhook/wecom",
AllowFrom: FlexibleStringSlice{},
ReplyTimeout: 5,
},
WeComApp: WeComAppConfig{
Enabled: false,
CorpID: "",
AgentID: 0,
WebhookHost: "0.0.0.0",
WebhookPort: 18792,
WebhookPath: "/webhook/wecom-app",
AllowFrom: FlexibleStringSlice{},
ReplyTimeout: 5,
},
WeComAIBot: WeComAIBotConfig{
Enabled: false,
WebhookPath: "/webhook/wecom-aibot",
AllowFrom: FlexibleStringSlice{},
ReplyTimeout: 5,
MaxSteps: 10,
WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
ProcessingMessage: DefaultWeComAIBotProcessingMessage,
Enabled: false,
BotID: "",
WebSocketURL: "wss://openws.work.weixin.qq.com",
SendThinkingMessage: true,
AllowFrom: FlexibleStringSlice{},
},
Weixin: WeixinConfig{
Enabled: false,
+14 -29
View File
@@ -69,21 +69,19 @@ type ModelSecurityEntry struct {
// ChannelsSecurity stores channel-related security data
type ChannelsSecurity struct {
Telegram *TelegramSecurity `yaml:"telegram,omitempty"`
Feishu *FeishuSecurity `yaml:"feishu,omitempty"`
Discord *DiscordSecurity `yaml:"discord,omitempty"`
Weixin *WeixinSecurity `yaml:"weixin,omitempty"`
QQ *QQSecurity `yaml:"qq,omitempty"`
DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"`
Slack *SlackSecurity `yaml:"slack,omitempty"`
Matrix *MatrixSecurity `yaml:"matrix,omitempty"`
LINE *LINESecurity `yaml:"line,omitempty"`
OneBot *OneBotSecurity `yaml:"onebot,omitempty"`
WeCom *WeComSecurity `yaml:"wecom,omitempty"`
WeComApp *WeComAppSecurity `yaml:"wecom_app,omitempty"`
WeComAIBot *WeComAIBotSecurity `yaml:"wecom_aibot,omitempty"`
Pico *PicoSecurity `yaml:"pico,omitempty"`
IRC *IRCSecurity `yaml:"irc,omitempty"`
Telegram *TelegramSecurity `yaml:"telegram,omitempty"`
Feishu *FeishuSecurity `yaml:"feishu,omitempty"`
Discord *DiscordSecurity `yaml:"discord,omitempty"`
Weixin *WeixinSecurity `yaml:"weixin,omitempty"`
QQ *QQSecurity `yaml:"qq,omitempty"`
DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"`
Slack *SlackSecurity `yaml:"slack,omitempty"`
Matrix *MatrixSecurity `yaml:"matrix,omitempty"`
LINE *LINESecurity `yaml:"line,omitempty"`
OneBot *OneBotSecurity `yaml:"onebot,omitempty"`
WeCom *WeComSecurity `yaml:"wecom,omitempty"`
Pico *PicoSecurity `yaml:"pico,omitempty"`
IRC *IRCSecurity `yaml:"irc,omitempty"`
}
type TelegramSecurity struct {
@@ -131,20 +129,7 @@ type OneBotSecurity struct {
}
type WeComSecurity struct {
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
}
type WeComAppSecurity struct {
CorpSecret string `yaml:"corp_secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
}
type WeComAIBotSecurity struct {
Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"`
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
EncodingAESKey string `yaml:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
Secret string `yaml:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_SECRET"`
}
type PicoSecurity struct {
+6 -36
View File
@@ -240,15 +240,7 @@ func TestAllSecurityKeysAccessible(t *testing.T) {
},
"wecom": {
"enabled": true,
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook"
},
"wecom_app": {
"enabled": true,
"corp_id": "test_corp_id",
"agent_id": 123456
},
"wecom_aibot": {
"enabled": true
"bot_id": "test_wecom_bot_id"
},
"pico": {
"enabled": true
@@ -315,15 +307,7 @@ channels:
onebot:
access_token: "onebot_test_access_token"
wecom:
token: "wecom_test_webhook_token"
encoding_aes_key: "wecom_test_aes_key"
wecom_app:
corp_secret: "wecom_app_test_corp_secret"
token: "wecom_app_test_token"
encoding_aes_key: "wecom_app_test_aes_key"
wecom_aibot:
token: "wecom_aibot_test_token"
encoding_aes_key: "wecom_aibot_test_aes_key"
secret: "wecom_test_secret"
pico:
token: "pico_test_token"
irc:
@@ -409,24 +393,10 @@ skills:
t.Logf("OneBot AccessToken(): %s", cfg.Channels.OneBot.AccessToken())
// WeCom
assert.Equal(t, "wecom_test_webhook_token", cfg.Channels.WeCom.Token())
assert.Equal(t, "wecom_test_aes_key", cfg.Channels.WeCom.EncodingAESKey())
t.Logf("WeCom Token(): %s", cfg.Channels.WeCom.Token())
t.Logf("WeCom EncodingAESKey(): %s", cfg.Channels.WeCom.EncodingAESKey())
// WeCom App
assert.Equal(t, "wecom_app_test_corp_secret", cfg.Channels.WeComApp.CorpSecret())
assert.Equal(t, "wecom_app_test_token", cfg.Channels.WeComApp.Token())
assert.Equal(t, "wecom_app_test_aes_key", cfg.Channels.WeComApp.EncodingAESKey())
t.Logf("WeComApp CorpSecret(): %s", cfg.Channels.WeComApp.CorpSecret())
t.Logf("WeComApp Token(): %s", cfg.Channels.WeComApp.Token())
t.Logf("WeComApp EncodingAESKey(): %s", cfg.Channels.WeComApp.EncodingAESKey())
// WeCom AI Bot
assert.Equal(t, "wecom_aibot_test_token", cfg.Channels.WeComAIBot.Token())
assert.Equal(t, "wecom_aibot_test_aes_key", cfg.Channels.WeComAIBot.EncodingAESKey())
t.Logf("WeComAIBot Token(): %s", cfg.Channels.WeComAIBot.Token())
t.Logf("WeComAIBot EncodingAESKey(): %s", cfg.Channels.WeComAIBot.EncodingAESKey())
assert.Equal(t, "test_wecom_bot_id", cfg.Channels.WeCom.BotID)
assert.Equal(t, "wecom_test_secret", cfg.Channels.WeCom.Secret())
t.Logf("WeCom BotID: %s", cfg.Channels.WeCom.BotID)
t.Logf("WeCom Secret(): %s", cfg.Channels.WeCom.Secret())
// Pico
assert.Equal(t, "pico_test_token", cfg.Channels.Pico.Token())
+12 -13
View File
@@ -13,17 +13,16 @@ var migrateableDirs = []string{
}
var supportedChannels = map[string]bool{
"whatsapp": true,
"telegram": true,
"feishu": true,
"discord": true,
"maixcam": true,
"qq": true,
"dingtalk": true,
"slack": true,
"matrix": true,
"line": true,
"onebot": true,
"wecom": true,
"wecom_app": true,
"whatsapp": true,
"telegram": true,
"feishu": true,
"discord": true,
"maixcam": true,
"qq": true,
"dingtalk": true,
"slack": true,
"matrix": true,
"line": true,
"onebot": true,
"wecom": true,
}
-2
View File
@@ -22,8 +22,6 @@ var channelCatalog = []channelCatalogItem{
{Name: "qq", ConfigKey: "qq"},
{Name: "onebot", ConfigKey: "onebot"},
{Name: "wecom", ConfigKey: "wecom"},
{Name: "wecom_app", ConfigKey: "wecom_app"},
{Name: "wecom_aibot", ConfigKey: "wecom_aibot"},
{Name: "whatsapp", ConfigKey: "whatsapp", Variant: "bridge"},
{Name: "whatsapp_native", ConfigKey: "whatsapp", Variant: "native"},
{Name: "pico", ConfigKey: "pico"},
+9
View File
@@ -209,6 +209,15 @@ func validateConfig(cfg *config.Config) []string {
errs = append(errs, "channels.discord.token is required when discord channel is enabled")
}
if cfg.Channels.WeCom.Enabled {
if cfg.Channels.WeCom.BotID == "" {
errs = append(errs, "channels.wecom.bot_id is required when wecom channel is enabled")
}
if cfg.Channels.WeCom.Secret() == "" {
errs = append(errs, "channels.wecom.secret is required when wecom channel is enabled")
}
}
if cfg.Tools.Exec.Enabled {
if cfg.Tools.Exec.EnableDenyPatterns {
errs = append(
@@ -146,13 +146,7 @@ function isConfigured(
case "weixin":
return asString(config.account_id) !== ""
case "wecom":
return asString(config.token) !== ""
case "wecom_app":
return (
asString(config.corp_id) !== "" && asString(config.corp_secret) !== ""
)
case "wecom_aibot":
return asString(config.token) !== ""
return asString(config.bot_id) !== ""
case "whatsapp":
return asString(config.bridge_url) !== ""
case "whatsapp_native":
@@ -193,11 +187,7 @@ function getRequiredFieldKeys(channelName: string): string[] {
case "onebot":
return ["ws_url"]
case "wecom":
return ["token"]
case "wecom_app":
return ["corp_id", "corp_secret"]
case "wecom_aibot":
return ["token"]
return ["bot_id", "secret"]
case "whatsapp":
return ["bridge_url"]
case "pico":
@@ -28,6 +28,7 @@ const SECRET_FIELDS = new Set([
"encoding_aes_key",
"encrypt_key",
"verification_token",
"secret",
"password",
"nickserv_password",
"sasl_password",
@@ -44,6 +45,7 @@ const OBJECT_FIELDS = new Set([
"allow_token_query",
"allow_from",
"allow_origins",
"groups",
])
function formatLabel(key: string): string {
@@ -118,6 +120,14 @@ export function GenericForm({
app_id: t("channels.form.desc.appId"),
client_id: t("channels.form.desc.clientId"),
corp_id: t("channels.form.desc.corpId"),
bot_id: t("channels.form.desc.appId"),
websocket_url: t("channels.form.desc.wsUrl"),
dm_policy: t("channels.form.desc.genericField", { field: "DM policy" }),
group_policy: t("channels.form.desc.genericField", { field: "group policy" }),
group_allow_from: t("channels.form.desc.allowFrom"),
send_thinking_message: t("channels.form.desc.genericField", {
field: "thinking message behavior",
}),
agent_id: t("channels.form.desc.agentId"),
webhook_url: t("channels.form.desc.webhookUrl"),
webhook_host: t("channels.form.desc.webhookHost"),
@@ -35,8 +35,6 @@ const CHANNEL_IMPORTANCE_ORDER = [
"slack",
"line",
"wecom",
"wecom_app",
"wecom_aibot",
"dingtalk",
"qq",
"onebot",
@@ -76,8 +74,6 @@ const CHANNEL_ICON_MAP: Record<
line: IconBrandLine,
qq: IconBrandQq,
wecom: IconBrandWechat,
wecom_app: IconBrandWechat,
wecom_aibot: IconBrandWechat,
whatsapp: IconBrandWhatsapp,
whatsapp_native: IconBrandWhatsapp,
matrix: IconBrandMatrix,
-2
View File
@@ -233,8 +233,6 @@
"qq": "QQ",
"onebot": "OneBot",
"wecom": "WeCom",
"wecom_app": "WeCom App",
"wecom_aibot": "WeCom AI Bot",
"whatsapp": "WhatsApp",
"whatsapp_native": "WhatsApp Native",
"pico": "Web",
-2
View File
@@ -233,8 +233,6 @@
"qq": "QQ",
"onebot": "OneBot",
"wecom": "企业微信",
"wecom_app": "企业微信应用",
"wecom_aibot": "企业微信 AI 机器人",
"whatsapp": "WhatsApp",
"whatsapp_native": "WhatsApp Native",
"pico": "Web",