diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 28ef76ad3..29b31e071 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -16,6 +16,9 @@ import ( "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + dch "github.com/sipeed/picoclaw/pkg/channels/discord" + slackch "github.com/sipeed/picoclaw/pkg/channels/slack" + tgram "github.com/sipeed/picoclaw/pkg/channels/telegram" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" @@ -26,6 +29,16 @@ import ( "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/voice" + + // Channel factory registrations (blank imports trigger init()) + _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" + _ "github.com/sipeed/picoclaw/pkg/channels/feishu" + _ "github.com/sipeed/picoclaw/pkg/channels/line" + _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" + _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/qq" + _ "github.com/sipeed/picoclaw/pkg/channels/wecom" + _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" ) func gatewayCmd() { @@ -138,19 +151,19 @@ func gatewayCmd() { if transcriber != nil { if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { + if tc, ok := telegramChannel.(*tgram.TelegramChannel); ok { tc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Telegram channel") } } if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*channels.DiscordChannel); ok { + if dc, ok := discordChannel.(*dch.DiscordChannel); ok { dc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Discord channel") } } if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*channels.SlackChannel); ok { + if sc, ok := slackChannel.(*slackch.SlackChannel); ok { sc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Slack channel") } diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go new file mode 100644 index 000000000..0edb0023c --- /dev/null +++ b/pkg/channels/dingtalk/dingtalk.go @@ -0,0 +1,202 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// DingTalk channel implementation using Stream Mode + +package dingtalk + +import ( + "context" + "fmt" + "sync" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// DingTalkChannel implements the Channel interface for DingTalk (钉钉) +// It uses WebSocket for receiving messages via stream mode and API for sending +type DingTalkChannel struct { + *channels.BaseChannel + config config.DingTalkConfig + clientID string + clientSecret string + streamClient *client.StreamClient + ctx context.Context + cancel context.CancelFunc + // Map to store session webhooks for each chat + sessionWebhooks sync.Map // chatID -> sessionWebhook +} + +// NewDingTalkChannel creates a new DingTalk channel instance +func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (*DingTalkChannel, error) { + if cfg.ClientID == "" || cfg.ClientSecret == "" { + return nil, fmt.Errorf("dingtalk client_id and client_secret are required") + } + + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) + + return &DingTalkChannel{ + BaseChannel: base, + config: cfg, + clientID: cfg.ClientID, + clientSecret: cfg.ClientSecret, + }, nil +} + +// Start initializes the DingTalk channel with Stream Mode +func (c *DingTalkChannel) Start(ctx context.Context) error { + logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Create credential config + cred := client.NewAppCredentialConfig(c.clientID, c.clientSecret) + + // Create the stream client with options + c.streamClient = client.NewStreamClient( + client.WithAppCredential(cred), + client.WithAutoReconnect(true), + ) + + // Register chatbot callback handler (IChatBotMessageHandler is a function type) + c.streamClient.RegisterChatBotCallbackRouter(c.onChatBotMessageReceived) + + // Start the stream client + if err := c.streamClient.Start(c.ctx); err != nil { + return fmt.Errorf("failed to start stream client: %w", err) + } + + c.SetRunning(true) + logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") + return nil +} + +// Stop gracefully stops the DingTalk channel +func (c *DingTalkChannel) Stop(ctx context.Context) error { + logger.InfoC("dingtalk", "Stopping DingTalk channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.streamClient != nil { + c.streamClient.Close() + } + + c.SetRunning(false) + logger.InfoC("dingtalk", "DingTalk channel stopped") + return nil +} + +// Send sends a message to DingTalk via the chatbot reply API +func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("dingtalk channel not running") + } + + // Get session webhook from storage + sessionWebhookRaw, ok := c.sessionWebhooks.Load(msg.ChatID) + if !ok { + return fmt.Errorf("no session_webhook found for chat %s, cannot send message", msg.ChatID) + } + + sessionWebhook, ok := sessionWebhookRaw.(string) + if !ok { + return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) + } + + logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + // Use the session webhook to send the reply + return c.SendDirectReply(ctx, sessionWebhook, msg.Content) +} + +// onChatBotMessageReceived implements the IChatBotMessageHandler function signature +// This is called by the Stream SDK when a new message arrives +// IChatBotMessageHandler is: func(c context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) +func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { + // Extract message content from Text field + content := data.Text.Content + if content == "" { + // Try to extract from Content interface{} if Text is empty + if contentMap, ok := data.Content.(map[string]interface{}); ok { + if textContent, ok := contentMap["content"].(string); ok { + content = textContent + } + } + } + + if content == "" { + return nil, nil // Ignore empty messages + } + + senderID := data.SenderStaffId + senderNick := data.SenderNick + chatID := senderID + if data.ConversationType != "1" { + // For group chats + chatID = data.ConversationId + } + + // Store the session webhook for this chat so we can reply later + c.sessionWebhooks.Store(chatID, data.SessionWebhook) + + metadata := map[string]string{ + "sender_name": senderNick, + "conversation_id": data.ConversationId, + "conversation_type": data.ConversationType, + "platform": "dingtalk", + "session_webhook": data.SessionWebhook, + } + + if data.ConversationType == "1" { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = data.ConversationId + } + + logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + "sender_nick": senderNick, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) + + // Return nil to indicate we've handled the message asynchronously + // The response will be sent through the message bus + return nil, nil +} + +// SendDirectReply sends a direct reply using the session webhook +func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { + replier := chatbot.NewChatbotReplier() + + // Convert string content to []byte for the API + contentBytes := []byte(content) + titleBytes := []byte("PicoClaw") + + // Send markdown formatted reply + err := replier.SimpleReplyMarkdown( + ctx, + sessionWebhook, + titleBytes, + contentBytes, + ) + + if err != nil { + return fmt.Errorf("failed to send reply: %w", err) + } + + return nil +} diff --git a/pkg/channels/dingtalk/init.go b/pkg/channels/dingtalk/init.go new file mode 100644 index 000000000..5f49bce8c --- /dev/null +++ b/pkg/channels/dingtalk/init.go @@ -0,0 +1,13 @@ +package dingtalk + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDingTalkChannel(cfg.Channels.DingTalk, b) + }) +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go new file mode 100644 index 000000000..6c4efd87c --- /dev/null +++ b/pkg/channels/discord/discord.go @@ -0,0 +1,373 @@ +package discord + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +const ( + transcriptionTimeout = 30 * time.Second + sendTimeout = 10 * time.Second +) + +type DiscordChannel struct { + *channels.BaseChannel + session *discordgo.Session + config config.DiscordConfig + transcriber *voice.GroqTranscriber + ctx context.Context + typingMu sync.Mutex + typingStop map[string]chan struct{} // chatID → stop signal + botUserID string // stored for mention checking +} + +func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { + session, err := discordgo.New("Bot " + cfg.Token) + if err != nil { + return nil, fmt.Errorf("failed to create discord session: %w", err) + } + + base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) + + return &DiscordChannel{ + BaseChannel: base, + session: session, + config: cfg, + transcriber: nil, + ctx: context.Background(), + typingStop: make(map[string]chan struct{}), + }, nil +} + +func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *DiscordChannel) getContext() context.Context { + if c.ctx == nil { + return context.Background() + } + return c.ctx +} + +func (c *DiscordChannel) Start(ctx context.Context) error { + logger.InfoC("discord", "Starting Discord bot") + + c.ctx = ctx + + // Get bot user ID before opening session to avoid race condition + botUser, err := c.session.User("@me") + if err != nil { + return fmt.Errorf("failed to get bot user: %w", err) + } + c.botUserID = botUser.ID + + c.session.AddHandler(c.handleMessage) + + if err := c.session.Open(); err != nil { + return fmt.Errorf("failed to open discord session: %w", err) + } + + c.SetRunning(true) + + logger.InfoCF("discord", "Discord bot connected", map[string]any{ + "username": botUser.Username, + "user_id": botUser.ID, + }) + + return nil +} + +func (c *DiscordChannel) Stop(ctx context.Context) error { + logger.InfoC("discord", "Stopping Discord bot") + c.SetRunning(false) + + // Stop all typing goroutines before closing session + c.typingMu.Lock() + for chatID, stop := range c.typingStop { + close(stop) + delete(c.typingStop, chatID) + } + c.typingMu.Unlock() + + if err := c.session.Close(); err != nil { + return fmt.Errorf("failed to close discord session: %w", err) + } + + return nil +} + +func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + c.stopTyping(msg.ChatID) + + if !c.IsRunning() { + return fmt.Errorf("discord bot not running") + } + + channelID := msg.ChatID + if channelID == "" { + return fmt.Errorf("channel ID is empty") + } + + runes := []rune(msg.Content) + if len(runes) == 0 { + return nil + } + + chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars + + for _, chunk := range chunks { + if err := c.sendChunk(ctx, channelID, chunk); err != nil { + return err + } + } + + return nil +} + +func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { + // Use the passed ctx for timeout control + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSend(channelID, content) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("failed to send discord message: %w", err) + } + return nil + case <-sendCtx.Done(): + return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + } +} + +// appendContent safely appends content to existing text +func appendContent(content, suffix string) string { + if content == "" { + return suffix + } + return content + "\n" + suffix +} + +func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { + if m == nil || m.Author == nil { + return + } + + if m.Author.ID == s.State.User.ID { + return + } + + // Check allowlist first to avoid downloading attachments and transcribing for rejected users + if !c.IsAllowed(m.Author.ID) { + logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + + // If configured to only respond to mentions, check if bot is mentioned + // Skip this check for DMs (GuildID is empty) - DMs should always be responded to + if c.config.MentionOnly && m.GuildID != "" { + isMentioned := false + for _, mention := range m.Mentions { + if mention.ID == c.botUserID { + isMentioned = true + break + } + } + if !isMentioned { + logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + } + + senderID := m.Author.ID + senderName := m.Author.Username + if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { + senderName += "#" + m.Author.Discriminator + } + + content := m.Content + content = c.stripBotMention(content) + mediaPaths := make([]string, 0, len(m.Attachments)) + localFiles := make([]string, 0, len(m.Attachments)) + + // Ensure temp files are cleaned up when function returns + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + for _, attachment := range m.Attachments { + isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) + + if isAudio { + localPath := c.downloadAttachment(attachment.URL, attachment.Filename) + if localPath != "" { + localFiles = append(localFiles, localPath) + + transcribedText := "" + if c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) + result, err := c.transcriber.Transcribe(ctx, localPath) + cancel() // Release context resources immediately to avoid leaks in for loop + + if err != nil { + logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) + } else { + transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) + logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ + "text": result.Text, + }) + } + } else { + transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) + } + + content = appendContent(content, transcribedText) + } else { + logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ + "url": attachment.URL, + "filename": attachment.Filename, + }) + mediaPaths = append(mediaPaths, attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) + } + } else { + mediaPaths = append(mediaPaths, attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) + } + } + + if content == "" && len(mediaPaths) == 0 { + return + } + + if content == "" { + content = "[media only]" + } + + // Start typing after all early returns — guaranteed to have a matching Send() + c.startTyping(m.ChannelID) + + logger.DebugCF("discord", "Received message", map[string]any{ + "sender_name": senderName, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) + + peerKind := "channel" + peerID := m.ChannelID + if m.GuildID == "" { + peerKind = "direct" + peerID = senderID + } + + metadata := map[string]string{ + "message_id": m.ID, + "user_id": senderID, + "username": m.Author.Username, + "display_name": senderName, + "guild_id": m.GuildID, + "channel_id": m.ChannelID, + "is_dm": fmt.Sprintf("%t", m.GuildID == ""), + "peer_kind": peerKind, + "peer_id": peerID, + } + + c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) +} + +// startTyping starts a continuous typing indicator loop for the given chatID. +// It stops any existing typing loop for that chatID before starting a new one. +func (c *DiscordChannel) startTyping(chatID string) { + c.typingMu.Lock() + // Stop existing loop for this chatID if any + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + } + stop := make(chan struct{}) + c.typingStop[chatID] = stop + c.typingMu.Unlock() + + go func() { + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + } + ticker := time.NewTicker(8 * time.Second) + defer ticker.Stop() + timeout := time.After(5 * time.Minute) + for { + select { + case <-stop: + return + case <-timeout: + return + case <-c.ctx.Done(): + return + case <-ticker.C: + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + } + } + } + }() +} + +// stopTyping stops the typing indicator loop for the given chatID. +func (c *DiscordChannel) stopTyping(chatID string) { + c.typingMu.Lock() + defer c.typingMu.Unlock() + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + delete(c.typingStop, chatID) + } +} + +func (c *DiscordChannel) downloadAttachment(url, filename string) string { + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "discord", + }) +} + +// stripBotMention removes the bot mention from the message content. +// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). +func (c *DiscordChannel) stripBotMention(text string) string { + if c.botUserID == "" { + return text + } + // Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID> + text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "") + text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "") + return strings.TrimSpace(text) +} diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go new file mode 100644 index 000000000..15a539804 --- /dev/null +++ b/pkg/channels/discord/init.go @@ -0,0 +1,13 @@ +package discord + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDiscordChannel(cfg.Channels.Discord, b) + }) +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go new file mode 100644 index 000000000..e8a057741 --- /dev/null +++ b/pkg/channels/feishu/common.go @@ -0,0 +1,9 @@ +package feishu + +// stringValue safely dereferences a *string pointer. +func stringValue(v *string) string { + if v == nil { + return "" + } + return *v +} diff --git a/pkg/channels/feishu/feishu_32.go b/pkg/channels/feishu/feishu_32.go new file mode 100644 index 000000000..14711e49e --- /dev/null +++ b/pkg/channels/feishu/feishu_32.go @@ -0,0 +1,37 @@ +//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64 + +package feishu + +import ( + "context" + "errors" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +// FeishuChannel is a stub implementation for 32-bit architectures +type FeishuChannel struct { + *channels.BaseChannel +} + +// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported +func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { + return nil, errors.New("feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config") +} + +// Start is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Start(ctx context.Context) error { + return nil +} + +// Stop is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Stop(ctx context.Context) error { + return nil +} + +// Send is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + return errors.New("feishu channel is not supported on 32-bit architectures") +} diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go new file mode 100644 index 000000000..a49ee34cb --- /dev/null +++ b/pkg/channels/feishu/feishu_64.go @@ -0,0 +1,221 @@ +//go:build amd64 || arm64 || riscv64 || mips64 || ppc64 + +package feishu + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + lark "github.com/larksuite/oapi-sdk-go/v3" + larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + larkws "github.com/larksuite/oapi-sdk-go/v3/ws" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type FeishuChannel struct { + *channels.BaseChannel + config config.FeishuConfig + client *lark.Client + wsClient *larkws.Client + + mu sync.Mutex + cancel context.CancelFunc +} + +func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { + base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) + + return &FeishuChannel{ + BaseChannel: base, + config: cfg, + client: lark.NewClient(cfg.AppID, cfg.AppSecret), + }, nil +} + +func (c *FeishuChannel) Start(ctx context.Context) error { + if c.config.AppID == "" || c.config.AppSecret == "" { + return fmt.Errorf("feishu app_id or app_secret is empty") + } + + dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey). + OnP2MessageReceiveV1(c.handleMessageReceive) + + runCtx, cancel := context.WithCancel(ctx) + + c.mu.Lock() + c.cancel = cancel + c.wsClient = larkws.NewClient( + c.config.AppID, + c.config.AppSecret, + larkws.WithEventHandler(dispatcher), + ) + wsClient := c.wsClient + c.mu.Unlock() + + c.SetRunning(true) + logger.InfoC("feishu", "Feishu channel started (websocket mode)") + + go func() { + if err := wsClient.Start(runCtx); err != nil { + logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + return nil +} + +func (c *FeishuChannel) Stop(ctx context.Context) error { + c.mu.Lock() + if c.cancel != nil { + c.cancel() + c.cancel = nil + } + c.wsClient = nil + c.mu.Unlock() + + c.SetRunning(false) + logger.InfoC("feishu", "Feishu channel stopped") + return nil +} + +func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("feishu channel not running") + } + + if msg.ChatID == "" { + return fmt.Errorf("chat ID is empty") + } + + payload, err := json.Marshal(map[string]string{"text": msg.Content}) + if err != nil { + return fmt.Errorf("failed to marshal feishu content: %w", err) + } + + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(msg.ChatID). + MsgType(larkim.MsgTypeText). + Content(string(payload)). + Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("failed to send feishu message: %w", err) + } + + if !resp.Success() { + return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) + } + + logger.DebugCF("feishu", "Feishu message sent", map[string]interface{}{ + "chat_id": msg.ChatID, + }) + + return nil +} + +func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error { + if event == nil || event.Event == nil || event.Event.Message == nil { + return nil + } + + message := event.Event.Message + sender := event.Event.Sender + + chatID := stringValue(message.ChatId) + if chatID == "" { + return nil + } + + senderID := extractFeishuSenderID(sender) + if senderID == "" { + senderID = "unknown" + } + + content := extractFeishuMessageContent(message) + if content == "" { + content = "[empty message]" + } + + metadata := map[string]string{} + if messageID := stringValue(message.MessageId); messageID != "" { + metadata["message_id"] = messageID + } + if messageType := stringValue(message.MessageType); messageType != "" { + metadata["message_type"] = messageType + } + if chatType := stringValue(message.ChatType); chatType != "" { + metadata["chat_type"] = chatType + } + if sender != nil && sender.TenantKey != nil { + metadata["tenant_key"] = *sender.TenantKey + } + + chatType := stringValue(message.ChatType) + if chatType == "p2p" { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } + + logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 80), + }) + + c.HandleMessage(senderID, chatID, content, nil, metadata) + return nil +} + +func extractFeishuSenderID(sender *larkim.EventSender) string { + if sender == nil || sender.SenderId == nil { + return "" + } + + if sender.SenderId.UserId != nil && *sender.SenderId.UserId != "" { + return *sender.SenderId.UserId + } + if sender.SenderId.OpenId != nil && *sender.SenderId.OpenId != "" { + return *sender.SenderId.OpenId + } + if sender.SenderId.UnionId != nil && *sender.SenderId.UnionId != "" { + return *sender.SenderId.UnionId + } + + return "" +} + +func extractFeishuMessageContent(message *larkim.EventMessage) string { + if message == nil || message.Content == nil || *message.Content == "" { + return "" + } + + if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText { + var textPayload struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil { + return textPayload.Text + } + } + + return *message.Content +} diff --git a/pkg/channels/feishu/init.go b/pkg/channels/feishu/init.go new file mode 100644 index 000000000..7e5a62dae --- /dev/null +++ b/pkg/channels/feishu/init.go @@ -0,0 +1,13 @@ +package feishu + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewFeishuChannel(cfg.Channels.Feishu, b) + }) +} diff --git a/pkg/channels/line/init.go b/pkg/channels/line/init.go new file mode 100644 index 000000000..9265575cc --- /dev/null +++ b/pkg/channels/line/init.go @@ -0,0 +1,13 @@ +package line + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewLINEChannel(cfg.Channels.LINE, b) + }) +} diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go new file mode 100644 index 000000000..7df0491d9 --- /dev/null +++ b/pkg/channels/line/line.go @@ -0,0 +1,607 @@ +package line + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + lineAPIBase = "https://api.line.me/v2/bot" + lineDataAPIBase = "https://api-data.line.me/v2/bot" + lineReplyEndpoint = lineAPIBase + "/message/reply" + linePushEndpoint = lineAPIBase + "/message/push" + lineContentEndpoint = lineDataAPIBase + "/message/%s/content" + lineBotInfoEndpoint = lineAPIBase + "/info" + lineLoadingEndpoint = lineAPIBase + "/chat/loading/start" + lineReplyTokenMaxAge = 25 * time.Second +) + +type replyTokenEntry struct { + token string + timestamp time.Time +} + +// LINEChannel implements the Channel interface for LINE Official Account +// using the LINE Messaging API with HTTP webhook for receiving messages +// and REST API for sending messages. +type LINEChannel struct { + *channels.BaseChannel + config config.LINEConfig + httpServer *http.Server + botUserID string // Bot's user ID + botBasicID string // Bot's basic ID (e.g. @216ru...) + botDisplayName string // Bot's display name for text-based mention detection + replyTokens sync.Map // chatID -> replyTokenEntry + quoteTokens sync.Map // chatID -> quoteToken (string) + ctx context.Context + cancel context.CancelFunc +} + +// NewLINEChannel creates a new LINE channel instance. +func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) { + if cfg.ChannelSecret == "" || cfg.ChannelAccessToken == "" { + return nil, fmt.Errorf("line channel_secret and channel_access_token are required") + } + + base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) + + return &LINEChannel{ + BaseChannel: base, + config: cfg, + }, nil +} + +// Start launches the HTTP webhook server. +func (c *LINEChannel) Start(ctx context.Context) error { + logger.InfoC("line", "Starting LINE channel (Webhook Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Fetch bot profile to get bot's userId for mention detection + if err := c.fetchBotInfo(); err != nil { + logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]interface{}{ + "error": err.Error(), + }) + } else { + logger.InfoCF("line", "Bot info fetched", map[string]interface{}{ + "bot_user_id": c.botUserID, + "basic_id": c.botBasicID, + "display_name": c.botDisplayName, + }) + } + + mux := http.NewServeMux() + path := c.config.WebhookPath + if path == "" { + path = "/webhook/line" + } + mux.HandleFunc(path, c.webhookHandler) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + + go func() { + logger.InfoCF("line", "LINE webhook server listening", map[string]interface{}{ + "addr": addr, + "path": path, + }) + if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("line", "Webhook server error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + c.SetRunning(true) + logger.InfoC("line", "LINE channel started (Webhook Mode)") + return nil +} + +// fetchBotInfo retrieves the bot's userId, basicId, and displayName from the LINE API. +func (c *LINEChannel) fetchBotInfo() error { + req, err := http.NewRequest(http.MethodGet, lineBotInfoEndpoint, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bot info API returned status %d", resp.StatusCode) + } + + var info struct { + UserID string `json:"userId"` + BasicID string `json:"basicId"` + DisplayName string `json:"displayName"` + } + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return err + } + + c.botUserID = info.UserID + c.botBasicID = info.BasicID + c.botDisplayName = info.DisplayName + return nil +} + +// Stop gracefully shuts down the HTTP server. +func (c *LINEChannel) Stop(ctx context.Context) error { + logger.InfoC("line", "Stopping LINE channel") + + if c.cancel != nil { + c.cancel() + } + + if c.httpServer != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := c.httpServer.Shutdown(shutdownCtx); err != nil { + logger.ErrorCF("line", "Webhook server shutdown error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + + c.SetRunning(false) + logger.InfoC("line", "LINE channel stopped") + return nil +} + +// webhookHandler handles incoming LINE webhook requests. +func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + logger.ErrorCF("line", "Failed to read request body", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + signature := r.Header.Get("X-Line-Signature") + if !c.verifySignature(body, signature) { + logger.WarnC("line", "Invalid webhook signature") + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + var payload struct { + Events []lineEvent `json:"events"` + } + if err := json.Unmarshal(body, &payload); err != nil { + logger.ErrorCF("line", "Failed to parse webhook payload", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Return 200 immediately, process events asynchronously + w.WriteHeader(http.StatusOK) + + for _, event := range payload.Events { + go c.processEvent(event) + } +} + +// verifySignature validates the X-Line-Signature using HMAC-SHA256. +func (c *LINEChannel) verifySignature(body []byte, signature string) bool { + if signature == "" { + return false + } + + mac := hmac.New(sha256.New, []byte(c.config.ChannelSecret)) + mac.Write(body) + expected := base64.StdEncoding.EncodeToString(mac.Sum(nil)) + + return hmac.Equal([]byte(expected), []byte(signature)) +} + +// LINE webhook event types +type lineEvent struct { + Type string `json:"type"` + ReplyToken string `json:"replyToken"` + Source lineSource `json:"source"` + Message json.RawMessage `json:"message"` + Timestamp int64 `json:"timestamp"` +} + +type lineSource struct { + Type string `json:"type"` // "user", "group", "room" + UserID string `json:"userId"` + GroupID string `json:"groupId"` + RoomID string `json:"roomId"` +} + +type lineMessage struct { + ID string `json:"id"` + Type string `json:"type"` // "text", "image", "video", "audio", "file", "sticker" + Text string `json:"text"` + QuoteToken string `json:"quoteToken"` + Mention *struct { + Mentionees []lineMentionee `json:"mentionees"` + } `json:"mention"` + ContentProvider struct { + Type string `json:"type"` + } `json:"contentProvider"` +} + +type lineMentionee struct { + Index int `json:"index"` + Length int `json:"length"` + Type string `json:"type"` // "user", "all" + UserID string `json:"userId"` +} + +func (c *LINEChannel) processEvent(event lineEvent) { + if event.Type != "message" { + logger.DebugCF("line", "Ignoring non-message event", map[string]interface{}{ + "type": event.Type, + }) + return + } + + senderID := event.Source.UserID + chatID := c.resolveChatID(event.Source) + isGroup := event.Source.Type == "group" || event.Source.Type == "room" + + var msg lineMessage + if err := json.Unmarshal(event.Message, &msg); err != nil { + logger.ErrorCF("line", "Failed to parse message", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + // In group chats, only respond when the bot is mentioned + if isGroup && !c.isBotMentioned(msg) { + logger.DebugCF("line", "Ignoring group message without mention", map[string]interface{}{ + "chat_id": chatID, + }) + return + } + + // Store reply token for later use + if event.ReplyToken != "" { + c.replyTokens.Store(chatID, replyTokenEntry{ + token: event.ReplyToken, + timestamp: time.Now(), + }) + } + + // Store quote token for quoting the original message in reply + if msg.QuoteToken != "" { + c.quoteTokens.Store(chatID, msg.QuoteToken) + } + + var content string + var mediaPaths []string + localFiles := []string{} + + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("line", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + switch msg.Type { + case "text": + content = msg.Text + // Strip bot mention from text in group chats + if isGroup { + content = c.stripBotMention(content, msg) + } + case "image": + localPath := c.downloadContent(msg.ID, "image.jpg") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[image]" + } + case "audio": + localPath := c.downloadContent(msg.ID, "audio.m4a") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[audio]" + } + case "video": + localPath := c.downloadContent(msg.ID, "video.mp4") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[video]" + } + case "file": + content = "[file]" + case "sticker": + content = "[sticker]" + default: + content = fmt.Sprintf("[%s]", msg.Type) + } + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "platform": "line", + "source_type": event.Source.Type, + "message_id": msg.ID, + } + + if isGroup { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } else { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } + + logger.DebugCF("line", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "message_type": msg.Type, + "is_group": isGroup, + "preview": utils.Truncate(content, 50), + }) + + // Show typing/loading indicator (requires user ID, not group ID) + c.sendLoading(senderID) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +// isBotMentioned checks if the bot is mentioned in the message. +// It first checks the mention metadata (userId match), then falls back +// to text-based detection using the bot's display name, since LINE may +// not include userId in mentionees for Official Accounts. +func (c *LINEChannel) isBotMentioned(msg lineMessage) bool { + // Check mention metadata + if msg.Mention != nil { + for _, m := range msg.Mention.Mentionees { + if m.Type == "all" { + return true + } + if c.botUserID != "" && m.UserID == c.botUserID { + return true + } + } + // Mention metadata exists with mentionees but bot not matched by userId. + // The bot IS likely mentioned (LINE includes mention struct when bot is @-ed), + // so check if any mentionee overlaps with bot display name in text. + if c.botDisplayName != "" { + for _, m := range msg.Mention.Mentionees { + if m.Index >= 0 && m.Length > 0 { + runes := []rune(msg.Text) + end := m.Index + m.Length + if end <= len(runes) { + mentionText := string(runes[m.Index:end]) + if strings.Contains(mentionText, c.botDisplayName) { + return true + } + } + } + } + } + } + + // Fallback: text-based detection with display name + if c.botDisplayName != "" && strings.Contains(msg.Text, "@"+c.botDisplayName) { + return true + } + + return false +} + +// stripBotMention removes the @BotName mention text from the message. +func (c *LINEChannel) stripBotMention(text string, msg lineMessage) string { + stripped := false + + // Try to strip using mention metadata indices + if msg.Mention != nil { + runes := []rune(text) + for i := len(msg.Mention.Mentionees) - 1; i >= 0; i-- { + m := msg.Mention.Mentionees[i] + // Strip if userId matches OR if the mention text contains the bot display name + shouldStrip := false + if c.botUserID != "" && m.UserID == c.botUserID { + shouldStrip = true + } else if c.botDisplayName != "" && m.Index >= 0 && m.Length > 0 { + end := m.Index + m.Length + if end <= len(runes) { + mentionText := string(runes[m.Index:end]) + if strings.Contains(mentionText, c.botDisplayName) { + shouldStrip = true + } + } + } + if shouldStrip { + start := m.Index + end := m.Index + m.Length + if start >= 0 && end <= len(runes) { + runes = append(runes[:start], runes[end:]...) + stripped = true + } + } + } + if stripped { + return strings.TrimSpace(string(runes)) + } + } + + // Fallback: strip @DisplayName from text + if c.botDisplayName != "" { + text = strings.ReplaceAll(text, "@"+c.botDisplayName, "") + } + + return strings.TrimSpace(text) +} + +// resolveChatID determines the chat ID from the event source. +// For group/room messages, use the group/room ID; for 1:1, use the user ID. +func (c *LINEChannel) resolveChatID(source lineSource) string { + switch source.Type { + case "group": + return source.GroupID + case "room": + return source.RoomID + default: + return source.UserID + } +} + +// Send sends a message to LINE. It first tries the Reply API (free) +// using a cached reply token, then falls back to the Push API. +func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("line channel not running") + } + + // Load and consume quote token for this chat + var quoteToken string + if qt, ok := c.quoteTokens.LoadAndDelete(msg.ChatID); ok { + quoteToken = qt.(string) + } + + // Try reply token first (free, valid for ~25 seconds) + if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok { + tokenEntry := entry.(replyTokenEntry) + if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge { + if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil { + logger.DebugCF("line", "Message sent via Reply API", map[string]interface{}{ + "chat_id": msg.ChatID, + "quoted": quoteToken != "", + }) + return nil + } + logger.DebugC("line", "Reply API failed, falling back to Push API") + } + } + + // Fall back to Push API + return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken) +} + +// buildTextMessage creates a text message object, optionally with quoteToken. +func buildTextMessage(content, quoteToken string) map[string]string { + msg := map[string]string{ + "type": "text", + "text": content, + } + if quoteToken != "" { + msg["quoteToken"] = quoteToken + } + return msg +} + +// sendReply sends a message using the LINE Reply API. +func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error { + payload := map[string]interface{}{ + "replyToken": replyToken, + "messages": []map[string]string{buildTextMessage(content, quoteToken)}, + } + + return c.callAPI(ctx, lineReplyEndpoint, payload) +} + +// sendPush sends a message using the LINE Push API. +func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error { + payload := map[string]interface{}{ + "to": to, + "messages": []map[string]string{buildTextMessage(content, quoteToken)}, + } + + return c.callAPI(ctx, linePushEndpoint, payload) +} + +// sendLoading sends a loading animation indicator to the chat. +func (c *LINEChannel) sendLoading(chatID string) { + payload := map[string]interface{}{ + "chatId": chatID, + "loadingSeconds": 60, + } + if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil { + logger.DebugCF("line", "Failed to send loading indicator", map[string]interface{}{ + "error": err.Error(), + }) + } +} + +// callAPI makes an authenticated POST request to the LINE API. +func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload interface{}) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("API request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// downloadContent downloads media content from the LINE API. +func (c *LINEChannel) downloadContent(messageID, filename string) string { + url := fmt.Sprintf(lineContentEndpoint, messageID) + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "line", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.ChannelAccessToken, + }, + }) +} diff --git a/pkg/channels/maixcam/init.go b/pkg/channels/maixcam/init.go new file mode 100644 index 000000000..5a269b22b --- /dev/null +++ b/pkg/channels/maixcam/init.go @@ -0,0 +1,13 @@ +package maixcam + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewMaixCamChannel(cfg.Channels.MaixCam, b) + }) +} diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go new file mode 100644 index 000000000..d3c6662d7 --- /dev/null +++ b/pkg/channels/maixcam/maixcam.go @@ -0,0 +1,244 @@ +package maixcam + +import ( + "context" + "encoding/json" + "fmt" + "net" + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +type MaixCamChannel struct { + *channels.BaseChannel + config config.MaixCamConfig + listener net.Listener + clients map[net.Conn]bool + clientsMux sync.RWMutex +} + +type MaixCamMessage struct { + Type string `json:"type"` + Tips string `json:"tips"` + Timestamp float64 `json:"timestamp"` + Data map[string]interface{} `json:"data"` +} + +func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { + base := channels.NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) + + return &MaixCamChannel{ + BaseChannel: base, + config: cfg, + clients: make(map[net.Conn]bool), + }, nil +} + +func (c *MaixCamChannel) Start(ctx context.Context) error { + logger.InfoC("maixcam", "Starting MaixCam channel server") + + addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + + c.listener = listener + c.SetRunning(true) + + logger.InfoCF("maixcam", "MaixCam server listening", map[string]interface{}{ + "host": c.config.Host, + "port": c.config.Port, + }) + + go c.acceptConnections(ctx) + + return nil +} + +func (c *MaixCamChannel) acceptConnections(ctx context.Context) { + logger.DebugC("maixcam", "Starting connection acceptor") + + for { + select { + case <-ctx.Done(): + logger.InfoC("maixcam", "Stopping connection acceptor") + return + default: + conn, err := c.listener.Accept() + if err != nil { + if c.IsRunning() { + logger.ErrorCF("maixcam", "Failed to accept connection", map[string]interface{}{ + "error": err.Error(), + }) + } + return + } + + logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]interface{}{ + "remote_addr": conn.RemoteAddr().String(), + }) + + c.clientsMux.Lock() + c.clients[conn] = true + c.clientsMux.Unlock() + + go c.handleConnection(conn, ctx) + } + } +} + +func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { + logger.DebugC("maixcam", "Handling MaixCam connection") + + defer func() { + conn.Close() + c.clientsMux.Lock() + delete(c.clients, conn) + c.clientsMux.Unlock() + logger.DebugC("maixcam", "Connection closed") + }() + + decoder := json.NewDecoder(conn) + + for { + select { + case <-ctx.Done(): + return + default: + var msg MaixCamMessage + if err := decoder.Decode(&msg); err != nil { + if err.Error() != "EOF" { + logger.ErrorCF("maixcam", "Failed to decode message", map[string]interface{}{ + "error": err.Error(), + }) + } + return + } + + c.processMessage(msg, conn) + } + } +} + +func (c *MaixCamChannel) processMessage(msg MaixCamMessage, conn net.Conn) { + switch msg.Type { + case "person_detected": + c.handlePersonDetection(msg) + case "heartbeat": + logger.DebugC("maixcam", "Received heartbeat") + case "status": + c.handleStatusUpdate(msg) + default: + logger.WarnCF("maixcam", "Unknown message type", map[string]interface{}{ + "type": msg.Type, + }) + } +} + +func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { + logger.InfoCF("maixcam", "", map[string]interface{}{ + "timestamp": msg.Timestamp, + "data": msg.Data, + }) + + senderID := "maixcam" + chatID := "default" + + classInfo, ok := msg.Data["class_name"].(string) + if !ok { + classInfo = "person" + } + + score, _ := msg.Data["score"].(float64) + x, _ := msg.Data["x"].(float64) + y, _ := msg.Data["y"].(float64) + w, _ := msg.Data["w"].(float64) + h, _ := msg.Data["h"].(float64) + + content := fmt.Sprintf("📷 Person detected!\nClass: %s\nConfidence: %.2f%%\nPosition: (%.0f, %.0f)\nSize: %.0fx%.0f", + classInfo, score*100, x, y, w, h) + + metadata := map[string]string{ + "timestamp": fmt.Sprintf("%.0f", msg.Timestamp), + "class_id": fmt.Sprintf("%.0f", msg.Data["class_id"]), + "score": fmt.Sprintf("%.2f", score), + "x": fmt.Sprintf("%.0f", x), + "y": fmt.Sprintf("%.0f", y), + "w": fmt.Sprintf("%.0f", w), + "h": fmt.Sprintf("%.0f", h), + "peer_kind": "channel", + "peer_id": "default", + } + + c.HandleMessage(senderID, chatID, content, []string{}, metadata) +} + +func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { + logger.InfoCF("maixcam", "Status update from MaixCam", map[string]interface{}{ + "status": msg.Data, + }) +} + +func (c *MaixCamChannel) Stop(ctx context.Context) error { + logger.InfoC("maixcam", "Stopping MaixCam channel") + c.SetRunning(false) + + if c.listener != nil { + c.listener.Close() + } + + c.clientsMux.Lock() + defer c.clientsMux.Unlock() + + for conn := range c.clients { + conn.Close() + } + c.clients = make(map[net.Conn]bool) + + logger.InfoC("maixcam", "MaixCam channel stopped") + return nil +} + +func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("maixcam channel not running") + } + + c.clientsMux.RLock() + defer c.clientsMux.RUnlock() + + if len(c.clients) == 0 { + logger.WarnC("maixcam", "No MaixCam devices connected") + return fmt.Errorf("no connected MaixCam devices") + } + + response := map[string]interface{}{ + "type": "command", + "timestamp": float64(0), + "message": msg.Content, + "chat_id": msg.ChatID, + } + + data, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal response: %w", err) + } + + var sendErr error + for conn := range c.clients { + if _, err := conn.Write(data); err != nil { + logger.ErrorCF("maixcam", "Failed to send to client", map[string]interface{}{ + "client": conn.RemoteAddr().String(), + "error": err.Error(), + }) + sendErr = err + } + } + + return sendErr +} diff --git a/pkg/channels/onebot/init.go b/pkg/channels/onebot/init.go new file mode 100644 index 000000000..84c06dfd6 --- /dev/null +++ b/pkg/channels/onebot/init.go @@ -0,0 +1,13 @@ +package onebot + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewOneBotChannel(cfg.Channels.OneBot, b) + }) +} diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go new file mode 100644 index 000000000..209f2dc00 --- /dev/null +++ b/pkg/channels/onebot/onebot.go @@ -0,0 +1,980 @@ +package onebot + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type OneBotChannel struct { + *channels.BaseChannel + config config.OneBotConfig + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + dedup map[string]struct{} + dedupRing []string + dedupIdx int + mu sync.Mutex + writeMu sync.Mutex + echoCounter int64 + selfID int64 + pending map[string]chan json.RawMessage + pendingMu sync.Mutex + transcriber *voice.GroqTranscriber + lastMessageID sync.Map + pendingEmojiMsg sync.Map +} + +type oneBotRawEvent struct { + PostType string `json:"post_type"` + MessageType string `json:"message_type"` + SubType string `json:"sub_type"` + MessageID json.RawMessage `json:"message_id"` + UserID json.RawMessage `json:"user_id"` + GroupID json.RawMessage `json:"group_id"` + RawMessage string `json:"raw_message"` + Message json.RawMessage `json:"message"` + Sender json.RawMessage `json:"sender"` + SelfID json.RawMessage `json:"self_id"` + Time json.RawMessage `json:"time"` + MetaEventType string `json:"meta_event_type"` + NoticeType string `json:"notice_type"` + Echo string `json:"echo"` + RetCode json.RawMessage `json:"retcode"` + Status json.RawMessage `json:"status"` + Data json.RawMessage `json:"data"` +} + +type BotStatus struct { + Online bool `json:"online"` + Good bool `json:"good"` +} + +func isAPIResponse(raw json.RawMessage) bool { + if len(raw) == 0 { + return false + } + var s string + if json.Unmarshal(raw, &s) == nil { + return s == "ok" || s == "failed" + } + var bs BotStatus + if json.Unmarshal(raw, &bs) == nil { + return bs.Online || bs.Good + } + return false +} + +type oneBotSender struct { + UserID json.RawMessage `json:"user_id"` + Nickname string `json:"nickname"` + Card string `json:"card"` +} + +type oneBotAPIRequest struct { + Action string `json:"action"` + Params interface{} `json:"params"` + Echo string `json:"echo,omitempty"` +} + +type oneBotMessageSegment struct { + Type string `json:"type"` + Data map[string]interface{} `json:"data"` +} + +func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { + base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) + + const dedupSize = 1024 + return &OneBotChannel{ + BaseChannel: base, + config: cfg, + dedup: make(map[string]struct{}, dedupSize), + dedupRing: make([]string, dedupSize), + dedupIdx: 0, + pending: make(map[string]chan json.RawMessage), + }, nil +} + +func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { + go func() { + _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{ + "message_id": messageID, + "emoji_id": emojiID, + "set": set, + }, 5*time.Second) + if err != nil { + logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{ + "message_id": messageID, + "error": err.Error(), + }) + } + }() +} + +func (c *OneBotChannel) Start(ctx context.Context) error { + if c.config.WSUrl == "" { + return fmt.Errorf("OneBot ws_url not configured") + } + + logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{ + "ws_url": c.config.WSUrl, + }) + + c.ctx, c.cancel = context.WithCancel(ctx) + + if err := c.connect(); err != nil { + logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{ + "error": err.Error(), + }) + } else { + go c.listen() + c.fetchSelfID() + } + + if c.config.ReconnectInterval > 0 { + go c.reconnectLoop() + } else { + if c.conn == nil { + return fmt.Errorf("failed to connect to OneBot and reconnect is disabled") + } + } + + c.SetRunning(true) + logger.InfoC("onebot", "OneBot channel started successfully") + + return nil +} + +func (c *OneBotChannel) connect() error { + dialer := websocket.DefaultDialer + dialer.HandshakeTimeout = 10 * time.Second + + header := make(map[string][]string) + if c.config.AccessToken != "" { + header["Authorization"] = []string{"Bearer " + c.config.AccessToken} + } + + conn, _, err := dialer.Dial(c.config.WSUrl, header) + if err != nil { + return err + } + + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + c.mu.Lock() + c.conn = conn + c.mu.Unlock() + + go c.pinger(conn) + + logger.InfoC("onebot", "WebSocket connected") + return nil +} + +func (c *OneBotChannel) pinger(conn *websocket.Conn) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.writeMu.Lock() + err := conn.WriteMessage(websocket.PingMessage, nil) + c.writeMu.Unlock() + if err != nil { + logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{ + "error": err.Error(), + }) + return + } + } + } +} + +func (c *OneBotChannel) fetchSelfID() { + resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second) + if err != nil { + logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + type loginInfo struct { + UserID json.RawMessage `json:"user_id"` + Nickname string `json:"nickname"` + } + for _, extract := range []func() (*loginInfo, error){ + func() (*loginInfo, error) { + var w struct { + Data loginInfo `json:"data"` + } + err := json.Unmarshal(resp, &w) + return &w.Data, err + }, + func() (*loginInfo, error) { + var f loginInfo + err := json.Unmarshal(resp, &f) + return &f, err + }, + } { + info, err := extract() + if err != nil || len(info.UserID) == 0 { + continue + } + if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 { + atomic.StoreInt64(&c.selfID, uid) + logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{ + "self_id": uid, + "nickname": info.Nickname, + }) + return + } + } + + logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{ + "response": string(resp), + }) +} + +func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return nil, fmt.Errorf("WebSocket not connected") + } + + echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1)) + + ch := make(chan json.RawMessage, 1) + c.pendingMu.Lock() + c.pending[echo] = ch + c.pendingMu.Unlock() + + defer func() { + c.pendingMu.Lock() + delete(c.pending, echo) + c.pendingMu.Unlock() + }() + + req := oneBotAPIRequest{ + Action: action, + Params: params, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal API request: %w", err) + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.writeMu.Unlock() + + if err != nil { + return nil, fmt.Errorf("failed to write API request: %w", err) + } + + select { + case resp := <-ch: + return resp, nil + case <-time.After(timeout): + return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) + case <-c.ctx.Done(): + return nil, fmt.Errorf("context cancelled") + } +} + +func (c *OneBotChannel) reconnectLoop() { + interval := time.Duration(c.config.ReconnectInterval) * time.Second + if interval < 5*time.Second { + interval = 5 * time.Second + } + + for { + select { + case <-c.ctx.Done(): + return + case <-time.After(interval): + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + logger.InfoC("onebot", "Attempting to reconnect...") + if err := c.connect(); err != nil { + logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{ + "error": err.Error(), + }) + } else { + go c.listen() + c.fetchSelfID() + } + } + } + } +} + +func (c *OneBotChannel) Stop(ctx context.Context) error { + logger.InfoC("onebot", "Stopping OneBot channel") + c.SetRunning(false) + + if c.cancel != nil { + c.cancel() + } + + c.pendingMu.Lock() + for echo, ch := range c.pending { + close(ch) + delete(c.pending, echo) + } + c.pendingMu.Unlock() + + c.mu.Lock() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + c.mu.Unlock() + + return nil +} + +func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("OneBot channel not running") + } + + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return fmt.Errorf("OneBot WebSocket not connected") + } + + action, params, err := c.buildSendRequest(msg) + if err != nil { + return err + } + + echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) + + req := oneBotAPIRequest{ + Action: action, + Params: params, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal OneBot request: %w", err) + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.writeMu.Unlock() + + if err != nil { + logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{ + "error": err.Error(), + }) + return err + } + + if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { + if mid, ok := msgID.(string); ok && mid != "" { + c.setMsgEmojiLike(mid, 289, false) + } + } + + return nil +} + +func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment { + var segments []oneBotMessageSegment + + if lastMsgID, ok := c.lastMessageID.Load(chatID); ok { + if msgID, ok := lastMsgID.(string); ok && msgID != "" { + segments = append(segments, oneBotMessageSegment{ + Type: "reply", + Data: map[string]interface{}{"id": msgID}, + }) + } + } + + segments = append(segments, oneBotMessageSegment{ + Type: "text", + Data: map[string]interface{}{"text": content}, + }) + + return segments +} + +func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) { + chatID := msg.ChatID + segments := c.buildMessageSegments(chatID, msg.Content) + + var action, idKey string + var rawID string + if rest, ok := strings.CutPrefix(chatID, "group:"); ok { + action, idKey, rawID = "send_group_msg", "group_id", rest + } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { + action, idKey, rawID = "send_private_msg", "user_id", rest + } else { + action, idKey, rawID = "send_private_msg", "user_id", chatID + } + + id, err := strconv.ParseInt(rawID, 10, 64) + if err != nil { + return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID) + } + return action, map[string]interface{}{idKey: id, "message": segments}, nil +} + +func (c *OneBotChannel) listen() { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + logger.WarnC("onebot", "WebSocket connection is nil, listener exiting") + return + } + + for { + select { + case <-c.ctx.Done(): + return + default: + _, message, err := conn.ReadMessage() + if err != nil { + logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{ + "error": err.Error(), + }) + c.mu.Lock() + if c.conn == conn { + c.conn.Close() + c.conn = nil + } + c.mu.Unlock() + return + } + + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + var raw oneBotRawEvent + if err := json.Unmarshal(message, &raw); err != nil { + logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{ + "error": err.Error(), + "payload": string(message), + }) + continue + } + + logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{ + "length": len(message), + "post_type": raw.PostType, + "sub_type": raw.SubType, + }) + + if raw.Echo != "" { + c.pendingMu.Lock() + ch, ok := c.pending[raw.Echo] + c.pendingMu.Unlock() + + if ok { + select { + case ch <- message: + default: + } + } else { + logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{ + "echo": raw.Echo, + "status": string(raw.Status), + }) + } + continue + } + + if isAPIResponse(raw.Status) { + logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{ + "status": string(raw.Status), + }) + continue + } + + c.handleRawEvent(&raw) + } + } +} + +func parseJSONInt64(raw json.RawMessage) (int64, error) { + if len(raw) == 0 { + return 0, nil + } + + var n int64 + if err := json.Unmarshal(raw, &n); err == nil { + return n, nil + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return strconv.ParseInt(s, 10, 64) + } + return 0, fmt.Errorf("cannot parse as int64: %s", string(raw)) +} + +func parseJSONString(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + + return string(raw) +} + +type parseMessageResult struct { + Text string + IsBotMentioned bool + Media []string + LocalFiles []string + ReplyTo string +} + +func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { + if len(raw) == 0 { + return parseMessageResult{} + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + mentioned := false + if selfID > 0 { + cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) + if strings.Contains(s, cqAt) { + mentioned = true + s = strings.ReplaceAll(s, cqAt, "") + s = strings.TrimSpace(s) + } + } + return parseMessageResult{Text: s, IsBotMentioned: mentioned} + } + + var segments []map[string]interface{} + if err := json.Unmarshal(raw, &segments); err != nil { + return parseMessageResult{} + } + + var textParts []string + mentioned := false + selfIDStr := strconv.FormatInt(selfID, 10) + var media []string + var localFiles []string + var replyTo string + + for _, seg := range segments { + segType, _ := seg["type"].(string) + data, _ := seg["data"].(map[string]interface{}) + + switch segType { + case "text": + if data != nil { + if t, ok := data["text"].(string); ok { + textParts = append(textParts, t) + } + } + + case "at": + if data != nil && selfID > 0 { + qqVal := fmt.Sprintf("%v", data["qq"]) + if qqVal == selfIDStr || qqVal == "all" { + mentioned = true + } + } + + case "image", "video", "file": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"} + filename := defaults[segType] + if f, ok := data["file"].(string); ok && f != "" { + filename = f + } else if n, ok := data["name"].(string); ok && n != "" { + filename = n + } + localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + media = append(media, localPath) + localFiles = append(localFiles, localPath) + textParts = append(textParts, fmt.Sprintf("[%s]", segType)) + } + } + } + + case "record": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + localFiles = append(localFiles, localPath) + if c.transcriber != nil && c.transcriber.IsAvailable() { + tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) + result, err := c.transcriber.Transcribe(tctx, localPath) + tcancel() + if err != nil { + logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + }) + textParts = append(textParts, "[voice (transcription failed)]") + media = append(media, localPath) + } else { + textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) + } + } else { + textParts = append(textParts, "[voice]") + media = append(media, localPath) + } + } + } + } + + case "reply": + if data != nil { + if id, ok := data["id"]; ok { + replyTo = fmt.Sprintf("%v", id) + } + } + + case "face": + if data != nil { + faceID, _ := data["id"] + textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID)) + } + + case "forward": + textParts = append(textParts, "[forward message]") + + default: + + } + } + + return parseMessageResult{ + Text: strings.TrimSpace(strings.Join(textParts, "")), + IsBotMentioned: mentioned, + Media: media, + LocalFiles: localFiles, + ReplyTo: replyTo, + } +} + +func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { + switch raw.PostType { + case "message": + if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { + if !c.IsAllowed(strconv.FormatInt(userID, 10)) { + logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{ + "user_id": userID, + }) + return + } + } + c.handleMessage(raw) + + case "message_sent": + logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{ + "message_type": raw.MessageType, + "message_id": parseJSONString(raw.MessageID), + }) + + case "meta_event": + c.handleMetaEvent(raw) + + case "notice": + c.handleNoticeEvent(raw) + + case "request": + logger.DebugCF("onebot", "Request event received", map[string]interface{}{ + "sub_type": raw.SubType, + }) + + case "": + logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{ + "echo": raw.Echo, + "status": raw.Status, + }) + + default: + logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{ + "post_type": raw.PostType, + }) + } +} + +func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) { + if raw.MetaEventType == "lifecycle" { + logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType}) + } else if raw.MetaEventType != "heartbeat" { + logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil) + } +} + +func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) { + fields := map[string]interface{}{ + "notice_type": raw.NoticeType, + "sub_type": raw.SubType, + "group_id": parseJSONString(raw.GroupID), + "user_id": parseJSONString(raw.UserID), + "message_id": parseJSONString(raw.MessageID), + } + switch raw.NoticeType { + case "group_recall", "group_increase", "group_decrease", + "friend_add", "group_admin", "group_ban": + logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields) + default: + logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields) + } +} + +func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { + // Parse fields from raw event + userID, err := parseJSONInt64(raw.UserID) + if err != nil { + logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{ + "error": err.Error(), + "raw": string(raw.UserID), + }) + return + } + + groupID, _ := parseJSONInt64(raw.GroupID) + selfID, _ := parseJSONInt64(raw.SelfID) + messageID := parseJSONString(raw.MessageID) + + if selfID == 0 { + selfID = atomic.LoadInt64(&c.selfID) + } + + parsed := c.parseMessageSegments(raw.Message, selfID) + isBotMentioned := parsed.IsBotMentioned + + content := raw.RawMessage + if content == "" { + content = parsed.Text + } else if selfID > 0 { + cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) + if strings.Contains(content, cqAt) { + isBotMentioned = true + content = strings.ReplaceAll(content, cqAt, "") + content = strings.TrimSpace(content) + } + } + + if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") { + content = parsed.Text + } + + var sender oneBotSender + if len(raw.Sender) > 0 { + if err := json.Unmarshal(raw.Sender, &sender); err != nil { + logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{ + "error": err.Error(), + "sender": string(raw.Sender), + }) + } + } + + // Clean up temp files when done + if len(parsed.LocalFiles) > 0 { + defer func() { + for _, f := range parsed.LocalFiles { + if err := os.Remove(f); err != nil { + logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{ + "path": f, + "error": err.Error(), + }) + } + } + }() + } + + if c.isDuplicate(messageID) { + logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{ + "message_id": messageID, + }) + return + } + + if content == "" { + logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{ + "message_id": messageID, + }) + return + } + + senderID := strconv.FormatInt(userID, 10) + var chatID string + + metadata := map[string]string{ + "message_id": messageID, + } + + if parsed.ReplyTo != "" { + metadata["reply_to_message_id"] = parsed.ReplyTo + } + + switch raw.MessageType { + case "private": + chatID = "private:" + senderID + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + + case "group": + groupIDStr := strconv.FormatInt(groupID, 10) + chatID = "group:" + groupIDStr + metadata["peer_kind"] = "group" + metadata["peer_id"] = groupIDStr + metadata["group_id"] = groupIDStr + + senderUserID, _ := parseJSONInt64(sender.UserID) + if senderUserID > 0 { + metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10) + } + + if sender.Card != "" { + metadata["sender_name"] = sender.Card + } else if sender.Nickname != "" { + metadata["sender_name"] = sender.Nickname + } + + triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) + if !triggered { + logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{ + "sender": senderID, + "group": groupIDStr, + "is_mentioned": isBotMentioned, + "content": truncate(content, 100), + }) + return + } + content = strippedContent + + default: + logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{ + "type": raw.MessageType, + "message_id": messageID, + "user_id": userID, + }) + return + } + + logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{ + "sender": senderID, + "chat_id": chatID, + "message_id": messageID, + "length": len(content), + "content": truncate(content, 100), + "media_count": len(parsed.Media), + }) + + if sender.Nickname != "" { + metadata["nickname"] = sender.Nickname + } + + c.lastMessageID.Store(chatID, messageID) + + if raw.MessageType == "group" && messageID != "" && messageID != "0" { + c.setMsgEmojiLike(messageID, 289, true) + c.pendingEmojiMsg.Store(chatID, messageID) + } + + c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) +} + +func (c *OneBotChannel) isDuplicate(messageID string) bool { + if messageID == "" || messageID == "0" { + return false + } + + c.mu.Lock() + defer c.mu.Unlock() + + if _, exists := c.dedup[messageID]; exists { + return true + } + + if old := c.dedupRing[c.dedupIdx]; old != "" { + delete(c.dedup, old) + } + c.dedupRing[c.dedupIdx] = messageID + c.dedup[messageID] = struct{}{} + c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing) + + return false +} + +func truncate(s string, n int) string { + runes := []rune(s) + if len(runes) <= n { + return s + } + return string(runes[:n]) + "..." +} + +func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) { + if isBotMentioned { + return true, strings.TrimSpace(content) + } + + for _, prefix := range c.config.GroupTriggerPrefix { + if prefix == "" { + continue + } + if strings.HasPrefix(content, prefix) { + return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) + } + } + + return false, content +} diff --git a/pkg/channels/qq/init.go b/pkg/channels/qq/init.go new file mode 100644 index 000000000..15b955089 --- /dev/null +++ b/pkg/channels/qq/init.go @@ -0,0 +1,13 @@ +package qq + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewQQChannel(cfg.Channels.QQ, b) + }) +} diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go new file mode 100644 index 000000000..9b07be0cc --- /dev/null +++ b/pkg/channels/qq/qq.go @@ -0,0 +1,248 @@ +package qq + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/tencent-connect/botgo" + "github.com/tencent-connect/botgo/dto" + "github.com/tencent-connect/botgo/event" + "github.com/tencent-connect/botgo/openapi" + "github.com/tencent-connect/botgo/token" + "golang.org/x/oauth2" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +type QQChannel struct { + *channels.BaseChannel + config config.QQConfig + api openapi.OpenAPI + tokenSource oauth2.TokenSource + ctx context.Context + cancel context.CancelFunc + sessionManager botgo.SessionManager + processedIDs map[string]bool + mu sync.RWMutex +} + +func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { + base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) + + return &QQChannel{ + BaseChannel: base, + config: cfg, + processedIDs: make(map[string]bool), + }, nil +} + +func (c *QQChannel) Start(ctx context.Context) error { + if c.config.AppID == "" || c.config.AppSecret == "" { + return fmt.Errorf("QQ app_id and app_secret not configured") + } + + logger.InfoC("qq", "Starting QQ bot (WebSocket mode)") + + // 创建 token source + credentials := &token.QQBotCredentials{ + AppID: c.config.AppID, + AppSecret: c.config.AppSecret, + } + c.tokenSource = token.NewQQBotTokenSource(credentials) + + // 创建子 context + c.ctx, c.cancel = context.WithCancel(ctx) + + // 启动自动刷新 token 协程 + if err := token.StartRefreshAccessToken(c.ctx, c.tokenSource); err != nil { + return fmt.Errorf("failed to start token refresh: %w", err) + } + + // 初始化 OpenAPI 客户端 + c.api = botgo.NewOpenAPI(c.config.AppID, c.tokenSource).WithTimeout(5 * time.Second) + + // 注册事件处理器 + intent := event.RegisterHandlers( + c.handleC2CMessage(), + c.handleGroupATMessage(), + ) + + // 获取 WebSocket 接入点 + wsInfo, err := c.api.WS(c.ctx, nil, "") + if err != nil { + return fmt.Errorf("failed to get websocket info: %w", err) + } + + logger.InfoCF("qq", "Got WebSocket info", map[string]interface{}{ + "shards": wsInfo.Shards, + }) + + // 创建并保存 sessionManager + c.sessionManager = botgo.NewSessionManager() + + // 在 goroutine 中启动 WebSocket 连接,避免阻塞 + go func() { + if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil { + logger.ErrorCF("qq", "WebSocket session error", map[string]interface{}{ + "error": err.Error(), + }) + c.SetRunning(false) + } + }() + + c.SetRunning(true) + logger.InfoC("qq", "QQ bot started successfully") + + return nil +} + +func (c *QQChannel) Stop(ctx context.Context) error { + logger.InfoC("qq", "Stopping QQ bot") + c.SetRunning(false) + + if c.cancel != nil { + c.cancel() + } + + return nil +} + +func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("QQ bot not running") + } + + // 构造消息 + msgToCreate := &dto.MessageToCreate{ + Content: msg.Content, + } + + // C2C 消息发送 + _, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate) + if err != nil { + logger.ErrorCF("qq", "Failed to send C2C message", map[string]interface{}{ + "error": err.Error(), + }) + return err + } + + return nil +} + +// handleC2CMessage 处理 QQ 私聊消息 +func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { + return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error { + // 去重检查 + if c.isDuplicate(data.ID) { + return nil + } + + // 提取用户信息 + var senderID string + if data.Author != nil && data.Author.ID != "" { + senderID = data.Author.ID + } else { + logger.WarnC("qq", "Received message with no sender ID") + return nil + } + + // 提取消息内容 + content := data.Content + if content == "" { + logger.DebugC("qq", "Received empty message, ignoring") + return nil + } + + logger.InfoCF("qq", "Received C2C message", map[string]interface{}{ + "sender": senderID, + "length": len(content), + }) + + // 转发到消息总线 + metadata := map[string]string{ + "message_id": data.ID, + "peer_kind": "direct", + "peer_id": senderID, + } + + c.HandleMessage(senderID, senderID, content, []string{}, metadata) + + return nil + } +} + +// handleGroupATMessage 处理群@消息 +func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { + return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error { + // 去重检查 + if c.isDuplicate(data.ID) { + return nil + } + + // 提取用户信息 + var senderID string + if data.Author != nil && data.Author.ID != "" { + senderID = data.Author.ID + } else { + logger.WarnC("qq", "Received group message with no sender ID") + return nil + } + + // 提取消息内容(去掉 @ 机器人部分) + content := data.Content + if content == "" { + logger.DebugC("qq", "Received empty group message, ignoring") + return nil + } + + logger.InfoCF("qq", "Received group AT message", map[string]interface{}{ + "sender": senderID, + "group": data.GroupID, + "length": len(content), + }) + + // 转发到消息总线(使用 GroupID 作为 ChatID) + metadata := map[string]string{ + "message_id": data.ID, + "group_id": data.GroupID, + "peer_kind": "group", + "peer_id": data.GroupID, + } + + c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata) + + return nil + } +} + +// isDuplicate 检查消息是否重复 +func (c *QQChannel) isDuplicate(messageID string) bool { + c.mu.Lock() + defer c.mu.Unlock() + + if c.processedIDs[messageID] { + return true + } + + c.processedIDs[messageID] = true + + // 简单清理:限制 map 大小 + if len(c.processedIDs) > 10000 { + // 清空一半 + count := 0 + for id := range c.processedIDs { + if count >= 5000 { + break + } + delete(c.processedIDs, id) + count++ + } + } + + return false +} diff --git a/pkg/channels/slack/init.go b/pkg/channels/slack/init.go new file mode 100644 index 000000000..c131bb291 --- /dev/null +++ b/pkg/channels/slack/init.go @@ -0,0 +1,13 @@ +package slack + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewSlackChannel(cfg.Channels.Slack, b) + }) +} diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go new file mode 100644 index 000000000..dc5190fc9 --- /dev/null +++ b/pkg/channels/slack/slack.go @@ -0,0 +1,444 @@ +package slack + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/slack-go/slack" + "github.com/slack-go/slack/slackevents" + "github.com/slack-go/slack/socketmode" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type SlackChannel struct { + *channels.BaseChannel + config config.SlackConfig + api *slack.Client + socketClient *socketmode.Client + botUserID string + teamID string + transcriber *voice.GroqTranscriber + ctx context.Context + cancel context.CancelFunc + pendingAcks sync.Map +} + +type slackMessageRef struct { + ChannelID string + Timestamp string +} + +func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) { + if cfg.BotToken == "" || cfg.AppToken == "" { + return nil, fmt.Errorf("slack bot_token and app_token are required") + } + + api := slack.New( + cfg.BotToken, + slack.OptionAppLevelToken(cfg.AppToken), + ) + + socketClient := socketmode.New(api) + + base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + + return &SlackChannel{ + BaseChannel: base, + config: cfg, + api: api, + socketClient: socketClient, + }, nil +} + +func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *SlackChannel) Start(ctx context.Context) error { + logger.InfoC("slack", "Starting Slack channel (Socket Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + authResp, err := c.api.AuthTest() + if err != nil { + return fmt.Errorf("slack auth test failed: %w", err) + } + c.botUserID = authResp.UserID + c.teamID = authResp.TeamID + + logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ + "bot_user_id": c.botUserID, + "team": authResp.Team, + }) + + go c.eventLoop() + + go func() { + if err := c.socketClient.RunContext(c.ctx); err != nil { + if c.ctx.Err() == nil { + logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + }() + + c.SetRunning(true) + logger.InfoC("slack", "Slack channel started (Socket Mode)") + return nil +} + +func (c *SlackChannel) Stop(ctx context.Context) error { + logger.InfoC("slack", "Stopping Slack channel") + + if c.cancel != nil { + c.cancel() + } + + c.SetRunning(false) + logger.InfoC("slack", "Slack channel stopped") + return nil +} + +func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("slack channel not running") + } + + channelID, threadTS := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + opts := []slack.MsgOption{ + slack.MsgOptionText(msg.Content, false), + } + + if threadTS != "" { + opts = append(opts, slack.MsgOptionTS(threadTS)) + } + + _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) + if err != nil { + return fmt.Errorf("failed to send slack message: %w", err) + } + + if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { + msgRef := ref.(slackMessageRef) + c.api.AddReaction("white_check_mark", slack.ItemRef{ + Channel: msgRef.ChannelID, + Timestamp: msgRef.Timestamp, + }) + } + + logger.DebugCF("slack", "Message sent", map[string]interface{}{ + "channel_id": channelID, + "thread_ts": threadTS, + }) + + return nil +} + +func (c *SlackChannel) eventLoop() { + for { + select { + case <-c.ctx.Done(): + return + case event, ok := <-c.socketClient.Events: + if !ok { + return + } + switch event.Type { + case socketmode.EventTypeEventsAPI: + c.handleEventsAPI(event) + case socketmode.EventTypeSlashCommand: + c.handleSlashCommand(event) + case socketmode.EventTypeInteractive: + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + } + } + } +} + +func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent) + if !ok { + return + } + + switch ev := eventsAPIEvent.InnerEvent.Data.(type) { + case *slackevents.MessageEvent: + c.handleMessageEvent(ev) + case *slackevents.AppMentionEvent: + c.handleAppMention(ev) + } +} + +func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { + if ev.User == c.botUserID || ev.User == "" { + return + } + if ev.BotID != "" { + return + } + if ev.SubType != "" && ev.SubType != "file_share" { + return + } + + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + chatID := channelID + if threadTS != "" { + chatID = channelID + "/" + threadTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := ev.Text + content = c.stripBotMention(content) + + var mediaPaths []string + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + if ev.Message != nil && len(ev.Message.Files) > 0 { + for _, file := range ev.Message.Files { + localPath := c.downloadSlackFile(file) + if localPath == "" { + continue + } + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + + if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() + result, err := c.transcriber.Transcribe(ctx, localPath) + + if err != nil { + logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) + content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) + } else { + content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) + } + } else { + content += fmt.Sprintf("\n[file: %s]", file.Name) + } + } + } + + if strings.TrimSpace(content) == "" { + return + } + + peerKind := "channel" + peerID := channelID + if strings.HasPrefix(channelID, "D") { + peerKind = "direct" + peerID = senderID + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + "peer_kind": peerKind, + "peer_id": peerID, + "team_id": c.teamID, + } + + logger.DebugCF("slack", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), + "has_thread": threadTS != "", + }) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { + if ev.User == c.botUserID { + return + } + + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + var chatID string + if threadTS != "" { + chatID = channelID + "/" + threadTS + } else { + chatID = channelID + "/" + messageTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := c.stripBotMention(ev.Text) + + if strings.TrimSpace(content) == "" { + return + } + + mentionPeerKind := "channel" + mentionPeerID := channelID + if strings.HasPrefix(channelID, "D") { + mentionPeerKind = "direct" + mentionPeerID = senderID + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + "is_mention": "true", + "peer_kind": mentionPeerKind, + "peer_id": mentionPeerID, + "team_id": c.teamID, + } + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { + cmd, ok := event.Data.(slack.SlashCommand) + if !ok { + return + } + + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + if !c.IsAllowed(cmd.UserID) { + logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{ + "user_id": cmd.UserID, + }) + return + } + + senderID := cmd.UserID + channelID := cmd.ChannelID + chatID := channelID + content := cmd.Text + + if strings.TrimSpace(content) == "" { + content = "help" + } + + metadata := map[string]string{ + "channel_id": channelID, + "platform": "slack", + "is_command": "true", + "trigger_id": cmd.TriggerID, + "peer_kind": "channel", + "peer_id": channelID, + "team_id": c.teamID, + } + + logger.DebugCF("slack", "Slash command received", map[string]interface{}{ + "sender_id": senderID, + "command": cmd.Command, + "text": utils.Truncate(content, 50), + }) + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) downloadSlackFile(file slack.File) string { + downloadURL := file.URLPrivateDownload + if downloadURL == "" { + downloadURL = file.URLPrivate + } + if downloadURL == "" { + logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID}) + return "" + } + + return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ + LoggerPrefix: "slack", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.BotToken, + }, + }) +} + +func (c *SlackChannel) stripBotMention(text string) string { + mention := fmt.Sprintf("<@%s>", c.botUserID) + text = strings.ReplaceAll(text, mention, "") + return strings.TrimSpace(text) +} + +func parseSlackChatID(chatID string) (channelID, threadTS string) { + parts := strings.SplitN(chatID, "/", 2) + channelID = parts[0] + if len(parts) > 1 { + threadTS = parts[1] + } + return +} diff --git a/pkg/channels/slack/slack_test.go b/pkg/channels/slack/slack_test.go new file mode 100644 index 000000000..30e0d2d73 --- /dev/null +++ b/pkg/channels/slack/slack_test.go @@ -0,0 +1,174 @@ +package slack + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestParseSlackChatID(t *testing.T) { + tests := []struct { + name string + chatID string + wantChanID string + wantThread string + }{ + { + name: "channel only", + chatID: "C123456", + wantChanID: "C123456", + wantThread: "", + }, + { + name: "channel with thread", + chatID: "C123456/1234567890.123456", + wantChanID: "C123456", + wantThread: "1234567890.123456", + }, + { + name: "DM channel", + chatID: "D987654", + wantChanID: "D987654", + wantThread: "", + }, + { + name: "empty string", + chatID: "", + wantChanID: "", + wantThread: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chanID, threadTS := parseSlackChatID(tt.chatID) + if chanID != tt.wantChanID { + t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID) + } + if threadTS != tt.wantThread { + t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread) + } + }) + } +} + +func TestStripBotMention(t *testing.T) { + ch := &SlackChannel{botUserID: "U12345BOT"} + + tests := []struct { + name string + input string + want string + }{ + { + name: "mention at start", + input: "<@U12345BOT> hello there", + want: "hello there", + }, + { + name: "mention in middle", + input: "hey <@U12345BOT> can you help", + want: "hey can you help", + }, + { + name: "no mention", + input: "hello world", + want: "hello world", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only mention", + input: "<@U12345BOT>", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ch.stripBotMention(tt.input) + if got != tt.want { + t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNewSlackChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing bot token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "", + AppToken: "xapp-test", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing bot_token, got nil") + } + }) + + t.Run("missing app token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing app_token, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U123"}, + } + ch, err := NewSlackChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "slack" { + t.Errorf("Name() = %q, want %q", ch.Name(), "slack") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestSlackChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ANYONE") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U_ALLOWED"}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ALLOWED") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("U_BLOCKED") { + t.Error("non-allowed user should be blocked") + } + }) +} diff --git a/pkg/channels/telegram/init.go b/pkg/channels/telegram/init.go new file mode 100644 index 000000000..ac87bb805 --- /dev/null +++ b/pkg/channels/telegram/init.go @@ -0,0 +1,13 @@ +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewTelegramChannel(cfg, b) + }) +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go new file mode 100644 index 000000000..f4c5108df --- /dev/null +++ b/pkg/channels/telegram/telegram.go @@ -0,0 +1,526 @@ +package telegram + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "sync" + "time" + + th "github.com/mymmrac/telego/telegohandler" + + "github.com/mymmrac/telego" + "github.com/mymmrac/telego/telegohandler" + tu "github.com/mymmrac/telego/telegoutil" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type TelegramChannel struct { + *channels.BaseChannel + bot *telego.Bot + commands TelegramCommander + config *config.Config + chatIDs map[string]int64 + transcriber *voice.GroqTranscriber + placeholders sync.Map // chatID -> messageID + stopThinking sync.Map // chatID -> thinkingCancel +} + +type thinkingCancel struct { + fn context.CancelFunc +} + +func (c *thinkingCancel) Cancel() { + if c != nil && c.fn != nil { + c.fn() + } +} + +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { + var opts []telego.BotOption + telegramCfg := cfg.Channels.Telegram + + if telegramCfg.Proxy != "" { + proxyURL, parseErr := url.Parse(telegramCfg.Proxy) + if parseErr != nil { + return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr) + } + opts = append(opts, telego.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + })) + } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { + // Use environment proxy if configured + opts = append(opts, telego.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + })) + } + + bot, err := telego.NewBot(telegramCfg.Token, opts...) + if err != nil { + return nil, fmt.Errorf("failed to create telegram bot: %w", err) + } + + base := channels.NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) + + return &TelegramChannel{ + BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), + bot: bot, + config: cfg, + chatIDs: make(map[string]int64), + transcriber: nil, + placeholders: sync.Map{}, + stopThinking: sync.Map{}, + }, nil +} + +func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *TelegramChannel) Start(ctx context.Context) error { + logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") + + updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + Timeout: 30, + }) + if err != nil { + return fmt.Errorf("failed to start long polling: %w", err) + } + + bh, err := telegohandler.NewBotHandler(c.bot, updates) + if err != nil { + return fmt.Errorf("failed to create bot handler: %w", err) + } + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + c.commands.Help(ctx, message) + return nil + }, th.CommandEqual("help")) + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Start(ctx, message) + }, th.CommandEqual("start")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Show(ctx, message) + }, th.CommandEqual("show")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.List(ctx, message) + }, th.CommandEqual("list")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.handleMessage(ctx, &message) + }, th.AnyMessage()) + + c.SetRunning(true) + logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + "username": c.bot.Username(), + }) + + go bh.Start() + + go func() { + <-ctx.Done() + bh.Stop() + }() + + return nil +} +func (c *TelegramChannel) Stop(ctx context.Context) error { + logger.InfoC("telegram", "Stopping Telegram bot...") + c.SetRunning(false) + return nil +} + +func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("telegram bot not running") + } + + chatID, err := parseChatID(msg.ChatID) + if err != nil { + return fmt.Errorf("invalid chat ID: %w", err) + } + + // Stop thinking animation + if stop, ok := c.stopThinking.Load(msg.ChatID); ok { + if cf, ok := stop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + c.stopThinking.Delete(msg.ChatID) + } + + htmlContent := markdownToTelegramHTML(msg.Content) + + // Try to edit placeholder + if pID, ok := c.placeholders.Load(msg.ChatID); ok { + c.placeholders.Delete(msg.ChatID) + editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) + editMsg.ParseMode = telego.ModeHTML + + if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { + return nil + } + // Fallback to new message if edit fails + } + + tgMsg := tu.Message(tu.ID(chatID), htmlContent) + tgMsg.ParseMode = telego.ModeHTML + + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) + tgMsg.ParseMode = "" + _, err = c.bot.SendMessage(ctx, tgMsg) + return err + } + + return nil +} + +func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { + if message == nil { + return fmt.Errorf("message is nil") + } + + user := message.From + if user == nil { + return fmt.Errorf("message sender (user) is nil") + } + + senderID := fmt.Sprintf("%d", user.ID) + if user.Username != "" { + senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + } + + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(senderID) { + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + "user_id": senderID, + }) + return nil + } + + chatID := message.Chat.ID + c.chatIDs[senderID] = chatID + + content := "" + mediaPaths := []string{} + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + if message.Text != "" { + content += message.Text + } + + if message.Caption != "" { + if content != "" { + content += "\n" + } + content += message.Caption + } + + if len(message.Photo) > 0 { + photo := message.Photo[len(message.Photo)-1] + photoPath := c.downloadPhoto(ctx, photo.FileID) + if photoPath != "" { + localFiles = append(localFiles, photoPath) + mediaPaths = append(mediaPaths, photoPath) + if content != "" { + content += "\n" + } + content += "[image: photo]" + } + } + + if message.Voice != nil { + voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") + if voicePath != "" { + localFiles = append(localFiles, voicePath) + mediaPaths = append(mediaPaths, voicePath) + + transcribedText := "" + if c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + result, err := c.transcriber.Transcribe(ctx, voicePath) + if err != nil { + logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + "path": voicePath, + }) + transcribedText = "[voice (transcription failed)]" + } else { + transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + "text": result.Text, + }) + } + } else { + transcribedText = "[voice]" + } + + if content != "" { + content += "\n" + } + content += transcribedText + } + } + + if message.Audio != nil { + audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") + if audioPath != "" { + localFiles = append(localFiles, audioPath) + mediaPaths = append(mediaPaths, audioPath) + if content != "" { + content += "\n" + } + content += "[audio]" + } + } + + if message.Document != nil { + docPath := c.downloadFile(ctx, message.Document.FileID, "") + if docPath != "" { + localFiles = append(localFiles, docPath) + mediaPaths = append(mediaPaths, docPath) + if content != "" { + content += "\n" + } + content += "[file]" + } + } + + if content == "" { + content = "[empty message]" + } + + logger.DebugCF("telegram", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": fmt.Sprintf("%d", chatID), + "preview": utils.Truncate(content, 50), + }) + + // Thinking indicator + err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) + if err != nil { + logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Stop any previous thinking animation + chatIDStr := fmt.Sprintf("%d", chatID) + if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { + if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + } + + // Create cancel function for thinking state + _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + + pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) + if err == nil { + pID := pMsg.MessageID + c.placeholders.Store(chatIDStr, pID) + } + + peerKind := "direct" + peerID := fmt.Sprintf("%d", user.ID) + if message.Chat.Type != "private" { + peerKind = "group" + peerID = fmt.Sprintf("%d", chatID) + } + + metadata := map[string]string{ + "message_id": fmt.Sprintf("%d", message.MessageID), + "user_id": fmt.Sprintf("%d", user.ID), + "username": user.Username, + "first_name": user.FirstName, + "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), + "peer_kind": peerKind, + "peer_id": peerID, + } + + c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + return nil +} + +func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) + if err != nil { + logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + return c.downloadFileWithInfo(file, ".jpg") +} + +func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string { + if file.FilePath == "" { + return "" + } + + url := c.bot.FileDownloadURL(file.FilePath) + logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) + + // Use FilePath as filename for better identification + filename := file.FilePath + ext + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "telegram", + }) +} + +func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) + if err != nil { + logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + return c.downloadFileWithInfo(file, ext) +} + +func parseChatID(chatIDStr string) (int64, error) { + var id int64 + _, err := fmt.Sscanf(chatIDStr, "%d", &id) + return id, err +} + +func markdownToTelegramHTML(text string) string { + if text == "" { + return "" + } + + codeBlocks := extractCodeBlocks(text) + text = codeBlocks.text + + inlineCodes := extractInlineCodes(text) + text = inlineCodes.text + + text = regexp.MustCompile(`^#{1,6}\s+(.+)$`).ReplaceAllString(text, "$1") + + text = regexp.MustCompile(`^>\s*(.*)$`).ReplaceAllString(text, "$1") + + text = escapeHTML(text) + + text = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(text, `$1`) + + text = regexp.MustCompile(`\*\*(.+?)\*\*`).ReplaceAllString(text, "$1") + + text = regexp.MustCompile(`__(.+?)__`).ReplaceAllString(text, "$1") + + reItalic := regexp.MustCompile(`_([^_]+)_`) + text = reItalic.ReplaceAllStringFunc(text, func(s string) string { + match := reItalic.FindStringSubmatch(s) + if len(match) < 2 { + return s + } + return "" + match[1] + "" + }) + + text = regexp.MustCompile(`~~(.+?)~~`).ReplaceAllString(text, "$1") + + text = regexp.MustCompile(`^[-*]\s+`).ReplaceAllString(text, "• ") + + for i, code := range inlineCodes.codes { + escaped := escapeHTML(code) + text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("%s", escaped)) + } + + for i, code := range codeBlocks.codes { + escaped := escapeHTML(code) + text = strings.ReplaceAll(text, fmt.Sprintf("\x00CB%d\x00", i), fmt.Sprintf("
%s
", escaped)) + } + + return text +} + +type codeBlockMatch struct { + text string + codes []string +} + +func extractCodeBlocks(text string) codeBlockMatch { + re := regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```") + matches := re.FindAllStringSubmatch(text, -1) + + codes := make([]string, 0, len(matches)) + for _, match := range matches { + codes = append(codes, match[1]) + } + + i := 0 + text = re.ReplaceAllStringFunc(text, func(m string) string { + placeholder := fmt.Sprintf("\x00CB%d\x00", i) + i++ + return placeholder + }) + + return codeBlockMatch{text: text, codes: codes} +} + +type inlineCodeMatch struct { + text string + codes []string +} + +func extractInlineCodes(text string) inlineCodeMatch { + re := regexp.MustCompile("`([^`]+)`") + matches := re.FindAllStringSubmatch(text, -1) + + codes := make([]string, 0, len(matches)) + for _, match := range matches { + codes = append(codes, match[1]) + } + + i := 0 + text = re.ReplaceAllStringFunc(text, func(m string) string { + placeholder := fmt.Sprintf("\x00IC%d\x00", i) + i++ + return placeholder + }) + + return inlineCodeMatch{text: text, codes: codes} +} + +func escapeHTML(text string) string { + text = strings.ReplaceAll(text, "&", "&") + text = strings.ReplaceAll(text, "<", "<") + text = strings.ReplaceAll(text, ">", ">") + return text +} diff --git a/pkg/channels/telegram/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go new file mode 100644 index 000000000..4bf1b3aff --- /dev/null +++ b/pkg/channels/telegram/telegram_commands.go @@ -0,0 +1,153 @@ +package telegram + +import ( + "context" + "fmt" + "strings" + + "github.com/mymmrac/telego" + "github.com/sipeed/picoclaw/pkg/config" +) + +type TelegramCommander interface { + Help(ctx context.Context, message telego.Message) error + Start(ctx context.Context, message telego.Message) error + Show(ctx context.Context, message telego.Message) error + List(ctx context.Context, message telego.Message) error +} + +type cmd struct { + bot *telego.Bot + config *config.Config +} + +func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { + return &cmd{ + bot: bot, + config: cfg, + } +} + +func commandArgs(text string) string { + parts := strings.SplitN(text, " ", 2) + if len(parts) < 2 { + return "" + } + return strings.TrimSpace(parts[1]) +} +func (c *cmd) Help(ctx context.Context, message telego.Message) error { + msg := `/start - Start the bot +/help - Show this help message +/show [model|channel] - Show current configuration +/list [models|channels] - List available options + ` + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: msg, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Start(ctx context.Context, message telego.Message) error { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Hello! I am PicoClaw 🦞", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Show(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /show [model|channel]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "model": + response = fmt.Sprintf("Current Model: %s (Provider: %s)", + c.config.Agents.Defaults.Model, + c.config.Agents.Defaults.Provider) + case "channel": + response = "Current Channel: telegram" + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} +func (c *cmd) List(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /list [models|channels]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "models": + provider := c.config.Agents.Defaults.Provider + if provider == "" { + provider = "configured default" + } + response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml", + c.config.Agents.Defaults.Model, provider) + + case "channels": + var enabled []string + if c.config.Channels.Telegram.Enabled { + enabled = append(enabled, "telegram") + } + if c.config.Channels.WhatsApp.Enabled { + enabled = append(enabled, "whatsapp") + } + if c.config.Channels.Feishu.Enabled { + enabled = append(enabled, "feishu") + } + if c.config.Channels.Discord.Enabled { + enabled = append(enabled, "discord") + } + if c.config.Channels.Slack.Enabled { + enabled = append(enabled, "slack") + } + response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) + + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go new file mode 100644 index 000000000..85c017958 --- /dev/null +++ b/pkg/channels/wecom/app.go @@ -0,0 +1,636 @@ +package wecom + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + wecomAPIBase = "https://qyapi.weixin.qq.com" +) + +// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) +type WeComAppChannel struct { + *channels.BaseChannel + config config.WeComAppConfig + server *http.Server + accessToken string + tokenExpiry time.Time + tokenMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + processedMsgs map[string]bool // Message deduplication: msg_id -> processed + msgMu sync.RWMutex +} + +// WeComXMLMessage represents the XML message structure from WeCom +type WeComXMLMessage struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType string `xml:"MsgType"` + Content string `xml:"Content"` + MsgId int64 `xml:"MsgId"` + AgentID int64 `xml:"AgentID"` + PicUrl string `xml:"PicUrl"` + MediaId string `xml:"MediaId"` + Format string `xml:"Format"` + ThumbMediaId string `xml:"ThumbMediaId"` + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale int `xml:"Scale"` + Label string `xml:"Label"` + Title string `xml:"Title"` + Description string `xml:"Description"` + Url string `xml:"Url"` + Event string `xml:"Event"` + EventKey string `xml:"EventKey"` +} + +// WeComTextMessage represents text message for sending +type WeComTextMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Safe int `json:"safe,omitempty"` +} + +// WeComMarkdownMessage represents markdown message for sending +type WeComMarkdownMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Markdown struct { + Content string `json:"content"` + } `json:"markdown"` +} + +// WeComImageMessage represents image message for sending +type WeComImageMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Image struct { + MediaID string `json:"media_id"` + } `json:"image"` +} + +// WeComAccessTokenResponse represents the access token API response +type WeComAccessTokenResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` +} + +// WeComSendMessageResponse represents the send message API response +type WeComSendMessageResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + InvalidUser string `json:"invaliduser"` + InvalidParty string `json:"invalidparty"` + InvalidTag string `json:"invalidtag"` +} + +// PKCS7Padding adds PKCS7 padding +type PKCS7Padding struct{} + +// NewWeComAppChannel creates a new WeCom App channel instance +func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) { + if cfg.CorpID == "" || cfg.CorpSecret == "" || cfg.AgentID == 0 { + return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") + } + + base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) + + return &WeComAppChannel{ + BaseChannel: base, + config: cfg, + processedMsgs: make(map[string]bool), + }, nil +} + +// Name returns the channel name +func (c *WeComAppChannel) Name() string { + return "wecom_app" +} + +// Start initializes the WeCom App channel with HTTP webhook server +func (c *WeComAppChannel) Start(ctx context.Context) error { + logger.InfoC("wecom_app", "Starting WeCom App channel...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Get initial access token + if err := c.refreshAccessToken(); err != nil { + logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Start token refresh goroutine + go c.tokenRefreshLoop() + + // Setup HTTP server for webhook + mux := http.NewServeMux() + webhookPath := c.config.WebhookPath + if webhookPath == "" { + webhookPath = "/webhook/wecom-app" + } + mux.HandleFunc(webhookPath, c.handleWebhook) + + // Health check endpoint + mux.HandleFunc("/health/wecom-app", c.handleHealth) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.server = &http.Server{ + Addr: addr, + Handler: mux, + } + + c.SetRunning(true) + logger.InfoCF("wecom_app", "WeCom App channel started", map[string]interface{}{ + "address": addr, + "path": webhookPath, + }) + + // Start server in goroutine + go func() { + if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("wecom_app", "HTTP server error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + return nil +} + +// Stop gracefully stops the WeCom App channel +func (c *WeComAppChannel) Stop(ctx context.Context) error { + logger.InfoC("wecom_app", "Stopping WeCom App channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.server != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + c.server.Shutdown(shutdownCtx) + } + + c.SetRunning(false) + logger.InfoC("wecom_app", "WeCom App channel stopped") + return nil +} + +// Send sends a message to WeCom user proactively using access token +func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("wecom_app channel not running") + } + + accessToken := c.getAccessToken() + if accessToken == "" { + return fmt.Errorf("no valid access token available") + } + + logger.DebugCF("wecom_app", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) +} + +// handleWebhook handles incoming webhook requests from WeCom +func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Log all incoming requests for debugging + logger.DebugCF("wecom_app", "Received webhook request", map[string]interface{}{ + "method": r.Method, + "url": r.URL.String(), + "path": r.URL.Path, + "query": r.URL.RawQuery, + }) + + if r.Method == http.MethodGet { + // Handle verification request + c.handleVerification(ctx, w, r) + return + } + + if r.Method == http.MethodPost { + // Handle message callback + c.handleMessageCallback(ctx, w, r) + return + } + + logger.WarnCF("wecom_app", "Method not allowed", map[string]interface{}{ + "method": r.Method, + }) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +// handleVerification handles the URL verification request from WeCom +func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + echostr := query.Get("echostr") + + logger.DebugCF("wecom_app", "Handling verification request", map[string]interface{}{ + "msg_signature": msgSignature, + "timestamp": timestamp, + "nonce": nonce, + "echostr": echostr, + "corp_id": c.config.CorpID, + }) + + if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { + logger.ErrorC("wecom_app", "Missing parameters in verification request") + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + logger.WarnCF("wecom_app", "Signature verification failed", map[string]interface{}{ + "token": c.config.Token, + "msg_signature": msgSignature, + "timestamp": timestamp, + "nonce": nonce, + }) + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + logger.DebugC("wecom_app", "Signature verification passed") + + // Decrypt echostr with CorpID verification + // For WeCom App (自建应用), receiveid should be corp_id + logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]interface{}{ + "encoding_aes_key": c.config.EncodingAESKey, + "corp_id": c.config.CorpID, + }) + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]interface{}{ + "error": err.Error(), + "encoding_aes_key": c.config.EncodingAESKey, + "corp_id": c.config.CorpID, + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]interface{}{ + "decrypted": decryptedEchoStr, + }) + + // Remove BOM and whitespace as per WeCom documentation + // The response must be plain text without quotes, BOM, or newlines + decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) + decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM + w.Write([]byte(decryptedEchoStr)) +} + +// handleMessageCallback handles incoming messages from WeCom +func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + + if msgSignature == "" || timestamp == "" || nonce == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Parse XML to get encrypted message + var encryptedMsg struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + Encrypt string `xml:"Encrypt"` + AgentID string `xml:"AgentID"` + } + + if err := xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid XML", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + logger.WarnC("wecom_app", "Message signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt message with CorpID verification + // For WeCom App (自建应用), receiveid should be corp_id + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Parse decrypted XML message + var msg WeComXMLMessage + if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { + logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid message format", http.StatusBadRequest) + return + } + + // Process the message with context + go c.processMessage(ctx, msg) + + // Return success response immediately + // WeCom App requires response within configured timeout (default 5 seconds) + w.Write([]byte("success")) +} + +// processMessage processes the received message +func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) { + // Skip non-text messages for now (can be extended) + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { + logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]interface{}{ + "msg_type": msg.MsgType, + }) + return + } + + // Message deduplication: Use msg_id to prevent duplicate processing + // As per WeCom documentation, use msg_id for deduplication + msgID := fmt.Sprintf("%d", msg.MsgId) + c.msgMu.Lock() + if c.processedMsgs[msgID] { + c.msgMu.Unlock() + logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]interface{}{ + "msg_id": msgID, + }) + return + } + c.processedMsgs[msgID] = true + c.msgMu.Unlock() + + // Clean up old messages periodically (keep last 1000) + if len(c.processedMsgs) > 1000 { + c.msgMu.Lock() + c.processedMsgs = make(map[string]bool) + c.msgMu.Unlock() + } + + senderID := msg.FromUserName + chatID := senderID // WeCom App uses user ID as chat ID for direct messages + + // Build metadata + // WeCom App only supports direct messages (private chat) + metadata := map[string]string{ + "msg_type": msg.MsgType, + "msg_id": fmt.Sprintf("%d", msg.MsgId), + "agent_id": fmt.Sprintf("%d", msg.AgentID), + "platform": "wecom_app", + "media_id": msg.MediaId, + "create_time": fmt.Sprintf("%d", msg.CreateTime), + "peer_kind": "direct", + "peer_id": senderID, + } + + content := msg.Content + + logger.DebugCF("wecom_app", "Received message", map[string]interface{}{ + "sender_id": senderID, + "msg_type": msg.MsgType, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +// tokenRefreshLoop periodically refreshes the access token +func (c *WeComAppChannel) tokenRefreshLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if err := c.refreshAccessToken(); err != nil { + logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]interface{}{ + "error": err.Error(), + }) + } + } + } +} + +// refreshAccessToken gets a new access token from WeCom API +func (c *WeComAppChannel) refreshAccessToken() error { + apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s", + wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret)) + + resp, err := http.Get(apiURL) + if err != nil { + return fmt.Errorf("failed to request access token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var tokenResp WeComAccessTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if tokenResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode) + } + + c.tokenMu.Lock() + c.accessToken = tokenResp.AccessToken + c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early + c.tokenMu.Unlock() + + logger.DebugC("wecom_app", "Access token refreshed successfully") + return nil +} + +// getAccessToken returns the current valid access token +func (c *WeComAppChannel) getAccessToken() string { + c.tokenMu.RLock() + defer c.tokenMu.RUnlock() + + if time.Now().After(c.tokenExpiry) { + return "" + } + + return c.accessToken +} + +// sendTextMessage sends a text message to a user +func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComTextMessage{ + ToUser: userID, + MsgType: "text", + AgentID: c.config.AgentID, + } + msg.Text.Content = content + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(body, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// sendMarkdownMessage sends a markdown message to a user +func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, userID, content string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComMarkdownMessage{ + ToUser: userID, + MsgType: "markdown", + AgentID: c.config.AgentID, + } + msg.Markdown.Content = content + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(body, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// handleHealth handles health check requests +func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) { + status := map[string]interface{}{ + "status": "ok", + "running": c.IsRunning(), + "has_token": c.getAccessToken() != "", + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go new file mode 100644 index 000000000..d9817fd49 --- /dev/null +++ b/pkg/channels/wecom/app_test.go @@ -0,0 +1,1086 @@ +package wecom + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// generateTestAESKeyApp generates a valid test AES key for WeCom App +func generateTestAESKeyApp() string { + // AES key needs to be 32 bytes (256 bits) for AES-256 + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + // Return base64 encoded key without padding + return base64.StdEncoding.EncodeToString(key)[:43] +} + +// encryptTestMessageApp encrypts a message for testing WeCom App +func encryptTestMessageApp(message, aesKey string) (string, error) { + // Decode AES key + key, err := base64.StdEncoding.DecodeString(aesKey + "=") + if err != nil { + return "", err + } + + // Prepare message: random(16) + msg_len(4) + msg + corp_id + random := make([]byte, 0, 16) + for i := 0; i < 16; i++ { + random = append(random, byte(i+1)) + } + + msgBytes := []byte(message) + corpID := []byte("test_corp_id") + + msgLen := uint32(len(msgBytes)) + lenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(lenBytes, msgLen) + + plainText := append(random, lenBytes...) + plainText = append(plainText, msgBytes...) + plainText = append(plainText, corpID...) + + // PKCS7 padding + blockSize := aes.BlockSize + padding := blockSize - len(plainText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + plainText = append(plainText, padText...) + + // Encrypt + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) + cipherText := make([]byte, len(plainText)) + mode.CryptBlocks(cipherText, plainText) + + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// generateSignatureApp generates a signature for testing WeCom App +func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string { + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + str := strings.Join(params, "") + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash) +} + +func TestNewWeComAppChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing corp_id", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "", + CorpSecret: "test_secret", + AgentID: 1000002, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing corp_id, got nil") + } + }) + + t.Run("missing corp_secret", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "", + AgentID: 1000002, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing corp_secret, got nil") + } + }) + + t.Run("missing agent_id", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 0, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing agent_id, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{"user1", "user2"}, + } + ch, err := NewWeComAppChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "wecom_app" { + t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestWeComAppChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{}, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + if !ch.IsAllowed("any_user") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{"allowed_user"}, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + if !ch.IsAllowed("allowed_user") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("blocked_user") { + t.Error("non-allowed user should be blocked") + } + }) +} + +func TestWeComAppVerifySignature(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) + + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + t.Error("valid signature should pass verification") + } + }) + + t.Run("invalid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + t.Error("invalid signature should fail verification") + } + }) + + t.Run("empty token skips verification", func(t *testing.T) { + cfgEmpty := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "", + } + chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) + + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should skip verification and return true") + } + }) +} + +func TestWeComAppDecryptMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("decrypt without AES key", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + // Without AES key, message should be base64 decoded only + plainText := "hello world" + encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) + + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != plainText { + t.Errorf("decryptMessage() = %q, want %q", result, plainText) + } + }) + + t.Run("decrypt with AES key", func(t *testing.T) { + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + originalMsg := "Hello" + encrypted, err := encryptTestMessageApp(originalMsg, aesKey) + if err != nil { + t.Fatalf("failed to encrypt test message: %v", err) + } + + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != originalMsg { + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) + } + }) + + t.Run("invalid base64", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid base64, got nil") + } + }) + + t.Run("invalid AES key", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "invalid_key", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid AES key, got nil") + } + }) + + t.Run("ciphertext too short", func(t *testing.T) { + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + // Encrypt a very short message that results in ciphertext less than block size + shortData := make([]byte, 8) + _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for short ciphertext, got nil") + } + }) +} + +func TestWeComAppPKCS7Unpad(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "empty input", + input: []byte{}, + expected: []byte{}, + }, + { + name: "valid padding 3 bytes", + input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), + expected: []byte("hello"), + }, + { + name: "valid padding 16 bytes (full block)", + input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("123456789012345"), + }, + { + name: "invalid padding larger than data", + input: []byte{20}, + expected: nil, // should return error + }, + { + name: "invalid padding zero", + input: append([]byte("test"), byte(0)), + expected: nil, // should return error + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs7Unpad(tt.input) + if tt.expected == nil { + // This case should return an error + if err == nil { + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) + } + return + } + if err != nil { + t.Errorf("pkcs7Unpad() unexpected error: %v", err) + return + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestWeComAppHandleVerification(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid verification request", func(t *testing.T) { + echostr := "test_echostr_123" + encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != echostr { + t.Errorf("response body = %q, want %q", w.Body.String(), echostr) + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=sig×tamp=ts", nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + echostr := "test_echostr" + encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComAppHandleMessageCallback(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid message callback", func(t *testing.T) { + // Create XML message + xmlMsg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "text", + Content: "Hello World", + MsgId: 123456, + AgentID: 1000002, + } + xmlData, _ := xml.Marshal(xmlMsg) + + // Encrypt message + encrypted, _ := encryptTestMessageApp(string(xmlData), aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=sig", nil) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid XML", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, "") + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, strings.NewReader("invalid xml")) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: "encrypted_data", + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComAppProcessMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("process text message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "text", + Content: "Hello World", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process image message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "image", + PicUrl: "https://example.com/image.jpg", + MediaId: "media_123", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process voice message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "voice", + MediaId: "media_123", + Format: "amr", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("skip unsupported message type", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "video", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process event message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "event", + Event: "subscribe", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) +} + +func TestWeComAppHandleWebhook(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("GET request calls verification", func(t *testing.T) { + echostr := "test_echostr" + encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encoded) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("POST request calls message callback", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + // Should not be method not allowed + if w.Code == http.StatusMethodNotAllowed { + t.Error("POST request should not return Method Not Allowed") + } + }) + + t.Run("unsupported method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/webhook/wecom-app", nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } + }) +} + +func TestWeComAppHandleHealth(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + req := httptest.NewRequest(http.MethodGet, "/health/wecom-app", nil) + w := httptest.NewRecorder() + + ch.handleHealth(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Content-Type = %q, want %q", contentType, "application/json") + } + + body := w.Body.String() + if !strings.Contains(body, "status") || !strings.Contains(body, "running") || !strings.Contains(body, "has_token") { + t.Errorf("response body should contain status, running, and has_token fields, got: %s", body) + } +} + +func TestWeComAppAccessToken(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("get empty access token initially", func(t *testing.T) { + token := ch.getAccessToken() + if token != "" { + t.Errorf("getAccessToken() = %q, want empty string", token) + } + }) + + t.Run("set and get access token", func(t *testing.T) { + ch.tokenMu.Lock() + ch.accessToken = "test_token_123" + ch.tokenExpiry = time.Now().Add(1 * time.Hour) + ch.tokenMu.Unlock() + + token := ch.getAccessToken() + if token != "test_token_123" { + t.Errorf("getAccessToken() = %q, want %q", token, "test_token_123") + } + }) + + t.Run("expired token returns empty", func(t *testing.T) { + ch.tokenMu.Lock() + ch.accessToken = "expired_token" + ch.tokenExpiry = time.Now().Add(-1 * time.Hour) + ch.tokenMu.Unlock() + + token := ch.getAccessToken() + if token != "" { + t.Errorf("getAccessToken() = %q, want empty string for expired token", token) + } + }) +} + +func TestWeComAppMessageStructures(t *testing.T) { + t.Run("WeComTextMessage structure", func(t *testing.T) { + msg := WeComTextMessage{ + ToUser: "user123", + MsgType: "text", + AgentID: 1000002, + } + msg.Text.Content = "Hello World" + + if msg.ToUser != "user123" { + t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } + + // Test JSON marshaling + jsonData, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal JSON: %v", err) + } + + var unmarshaled WeComTextMessage + err = json.Unmarshal(jsonData, &unmarshaled) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if unmarshaled.ToUser != msg.ToUser { + t.Errorf("JSON round-trip failed for ToUser") + } + }) + + t.Run("WeComMarkdownMessage structure", func(t *testing.T) { + msg := WeComMarkdownMessage{ + ToUser: "user123", + MsgType: "markdown", + AgentID: 1000002, + } + msg.Markdown.Content = "# Hello\nWorld" + + if msg.Markdown.Content != "# Hello\nWorld" { + t.Errorf("Markdown.Content = %q, want %q", msg.Markdown.Content, "# Hello\nWorld") + } + + // Test JSON marshaling + jsonData, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal JSON: %v", err) + } + + if !bytes.Contains(jsonData, []byte("markdown")) { + t.Error("JSON should contain 'markdown' field") + } + }) + + t.Run("WeComImageMessage structure", func(t *testing.T) { + msg := WeComImageMessage{ + ToUser: "user123", + MsgType: "image", + AgentID: 1000002, + } + msg.Image.MediaID = "media_123456" + + if msg.Image.MediaID != "media_123456" { + t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") + } + }) + + t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { + jsonData := `{ + "errcode": 0, + "errmsg": "ok", + "access_token": "test_access_token", + "expires_in": 7200 + }` + + var resp WeComAccessTokenResponse + err := json.Unmarshal([]byte(jsonData), &resp) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if resp.ErrCode != 0 { + t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) + } + if resp.ErrMsg != "ok" { + t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") + } + if resp.AccessToken != "test_access_token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test_access_token") + } + if resp.ExpiresIn != 7200 { + t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, 7200) + } + }) + + t.Run("WeComSendMessageResponse structure", func(t *testing.T) { + jsonData := `{ + "errcode": 0, + "errmsg": "ok", + "invaliduser": "", + "invalidparty": "", + "invalidtag": "" + }` + + var resp WeComSendMessageResponse + err := json.Unmarshal([]byte(jsonData), &resp) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if resp.ErrCode != 0 { + t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) + } + if resp.ErrMsg != "ok" { + t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") + } + }) +} + +func TestWeComAppXMLMessageStructure(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.ToUserName != "corp_id" { + t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id") + } + if msg.FromUserName != "user123" { + t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123") + } + if msg.CreateTime != 1234567890 { + t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890) + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Content != "Hello World" { + t.Errorf("Content = %q, want %q", msg.Content, "Hello World") + } + if msg.MsgId != 1234567890123456 { + t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456) + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } +} + +func TestWeComAppXMLMessageImage(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "image" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") + } + if msg.PicUrl != "https://example.com/image.jpg" { + t.Errorf("PicUrl = %q, want %q", msg.PicUrl, "https://example.com/image.jpg") + } + if msg.MediaId != "media_123" { + t.Errorf("MediaId = %q, want %q", msg.MediaId, "media_123") + } +} + +func TestWeComAppXMLMessageVoice(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "voice" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "voice") + } + if msg.Format != "amr" { + t.Errorf("Format = %q, want %q", msg.Format, "amr") + } +} + +func TestWeComAppXMLMessageLocation(t *testing.T) { + xmlData := ` + + + + 1234567890 + + 39.9042 + 116.4074 + 16 + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "location" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "location") + } + if msg.LocationX != 39.9042 { + t.Errorf("LocationX = %f, want %f", msg.LocationX, 39.9042) + } + if msg.LocationY != 116.4074 { + t.Errorf("LocationY = %f, want %f", msg.LocationY, 116.4074) + } + if msg.Scale != 16 { + t.Errorf("Scale = %d, want %d", msg.Scale, 16) + } + if msg.Label != "Beijing" { + t.Errorf("Label = %q, want %q", msg.Label, "Beijing") + } +} + +func TestWeComAppXMLMessageLink(t *testing.T) { + xmlData := ` + + + + 1234567890 + + <![CDATA[Link Title]]> + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "link" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "link") + } + if msg.Title != "Link Title" { + t.Errorf("Title = %q, want %q", msg.Title, "Link Title") + } + if msg.Description != "Link Description" { + t.Errorf("Description = %q, want %q", msg.Description, "Link Description") + } + if msg.Url != "https://example.com" { + t.Errorf("Url = %q, want %q", msg.Url, "https://example.com") + } +} + +func TestWeComAppXMLMessageEvent(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "event" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "event") + } + if msg.Event != "subscribe" { + t.Errorf("Event = %q, want %q", msg.Event, "subscribe") + } + if msg.EventKey != "event_key_123" { + t.Errorf("EventKey = %q, want %q", msg.EventKey, "event_key_123") + } +} diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go new file mode 100644 index 000000000..9683a308f --- /dev/null +++ b/pkg/channels/wecom/bot.go @@ -0,0 +1,469 @@ +package wecom + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) +// Uses webhook callback mode - simpler than WeCom App but only supports passive replies +type WeComBotChannel struct { + *channels.BaseChannel + config config.WeComConfig + server *http.Server + ctx context.Context + cancel context.CancelFunc + processedMsgs map[string]bool // Message deduplication: msg_id -> processed + msgMu sync.RWMutex +} + +// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) +type WeComBotMessage struct { + MsgID string `json:"msgid"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid"` // Session ID, only present for group chats + ChatType string `json:"chattype"` // "single" for DM, "group" for group chat + From struct { + UserID string `json:"userid"` + } `json:"from"` + ResponseURL string `json:"response_url"` + MsgType string `json:"msgtype"` // text, image, voice, file, mixed + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + Voice struct { + Content string `json:"content"` // Voice to text content + } `json:"voice"` + File struct { + URL string `json:"url"` + } `json:"file"` + Mixed struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + } `json:"msg_item"` + } `json:"mixed"` + Quote struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + } `json:"quote"` +} + +// WeComBotReplyMessage represents the reply message structure +type WeComBotReplyMessage struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text,omitempty"` +} + +// NewWeComBotChannel creates a new WeCom Bot channel instance +func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) { + if cfg.Token == "" || cfg.WebhookURL == "" { + return nil, fmt.Errorf("wecom token and webhook_url are required") + } + + base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) + + return &WeComBotChannel{ + BaseChannel: base, + config: cfg, + processedMsgs: make(map[string]bool), + }, nil +} + +// Name returns the channel name +func (c *WeComBotChannel) Name() string { + return "wecom" +} + +// Start initializes the WeCom Bot channel with HTTP webhook server +func (c *WeComBotChannel) Start(ctx context.Context) error { + logger.InfoC("wecom", "Starting WeCom Bot channel...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Setup HTTP server for webhook + mux := http.NewServeMux() + webhookPath := c.config.WebhookPath + if webhookPath == "" { + webhookPath = "/webhook/wecom" + } + mux.HandleFunc(webhookPath, c.handleWebhook) + + // Health check endpoint + mux.HandleFunc("/health/wecom", c.handleHealth) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.server = &http.Server{ + Addr: addr, + Handler: mux, + } + + c.SetRunning(true) + logger.InfoCF("wecom", "WeCom Bot channel started", map[string]interface{}{ + "address": addr, + "path": webhookPath, + }) + + // Start server in goroutine + go func() { + if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("wecom", "HTTP server error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + return nil +} + +// Stop gracefully stops the WeCom Bot channel +func (c *WeComBotChannel) Stop(ctx context.Context) error { + logger.InfoC("wecom", "Stopping WeCom Bot channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.server != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + c.server.Shutdown(shutdownCtx) + } + + c.SetRunning(false) + logger.InfoC("wecom", "WeCom Bot channel stopped") + return nil +} + +// Send sends a message to WeCom user via webhook API +// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message +// For delayed responses, we use the webhook URL +func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("wecom channel not running") + } + + logger.DebugCF("wecom", "Sending message via webhook", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) +} + +// handleWebhook handles incoming webhook requests from WeCom +func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if r.Method == http.MethodGet { + // Handle verification request + c.handleVerification(ctx, w, r) + return + } + + if r.Method == http.MethodPost { + // Handle message callback + c.handleMessageCallback(ctx, w, r) + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +// handleVerification handles the URL verification request from WeCom +func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + echostr := query.Get("echostr") + + if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + logger.WarnC("wecom", "Signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt echostr + // For AIBOT (智能机器人), receiveid should be empty string "" + // Reference: https://developer.work.weixin.qq.com/document/path/101033 + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Remove BOM and whitespace as per WeCom documentation + // The response must be plain text without quotes, BOM, or newlines + decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) + decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM + w.Write([]byte(decryptedEchoStr)) +} + +// handleMessageCallback handles incoming messages from WeCom +func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + + if msgSignature == "" || timestamp == "" || nonce == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Parse XML to get encrypted message + var encryptedMsg struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + Encrypt string `xml:"Encrypt"` + AgentID string `xml:"AgentID"` + } + + if err := xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom", "Failed to parse XML", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid XML", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + logger.WarnC("wecom", "Message signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt message + // For AIBOT (智能机器人), receiveid should be empty string "" + // Reference: https://developer.work.weixin.qq.com/document/path/101033 + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom", "Failed to decrypt message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Parse decrypted JSON message (AIBOT uses JSON format) + var msg WeComBotMessage + if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { + logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid message format", http.StatusBadRequest) + return + } + + // Process the message asynchronously with context + go c.processMessage(ctx, msg) + + // Return success response immediately + // WeCom Bot requires response within configured timeout (default 5 seconds) + w.Write([]byte("success")) +} + +// processMessage processes the received message +func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { + // Skip unsupported message types + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && msg.MsgType != "mixed" { + logger.DebugCF("wecom", "Skipping non-supported message type", map[string]interface{}{ + "msg_type": msg.MsgType, + }) + return + } + + // Message deduplication: Use msg_id to prevent duplicate processing + msgID := msg.MsgID + c.msgMu.Lock() + if c.processedMsgs[msgID] { + c.msgMu.Unlock() + logger.DebugCF("wecom", "Skipping duplicate message", map[string]interface{}{ + "msg_id": msgID, + }) + return + } + c.processedMsgs[msgID] = true + c.msgMu.Unlock() + + // Clean up old messages periodically (keep last 1000) + if len(c.processedMsgs) > 1000 { + c.msgMu.Lock() + c.processedMsgs = make(map[string]bool) + c.msgMu.Unlock() + } + + senderID := msg.From.UserID + + // Determine if this is a group chat or direct message + // ChatType: "single" for DM, "group" for group chat + isGroupChat := msg.ChatType == "group" + + var chatID, peerKind, peerID string + if isGroupChat { + // Group chat: use ChatID as chatID and peer_id + chatID = msg.ChatID + peerKind = "group" + peerID = msg.ChatID + } else { + // Direct message: use senderID as chatID and peer_id + chatID = senderID + peerKind = "direct" + peerID = senderID + } + + // Extract content based on message type + var content string + switch msg.MsgType { + case "text": + content = msg.Text.Content + case "voice": + content = msg.Voice.Content // Voice to text content + case "mixed": + // For mixed messages, concatenate text items + for _, item := range msg.Mixed.MsgItem { + if item.MsgType == "text" { + content += item.Text.Content + } + } + case "image", "file": + // For image and file, we don't have text content + content = "" + } + + // Build metadata + metadata := map[string]string{ + "msg_type": msg.MsgType, + "msg_id": msg.MsgID, + "platform": "wecom", + "peer_kind": peerKind, + "peer_id": peerID, + "response_url": msg.ResponseURL, + } + if isGroupChat { + metadata["chat_id"] = msg.ChatID + metadata["sender_id"] = senderID + } + + logger.DebugCF("wecom", "Received message", map[string]interface{}{ + "sender_id": senderID, + "msg_type": msg.MsgType, + "peer_kind": peerKind, + "is_group_chat": isGroupChat, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +// sendWebhookReply sends a reply using the webhook URL +func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error { + reply := WeComBotReplyMessage{ + MsgType: "text", + } + reply.Text.Content = content + + jsonData, err := json.Marshal(reply) + if err != nil { + return fmt.Errorf("failed to marshal reply: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send webhook reply: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // Check response + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if result.ErrCode != 0 { + return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode) + } + + return nil +} + +// handleHealth handles health check requests +func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { + status := map[string]interface{}{ + "status": "ok", + "running": c.IsRunning(), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go new file mode 100644 index 000000000..460e0058f --- /dev/null +++ b/pkg/channels/wecom/bot_test.go @@ -0,0 +1,753 @@ +package wecom + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +// generateTestAESKey generates a valid test AES key +func generateTestAESKey() string { + // AES key needs to be 32 bytes (256 bits) for AES-256 + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + // Return base64 encoded key without padding + return base64.StdEncoding.EncodeToString(key)[:43] +} + +// encryptTestMessage encrypts a message for testing (AIBOT JSON format) +func encryptTestMessage(message, aesKey string) (string, error) { + // Decode AES key + key, err := base64.StdEncoding.DecodeString(aesKey + "=") + if err != nil { + return "", err + } + + // Prepare message: random(16) + msg_len(4) + msg + receiveid + random := make([]byte, 0, 16) + for i := 0; i < 16; i++ { + random = append(random, byte(i)) + } + + msgBytes := []byte(message) + receiveID := []byte("test_aibot_id") + + msgLen := uint32(len(msgBytes)) + lenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(lenBytes, msgLen) + + plainText := append(random, lenBytes...) + plainText = append(plainText, msgBytes...) + plainText = append(plainText, receiveID...) + + // PKCS7 padding + blockSize := aes.BlockSize + padding := blockSize - len(plainText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + plainText = append(plainText, padText...) + + // Encrypt + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) + cipherText := make([]byte, len(plainText)) + mode.CryptBlocks(cipherText, plainText) + + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// generateSignature generates a signature for testing +func generateSignature(token, timestamp, nonce, msgEncrypt string) string { + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + str := strings.Join(params, "") + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash) +} + +func TestNewWeComBotChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing token", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + _, err := NewWeComBotChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing token, got nil") + } + }) + + t.Run("missing webhook_url", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "", + } + _, err := NewWeComBotChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing webhook_url, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{"user1", "user2"}, + } + ch, err := NewWeComBotChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "wecom" { + t.Errorf("Name() = %q, want %q", ch.Name(), "wecom") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestWeComBotChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{}, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + if !ch.IsAllowed("any_user") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{"allowed_user"}, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + if !ch.IsAllowed("allowed_user") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("blocked_user") { + t.Error("non-allowed user should be blocked") + } + }) +} + +func TestWeComBotVerifySignature(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) + + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + t.Error("valid signature should pass verification") + } + }) + + t.Run("invalid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + t.Error("invalid signature should fail verification") + } + }) + + t.Run("empty token skips verification", func(t *testing.T) { + // Create a channel manually with empty token to test the behavior + cfgEmpty := config.WeComConfig{ + Token: "", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + base := channels.NewBaseChannel("wecom", cfgEmpty, msgBus, cfgEmpty.AllowFrom) + chEmpty := &WeComBotChannel{ + BaseChannel: base, + config: cfgEmpty, + } + + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should skip verification and return true") + } + }) +} + +func TestWeComBotDecryptMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("decrypt without AES key", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + // Without AES key, message should be base64 decoded only + plainText := "hello world" + encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) + + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != plainText { + t.Errorf("decryptMessage() = %q, want %q", result, plainText) + } + }) + + t.Run("decrypt with AES key", func(t *testing.T) { + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + originalMsg := "Hello" + encrypted, err := encryptTestMessage(originalMsg, aesKey) + if err != nil { + t.Fatalf("failed to encrypt test message: %v", err) + } + + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != originalMsg { + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) + } + }) + + t.Run("invalid base64", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid base64, got nil") + } + }) + + t.Run("invalid AES key", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "invalid_key", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid AES key, got nil") + } + }) +} + +func TestWeComBotPKCS7Unpad(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "empty input", + input: []byte{}, + expected: []byte{}, + }, + { + name: "valid padding 3 bytes", + input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), + expected: []byte("hello"), + }, + { + name: "valid padding 16 bytes (full block)", + input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("123456789012345"), + }, + { + name: "invalid padding larger than data", + input: []byte{20}, + expected: nil, // should return error + }, + { + name: "invalid padding zero", + input: append([]byte("test"), byte(0)), + expected: nil, // should return error + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs7Unpad(tt.input) + if tt.expected == nil { + // This case should return an error + if err == nil { + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) + } + return + } + if err != nil { + t.Errorf("pkcs7Unpad() unexpected error: %v", err) + return + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestWeComBotHandleVerification(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + EncodingAESKey: aesKey, + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid verification request", func(t *testing.T) { + echostr := "test_echostr_123" + encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != echostr { + t.Errorf("response body = %q, want %q", w.Body.String(), echostr) + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + echostr := "test_echostr" + encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComBotHandleMessageCallback(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + EncodingAESKey: aesKey, + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid direct message callback", func(t *testing.T) { + // Create JSON message for direct chat (single) + jsonMsg := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chattype": "single", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` + + // Encrypt message + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("valid group message callback", func(t *testing.T) { + // Create JSON message for group chat + jsonMsg := `{ + "msgid": "test_msg_id_456", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user456"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello Group"} + }` + + // Encrypt message + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid XML", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, "") + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, strings.NewReader("invalid xml")) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: "encrypted_data", + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComBotProcessMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("process direct text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_123", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", + } + msg.From.UserID = "user123" + msg.Text.Content = "Hello World" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process group text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_456", + AIBotID: "test_aibot_id", + ChatID: "group_chat_id_123", + ChatType: "group", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", + } + msg.From.UserID = "user456" + msg.Text.Content = "Hello Group" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process voice message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_789", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "voice", + } + msg.From.UserID = "user123" + msg.Voice.Content = "Voice message text" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("skip unsupported message type", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_000", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "video", + } + msg.From.UserID = "user123" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) +} + +func TestWeComBotHandleWebhook(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("GET request calls verification", func(t *testing.T) { + echostr := "test_echostr" + encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encoded) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("POST request calls message callback", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + // Should not be method not allowed + if w.Code == http.StatusMethodNotAllowed { + t.Error("POST request should not return Method Not Allowed") + } + }) + + t.Run("unsupported method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } + }) +} + +func TestWeComBotHandleHealth(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil) + w := httptest.NewRecorder() + + ch.handleHealth(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Content-Type = %q, want %q", contentType, "application/json") + } + + body := w.Body.String() + if !strings.Contains(body, "status") || !strings.Contains(body, "running") { + t.Errorf("response body should contain status and running fields, got: %s", body) + } +} + +func TestWeComBotReplyMessage(t *testing.T) { + msg := WeComBotReplyMessage{ + MsgType: "text", + } + msg.Text.Content = "Hello World" + + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } +} + +func TestWeComBotMessageStructure(t *testing.T) { + jsonData := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` + + var msg WeComBotMessage + err := json.Unmarshal([]byte(jsonData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if msg.MsgID != "test_msg_id_123" { + t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123") + } + if msg.AIBotID != "test_aibot_id" { + t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id") + } + if msg.ChatID != "group_chat_id_123" { + t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123") + } + if msg.ChatType != "group" { + t.Errorf("ChatType = %q, want %q", msg.ChatType, "group") + } + if msg.From.UserID != "user123" { + t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123") + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } +} diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go new file mode 100644 index 000000000..3c1629577 --- /dev/null +++ b/pkg/channels/wecom/common.go @@ -0,0 +1,134 @@ +package wecom + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "sort" + "strings" +) + +// blockSize is the PKCS7 block size used by WeCom (32) +const blockSize = 32 + +// verifySignature verifies the message signature for WeCom +// This is a common function used by both WeCom Bot and WeCom App +func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { + if token == "" { + return true // Skip verification if token is not set + } + + // Sort parameters + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + + // Concatenate + str := strings.Join(params, "") + + // SHA1 hash + hash := sha1.Sum([]byte(str)) + expectedSignature := fmt.Sprintf("%x", hash) + + return expectedSignature == msgSignature +} + +// decryptMessage decrypts the encrypted message using AES +// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id +func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) { + return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "") +} + +// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid +// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. +func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { + if encodingAESKey == "" { + // No encryption, return as is (base64 decode) + decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", err + } + return string(decoded), nil + } + + // Decode AES key (base64) + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", fmt.Errorf("failed to decode AES key: %w", err) + } + + // Decode encrypted message + cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", fmt.Errorf("failed to decode message: %w", err) + } + + // AES decrypt + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + if len(cipherText) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + // IV is the first 16 bytes of AESKey + iv := aesKey[:aes.BlockSize] + mode := cipher.NewCBCDecrypter(block, iv) + plainText := make([]byte, len(cipherText)) + mode.CryptBlocks(plainText, cipherText) + + // Remove PKCS7 padding + plainText, err = pkcs7Unpad(plainText) + if err != nil { + return "", fmt.Errorf("failed to unpad: %w", err) + } + + // Parse message structure + // Format: random(16) + msg_len(4) + msg + receiveid + if len(plainText) < 20 { + return "", fmt.Errorf("decrypted message too short") + } + + msgLen := binary.BigEndian.Uint32(plainText[16:20]) + if int(msgLen) > len(plainText)-20 { + return "", fmt.Errorf("invalid message length") + } + + msg := plainText[20 : 20+msgLen] + + // Verify receiveid if provided + if receiveid != "" && len(plainText) > 20+int(msgLen) { + actualReceiveID := string(plainText[20+msgLen:]) + if actualReceiveID != receiveid { + return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) + } + } + + return string(msg), nil +} + +// pkcs7Unpad removes PKCS7 padding with validation +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return data, nil + } + padding := int(data[len(data)-1]) + // WeCom uses 32-byte block size for PKCS7 padding + if padding == 0 || padding > blockSize { + return nil, fmt.Errorf("invalid padding size: %d", padding) + } + if padding > len(data) { + return nil, fmt.Errorf("padding size larger than data") + } + // Verify all padding bytes + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte at position %d", i) + } + } + return data[:len(data)-padding], nil +} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go new file mode 100644 index 000000000..3ef1ecdf3 --- /dev/null +++ b/pkg/channels/wecom/init.go @@ -0,0 +1,16 @@ +package wecom + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComBotChannel(cfg.Channels.WeCom, b) + }) + channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComAppChannel(cfg.Channels.WeComApp, b) + }) +} diff --git a/pkg/channels/whatsapp/init.go b/pkg/channels/whatsapp/init.go new file mode 100644 index 000000000..d9c2669c3 --- /dev/null +++ b/pkg/channels/whatsapp/init.go @@ -0,0 +1,13 @@ +package whatsapp + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWhatsAppChannel(cfg.Channels.WhatsApp, b) + }) +} diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go new file mode 100644 index 000000000..1ac256766 --- /dev/null +++ b/pkg/channels/whatsapp/whatsapp.go @@ -0,0 +1,193 @@ +package whatsapp + +import ( + "context" + "encoding/json" + "fmt" + "log" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type WhatsAppChannel struct { + *channels.BaseChannel + conn *websocket.Conn + config config.WhatsAppConfig + url string + mu sync.Mutex + connected bool +} + +func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { + base := channels.NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) + + return &WhatsAppChannel{ + BaseChannel: base, + config: cfg, + url: cfg.BridgeURL, + connected: false, + }, nil +} + +func (c *WhatsAppChannel) Start(ctx context.Context) error { + log.Printf("Starting WhatsApp channel connecting to %s...", c.url) + + dialer := websocket.DefaultDialer + dialer.HandshakeTimeout = 10 * time.Second + + conn, _, err := dialer.Dial(c.url, nil) + if err != nil { + return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) + } + + c.mu.Lock() + c.conn = conn + c.connected = true + c.mu.Unlock() + + c.SetRunning(true) + log.Println("WhatsApp channel connected") + + go c.listen(ctx) + + return nil +} + +func (c *WhatsAppChannel) Stop(ctx context.Context) error { + log.Println("Stopping WhatsApp channel...") + + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Printf("Error closing WhatsApp connection: %v", err) + } + c.conn = nil + } + + c.connected = false + c.SetRunning(false) + + return nil +} + +func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return fmt.Errorf("whatsapp connection not established") + } + + payload := map[string]interface{}{ + "type": "message", + "to": msg.ChatID, + "content": msg.Content, + } + + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + return nil +} + +func (c *WhatsAppChannel) listen(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + default: + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + time.Sleep(1 * time.Second) + continue + } + + _, message, err := conn.ReadMessage() + if err != nil { + log.Printf("WhatsApp read error: %v", err) + time.Sleep(2 * time.Second) + continue + } + + var msg map[string]interface{} + if err := json.Unmarshal(message, &msg); err != nil { + log.Printf("Failed to unmarshal WhatsApp message: %v", err) + continue + } + + msgType, ok := msg["type"].(string) + if !ok { + continue + } + + if msgType == "message" { + c.handleIncomingMessage(msg) + } + } + } +} + +func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { + senderID, ok := msg["from"].(string) + if !ok { + return + } + + chatID, ok := msg["chat"].(string) + if !ok { + chatID = senderID + } + + content, ok := msg["content"].(string) + if !ok { + content = "" + } + + var mediaPaths []string + if mediaData, ok := msg["media"].([]interface{}); ok { + mediaPaths = make([]string, 0, len(mediaData)) + for _, m := range mediaData { + if path, ok := m.(string); ok { + mediaPaths = append(mediaPaths, path) + } + } + } + + metadata := make(map[string]string) + if messageID, ok := msg["id"].(string); ok { + metadata["message_id"] = messageID + } + if userName, ok := msg["from_name"].(string); ok { + metadata["user_name"] = userName + } + + if chatID == senderID { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } + + log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +}