mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(channels): remove old channel files from parent package
This commit is contained in:
@@ -1,204 +0,0 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// DingTalk channel implementation using Stream Mode
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
|
||||
// It uses WebSocket for receiving messages via stream mode and API for sending
|
||||
type DingTalkChannel struct {
|
||||
*BaseChannel
|
||||
config config.DingTalkConfig
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Map to store session webhooks for each chat
|
||||
sessionWebhooks sync.Map // chatID -> sessionWebhook
|
||||
}
|
||||
|
||||
// NewDingTalkChannel creates a new DingTalk channel instance
|
||||
func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (*DingTalkChannel, error) {
|
||||
if cfg.ClientID == "" || cfg.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("dingtalk client_id and client_secret are required")
|
||||
}
|
||||
|
||||
base := NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &DingTalkChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
clientID: cfg.ClientID,
|
||||
clientSecret: cfg.ClientSecret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start initializes the DingTalk channel with Stream Mode
|
||||
func (c *DingTalkChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Create credential config
|
||||
cred := client.NewAppCredentialConfig(c.clientID, c.clientSecret)
|
||||
|
||||
// Create the stream client with options
|
||||
c.streamClient = client.NewStreamClient(
|
||||
client.WithAppCredential(cred),
|
||||
client.WithAutoReconnect(true),
|
||||
)
|
||||
|
||||
// Register chatbot callback handler (IChatBotMessageHandler is a function type)
|
||||
c.streamClient.RegisterChatBotCallbackRouter(c.onChatBotMessageReceived)
|
||||
|
||||
// Start the stream client
|
||||
if err := c.streamClient.Start(c.ctx); err != nil {
|
||||
return fmt.Errorf("failed to start stream client: %w", err)
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the DingTalk channel
|
||||
func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.streamClient != nil {
|
||||
c.streamClient.Close()
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("dingtalk", "DingTalk channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to DingTalk via the chatbot reply API
|
||||
func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("dingtalk channel not running")
|
||||
}
|
||||
|
||||
// Get session webhook from storage
|
||||
sessionWebhookRaw, ok := c.sessionWebhooks.Load(msg.ChatID)
|
||||
if !ok {
|
||||
return fmt.Errorf("no session_webhook found for chat %s, cannot send message", msg.ChatID)
|
||||
}
|
||||
|
||||
sessionWebhook, ok := sessionWebhookRaw.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
|
||||
}
|
||||
|
||||
logger.DebugCF("dingtalk", "Sending message", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
// Use the session webhook to send the reply
|
||||
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
|
||||
}
|
||||
|
||||
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
|
||||
// This is called by the Stream SDK when a new message arrives
|
||||
// IChatBotMessageHandler is: func(c context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error)
|
||||
func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
ctx context.Context,
|
||||
data *chatbot.BotCallbackDataModel,
|
||||
) ([]byte, error) {
|
||||
// Extract message content from Text field
|
||||
content := data.Text.Content
|
||||
if content == "" {
|
||||
// Try to extract from Content interface{} if Text is empty
|
||||
if contentMap, ok := data.Content.(map[string]any); ok {
|
||||
if textContent, ok := contentMap["content"].(string); ok {
|
||||
content = textContent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
return nil, nil // Ignore empty messages
|
||||
}
|
||||
|
||||
senderID := data.SenderStaffId
|
||||
senderNick := data.SenderNick
|
||||
chatID := senderID
|
||||
if data.ConversationType != "1" {
|
||||
// For group chats
|
||||
chatID = data.ConversationId
|
||||
}
|
||||
|
||||
// Store the session webhook for this chat so we can reply later
|
||||
c.sessionWebhooks.Store(chatID, data.SessionWebhook)
|
||||
|
||||
metadata := map[string]string{
|
||||
"sender_name": senderNick,
|
||||
"conversation_id": data.ConversationId,
|
||||
"conversation_type": data.ConversationType,
|
||||
"platform": "dingtalk",
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
|
||||
if data.ConversationType == "1" {
|
||||
metadata["peer_kind"] = "direct"
|
||||
metadata["peer_id"] = senderID
|
||||
} else {
|
||||
metadata["peer_kind"] = "group"
|
||||
metadata["peer_id"] = data.ConversationId
|
||||
}
|
||||
|
||||
logger.DebugCF("dingtalk", "Received message", map[string]any{
|
||||
"sender_nick": senderNick,
|
||||
"sender_id": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
|
||||
// Return nil to indicate we've handled the message asynchronously
|
||||
// The response will be sent through the message bus
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// SendDirectReply sends a direct reply using the session webhook
|
||||
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
|
||||
replier := chatbot.NewChatbotReplier()
|
||||
|
||||
// Convert string content to []byte for the API
|
||||
contentBytes := []byte(content)
|
||||
titleBytes := []byte("PicoClaw")
|
||||
|
||||
// Send markdown formatted reply
|
||||
err := replier.SimpleReplyMarkdown(
|
||||
ctx,
|
||||
sessionWebhook,
|
||||
titleBytes,
|
||||
contentBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send reply: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,373 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
const (
|
||||
transcriptionTimeout = 30 * time.Second
|
||||
sendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type DiscordChannel struct {
|
||||
*BaseChannel
|
||||
session *discordgo.Session
|
||||
config config.DiscordConfig
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
typingMu sync.Mutex
|
||||
typingStop map[string]chan struct{} // chatID → stop signal
|
||||
botUserID string // stored for mention checking
|
||||
}
|
||||
|
||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||
session, err := discordgo.New("Bot " + cfg.Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create discord session: %w", err)
|
||||
}
|
||||
|
||||
base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom)
|
||||
|
||||
return &DiscordChannel{
|
||||
BaseChannel: base,
|
||||
session: session,
|
||||
config: cfg,
|
||||
transcriber: nil,
|
||||
ctx: context.Background(),
|
||||
typingStop: make(map[string]chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) getContext() context.Context {
|
||||
if c.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("discord", "Starting Discord bot")
|
||||
|
||||
c.ctx = ctx
|
||||
|
||||
// Get bot user ID before opening session to avoid race condition
|
||||
botUser, err := c.session.User("@me")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get bot user: %w", err)
|
||||
}
|
||||
c.botUserID = botUser.ID
|
||||
|
||||
c.session.AddHandler(c.handleMessage)
|
||||
|
||||
if err := c.session.Open(); err != nil {
|
||||
return fmt.Errorf("failed to open discord session: %w", err)
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
|
||||
logger.InfoCF("discord", "Discord bot connected", map[string]any{
|
||||
"username": botUser.Username,
|
||||
"user_id": botUser.ID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("discord", "Stopping Discord bot")
|
||||
c.setRunning(false)
|
||||
|
||||
// Stop all typing goroutines before closing session
|
||||
c.typingMu.Lock()
|
||||
for chatID, stop := range c.typingStop {
|
||||
close(stop)
|
||||
delete(c.typingStop, chatID)
|
||||
}
|
||||
c.typingMu.Unlock()
|
||||
|
||||
if err := c.session.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close discord session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
c.stopTyping(msg.ChatID)
|
||||
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("discord bot not running")
|
||||
}
|
||||
|
||||
channelID := msg.ChatID
|
||||
if channelID == "" {
|
||||
return fmt.Errorf("channel ID is empty")
|
||||
}
|
||||
|
||||
runes := []rune(msg.Content)
|
||||
if len(runes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
|
||||
// Use the passed ctx for timeout control
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := c.session.ChannelMessageSend(channelID, content)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send discord message: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-sendCtx.Done():
|
||||
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// appendContent safely appends content to existing text
|
||||
func appendContent(content, suffix string) string {
|
||||
if content == "" {
|
||||
return suffix
|
||||
}
|
||||
return content + "\n" + suffix
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
if m == nil || m.Author == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
}
|
||||
|
||||
// Check allowlist first to avoid downloading attachments and transcribing for rejected users
|
||||
if !c.IsAllowed(m.Author.ID) {
|
||||
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
|
||||
"user_id": m.Author.ID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// If configured to only respond to mentions, check if bot is mentioned
|
||||
// Skip this check for DMs (GuildID is empty) - DMs should always be responded to
|
||||
if c.config.MentionOnly && m.GuildID != "" {
|
||||
isMentioned := false
|
||||
for _, mention := range m.Mentions {
|
||||
if mention.ID == c.botUserID {
|
||||
isMentioned = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isMentioned {
|
||||
logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{
|
||||
"user_id": m.Author.ID,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
senderID := m.Author.ID
|
||||
senderName := m.Author.Username
|
||||
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
|
||||
senderName += "#" + m.Author.Discriminator
|
||||
}
|
||||
|
||||
content := m.Content
|
||||
content = c.stripBotMention(content)
|
||||
mediaPaths := make([]string, 0, len(m.Attachments))
|
||||
localFiles := make([]string, 0, len(m.Attachments))
|
||||
|
||||
// Ensure temp files are cleaned up when function returns
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, attachment := range m.Attachments {
|
||||
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
|
||||
|
||||
if isAudio {
|
||||
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
|
||||
transcribedText := ""
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
|
||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||
cancel() // Release context resources immediately to avoid leaks in for loop
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
|
||||
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
|
||||
"text": result.Text,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
|
||||
}
|
||||
|
||||
content = appendContent(content, transcribedText)
|
||||
} else {
|
||||
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
|
||||
"url": attachment.URL,
|
||||
"filename": attachment.Filename,
|
||||
})
|
||||
mediaPaths = append(mediaPaths, attachment.URL)
|
||||
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||
}
|
||||
} else {
|
||||
mediaPaths = append(mediaPaths, attachment.URL)
|
||||
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||
}
|
||||
}
|
||||
|
||||
if content == "" && len(mediaPaths) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
content = "[media only]"
|
||||
}
|
||||
|
||||
// Start typing after all early returns — guaranteed to have a matching Send()
|
||||
c.startTyping(m.ChannelID)
|
||||
|
||||
logger.DebugCF("discord", "Received message", map[string]any{
|
||||
"sender_name": senderName,
|
||||
"sender_id": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := m.ChannelID
|
||||
if m.GuildID == "" {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": m.ID,
|
||||
"user_id": senderID,
|
||||
"username": m.Author.Username,
|
||||
"display_name": senderName,
|
||||
"guild_id": m.GuildID,
|
||||
"channel_id": m.ChannelID,
|
||||
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
// startTyping starts a continuous typing indicator loop for the given chatID.
|
||||
// It stops any existing typing loop for that chatID before starting a new one.
|
||||
func (c *DiscordChannel) startTyping(chatID string) {
|
||||
c.typingMu.Lock()
|
||||
// Stop existing loop for this chatID if any
|
||||
if stop, ok := c.typingStop[chatID]; ok {
|
||||
close(stop)
|
||||
}
|
||||
stop := make(chan struct{})
|
||||
c.typingStop[chatID] = stop
|
||||
c.typingMu.Unlock()
|
||||
|
||||
go func() {
|
||||
if err := c.session.ChannelTyping(chatID); err != nil {
|
||||
logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
|
||||
}
|
||||
ticker := time.NewTicker(8 * time.Second)
|
||||
defer ticker.Stop()
|
||||
timeout := time.After(5 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-timeout:
|
||||
return
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.session.ChannelTyping(chatID); err != nil {
|
||||
logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// stopTyping stops the typing indicator loop for the given chatID.
|
||||
func (c *DiscordChannel) stopTyping(chatID string) {
|
||||
c.typingMu.Lock()
|
||||
defer c.typingMu.Unlock()
|
||||
if stop, ok := c.typingStop[chatID]; ok {
|
||||
close(stop)
|
||||
delete(c.typingStop, chatID)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "discord",
|
||||
})
|
||||
}
|
||||
|
||||
// stripBotMention removes the bot mention from the message content.
|
||||
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
|
||||
func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
if c.botUserID == "" {
|
||||
return text
|
||||
}
|
||||
// Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID>
|
||||
text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "")
|
||||
text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// FeishuChannel is a stub implementation for 32-bit architectures
|
||||
type FeishuChannel struct {
|
||||
*BaseChannel
|
||||
}
|
||||
|
||||
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
return nil, errors.New(
|
||||
"feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config",
|
||||
)
|
||||
}
|
||||
|
||||
// Start is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type FeishuChannel struct {
|
||||
*BaseChannel
|
||||
config config.FeishuConfig
|
||||
client *lark.Client
|
||||
wsClient *larkws.Client
|
||||
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
base := NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom)
|
||||
|
||||
return &FeishuChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
if c.config.AppID == "" || c.config.AppSecret == "" {
|
||||
return fmt.Errorf("feishu app_id or app_secret is empty")
|
||||
}
|
||||
|
||||
dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey).
|
||||
OnP2MessageReceiveV1(c.handleMessageReceive)
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
c.mu.Lock()
|
||||
c.cancel = cancel
|
||||
c.wsClient = larkws.NewClient(
|
||||
c.config.AppID,
|
||||
c.config.AppSecret,
|
||||
larkws.WithEventHandler(dispatcher),
|
||||
)
|
||||
wsClient := c.wsClient
|
||||
c.mu.Unlock()
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("feishu", "Feishu channel started (websocket mode)")
|
||||
|
||||
go func() {
|
||||
if err := wsClient.Start(runCtx); err != nil {
|
||||
logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
c.cancel = nil
|
||||
}
|
||||
c.wsClient = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("feishu", "Feishu channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("feishu channel not running")
|
||||
}
|
||||
|
||||
if msg.ChatID == "" {
|
||||
return fmt.Errorf("chat ID is empty")
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(map[string]string{"text": msg.Content})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal feishu content: %w", err)
|
||||
}
|
||||
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(msg.ChatID).
|
||||
MsgType(larkim.MsgTypeText).
|
||||
Content(string(payload)).
|
||||
Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send feishu message: %w", err)
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg)
|
||||
}
|
||||
|
||||
logger.DebugCF("feishu", "Feishu message sent", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
if event == nil || event.Event == nil || event.Event.Message == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
message := event.Event.Message
|
||||
sender := event.Event.Sender
|
||||
|
||||
chatID := stringValue(message.ChatId)
|
||||
if chatID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
senderID := extractFeishuSenderID(sender)
|
||||
if senderID == "" {
|
||||
senderID = "unknown"
|
||||
}
|
||||
|
||||
content := extractFeishuMessageContent(message)
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
metadata := map[string]string{}
|
||||
if messageID := stringValue(message.MessageId); messageID != "" {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
if messageType := stringValue(message.MessageType); messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
if chatType := stringValue(message.ChatType); chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType == "p2p" {
|
||||
metadata["peer_kind"] = "direct"
|
||||
metadata["peer_id"] = senderID
|
||||
} else {
|
||||
metadata["peer_kind"] = "group"
|
||||
metadata["peer_id"] = chatID
|
||||
}
|
||||
|
||||
logger.InfoCF("feishu", "Feishu message received", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 80),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractFeishuSenderID(sender *larkim.EventSender) string {
|
||||
if sender == nil || sender.SenderId == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if sender.SenderId.UserId != nil && *sender.SenderId.UserId != "" {
|
||||
return *sender.SenderId.UserId
|
||||
}
|
||||
if sender.SenderId.OpenId != nil && *sender.SenderId.OpenId != "" {
|
||||
return *sender.SenderId.OpenId
|
||||
}
|
||||
if sender.SenderId.UnionId != nil && *sender.SenderId.UnionId != "" {
|
||||
return *sender.SenderId.UnionId
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractFeishuMessageContent(message *larkim.EventMessage) string {
|
||||
if message == nil || message.Content == nil || *message.Content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText {
|
||||
var textPayload struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil {
|
||||
return textPayload.Text
|
||||
}
|
||||
}
|
||||
|
||||
return *message.Content
|
||||
}
|
||||
|
||||
func stringValue(v *string) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return *v
|
||||
}
|
||||
@@ -1,606 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
lineAPIBase = "https://api.line.me/v2/bot"
|
||||
lineDataAPIBase = "https://api-data.line.me/v2/bot"
|
||||
lineReplyEndpoint = lineAPIBase + "/message/reply"
|
||||
linePushEndpoint = lineAPIBase + "/message/push"
|
||||
lineContentEndpoint = lineDataAPIBase + "/message/%s/content"
|
||||
lineBotInfoEndpoint = lineAPIBase + "/info"
|
||||
lineLoadingEndpoint = lineAPIBase + "/chat/loading/start"
|
||||
lineReplyTokenMaxAge = 25 * time.Second
|
||||
)
|
||||
|
||||
type replyTokenEntry struct {
|
||||
token string
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// LINEChannel implements the Channel interface for LINE Official Account
|
||||
// using the LINE Messaging API with HTTP webhook for receiving messages
|
||||
// and REST API for sending messages.
|
||||
type LINEChannel struct {
|
||||
*BaseChannel
|
||||
config config.LINEConfig
|
||||
httpServer *http.Server
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
botDisplayName string // Bot's display name for text-based mention detection
|
||||
replyTokens sync.Map // chatID -> replyTokenEntry
|
||||
quoteTokens sync.Map // chatID -> quoteToken (string)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewLINEChannel creates a new LINE channel instance.
|
||||
func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) {
|
||||
if cfg.ChannelSecret == "" || cfg.ChannelAccessToken == "" {
|
||||
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
|
||||
}
|
||||
|
||||
base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &LINEChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start launches the HTTP webhook server.
|
||||
func (c *LINEChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("line", "Starting LINE channel (Webhook Mode)")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Fetch bot profile to get bot's userId for mention detection
|
||||
if err := c.fetchBotInfo(); err != nil {
|
||||
logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("line", "Bot info fetched", map[string]any{
|
||||
"bot_user_id": c.botUserID,
|
||||
"basic_id": c.botBasicID,
|
||||
"display_name": c.botDisplayName,
|
||||
})
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
path := c.config.WebhookPath
|
||||
if path == "" {
|
||||
path = "/webhook/line"
|
||||
}
|
||||
mux.HandleFunc(path, c.webhookHandler)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
|
||||
c.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.InfoCF("line", "LINE webhook server listening", map[string]any{
|
||||
"addr": addr,
|
||||
"path": path,
|
||||
})
|
||||
if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("line", "Webhook server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("line", "LINE channel started (Webhook Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchBotInfo retrieves the bot's userId, basicId, and displayName from the LINE API.
|
||||
func (c *LINEChannel) fetchBotInfo() error {
|
||||
req, err := http.NewRequest(http.MethodGet, lineBotInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bot info API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var info struct {
|
||||
UserID string `json:"userId"`
|
||||
BasicID string `json:"basicId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.botUserID = info.UserID
|
||||
c.botBasicID = info.BasicID
|
||||
c.botDisplayName = info.DisplayName
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the HTTP server.
|
||||
func (c *LINEChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("line", "Stopping LINE channel")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.httpServer != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
if err := c.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("line", "LINE channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// webhookHandler handles incoming LINE webhook requests.
|
||||
func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
logger.ErrorCF("line", "Failed to read request body", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
signature := r.Header.Get("X-Line-Signature")
|
||||
if !c.verifySignature(body, signature) {
|
||||
logger.WarnC("line", "Invalid webhook signature")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Events []lineEvent `json:"events"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
logger.ErrorCF("line", "Failed to parse webhook payload", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Return 200 immediately, process events asynchronously
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, event := range payload.Events {
|
||||
go c.processEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// verifySignature validates the X-Line-Signature using HMAC-SHA256.
|
||||
func (c *LINEChannel) verifySignature(body []byte, signature string) bool {
|
||||
if signature == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(c.config.ChannelSecret))
|
||||
mac.Write(body)
|
||||
expected := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||
|
||||
return hmac.Equal([]byte(expected), []byte(signature))
|
||||
}
|
||||
|
||||
// LINE webhook event types
|
||||
type lineEvent struct {
|
||||
Type string `json:"type"`
|
||||
ReplyToken string `json:"replyToken"`
|
||||
Source lineSource `json:"source"`
|
||||
Message json.RawMessage `json:"message"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type lineSource struct {
|
||||
Type string `json:"type"` // "user", "group", "room"
|
||||
UserID string `json:"userId"`
|
||||
GroupID string `json:"groupId"`
|
||||
RoomID string `json:"roomId"`
|
||||
}
|
||||
|
||||
type lineMessage struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "text", "image", "video", "audio", "file", "sticker"
|
||||
Text string `json:"text"`
|
||||
QuoteToken string `json:"quoteToken"`
|
||||
Mention *struct {
|
||||
Mentionees []lineMentionee `json:"mentionees"`
|
||||
} `json:"mention"`
|
||||
ContentProvider struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"contentProvider"`
|
||||
}
|
||||
|
||||
type lineMentionee struct {
|
||||
Index int `json:"index"`
|
||||
Length int `json:"length"`
|
||||
Type string `json:"type"` // "user", "all"
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
|
||||
func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
if event.Type != "message" {
|
||||
logger.DebugCF("line", "Ignoring non-message event", map[string]any{
|
||||
"type": event.Type,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := event.Source.UserID
|
||||
chatID := c.resolveChatID(event.Source)
|
||||
isGroup := event.Source.Type == "group" || event.Source.Type == "room"
|
||||
|
||||
var msg lineMessage
|
||||
if err := json.Unmarshal(event.Message, &msg); err != nil {
|
||||
logger.ErrorCF("line", "Failed to parse message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// In group chats, only respond when the bot is mentioned
|
||||
if isGroup && !c.isBotMentioned(msg) {
|
||||
logger.DebugCF("line", "Ignoring group message without mention", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Store reply token for later use
|
||||
if event.ReplyToken != "" {
|
||||
c.replyTokens.Store(chatID, replyTokenEntry{
|
||||
token: event.ReplyToken,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// Store quote token for quoting the original message in reply
|
||||
if msg.QuoteToken != "" {
|
||||
c.quoteTokens.Store(chatID, msg.QuoteToken)
|
||||
}
|
||||
|
||||
var content string
|
||||
var mediaPaths []string
|
||||
localFiles := []string{}
|
||||
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
switch msg.Type {
|
||||
case "text":
|
||||
content = msg.Text
|
||||
// Strip bot mention from text in group chats
|
||||
if isGroup {
|
||||
content = c.stripBotMention(content, msg)
|
||||
}
|
||||
case "image":
|
||||
localPath := c.downloadContent(msg.ID, "image.jpg")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[image]"
|
||||
}
|
||||
case "audio":
|
||||
localPath := c.downloadContent(msg.ID, "audio.m4a")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[audio]"
|
||||
}
|
||||
case "video":
|
||||
localPath := c.downloadContent(msg.ID, "video.mp4")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[video]"
|
||||
}
|
||||
case "file":
|
||||
content = "[file]"
|
||||
case "sticker":
|
||||
content = "[sticker]"
|
||||
default:
|
||||
content = fmt.Sprintf("[%s]", msg.Type)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"platform": "line",
|
||||
"source_type": event.Source.Type,
|
||||
"message_id": msg.ID,
|
||||
}
|
||||
|
||||
if isGroup {
|
||||
metadata["peer_kind"] = "group"
|
||||
metadata["peer_id"] = chatID
|
||||
} else {
|
||||
metadata["peer_kind"] = "direct"
|
||||
metadata["peer_id"] = senderID
|
||||
}
|
||||
|
||||
logger.DebugCF("line", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_type": msg.Type,
|
||||
"is_group": isGroup,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Show typing/loading indicator (requires user ID, not group ID)
|
||||
c.sendLoading(senderID)
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message.
|
||||
// It first checks the mention metadata (userId match), then falls back
|
||||
// to text-based detection using the bot's display name, since LINE may
|
||||
// not include userId in mentionees for Official Accounts.
|
||||
func (c *LINEChannel) isBotMentioned(msg lineMessage) bool {
|
||||
// Check mention metadata
|
||||
if msg.Mention != nil {
|
||||
for _, m := range msg.Mention.Mentionees {
|
||||
if m.Type == "all" {
|
||||
return true
|
||||
}
|
||||
if c.botUserID != "" && m.UserID == c.botUserID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Mention metadata exists with mentionees but bot not matched by userId.
|
||||
// The bot IS likely mentioned (LINE includes mention struct when bot is @-ed),
|
||||
// so check if any mentionee overlaps with bot display name in text.
|
||||
if c.botDisplayName != "" {
|
||||
for _, m := range msg.Mention.Mentionees {
|
||||
if m.Index >= 0 && m.Length > 0 {
|
||||
runes := []rune(msg.Text)
|
||||
end := m.Index + m.Length
|
||||
if end <= len(runes) {
|
||||
mentionText := string(runes[m.Index:end])
|
||||
if strings.Contains(mentionText, c.botDisplayName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: text-based detection with display name
|
||||
if c.botDisplayName != "" && strings.Contains(msg.Text, "@"+c.botDisplayName) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// stripBotMention removes the @BotName mention text from the message.
|
||||
func (c *LINEChannel) stripBotMention(text string, msg lineMessage) string {
|
||||
stripped := false
|
||||
|
||||
// Try to strip using mention metadata indices
|
||||
if msg.Mention != nil {
|
||||
runes := []rune(text)
|
||||
for i := len(msg.Mention.Mentionees) - 1; i >= 0; i-- {
|
||||
m := msg.Mention.Mentionees[i]
|
||||
// Strip if userId matches OR if the mention text contains the bot display name
|
||||
shouldStrip := false
|
||||
if c.botUserID != "" && m.UserID == c.botUserID {
|
||||
shouldStrip = true
|
||||
} else if c.botDisplayName != "" && m.Index >= 0 && m.Length > 0 {
|
||||
end := m.Index + m.Length
|
||||
if end <= len(runes) {
|
||||
mentionText := string(runes[m.Index:end])
|
||||
if strings.Contains(mentionText, c.botDisplayName) {
|
||||
shouldStrip = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if shouldStrip {
|
||||
start := m.Index
|
||||
end := m.Index + m.Length
|
||||
if start >= 0 && end <= len(runes) {
|
||||
runes = append(runes[:start], runes[end:]...)
|
||||
stripped = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if stripped {
|
||||
return strings.TrimSpace(string(runes))
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: strip @DisplayName from text
|
||||
if c.botDisplayName != "" {
|
||||
text = strings.ReplaceAll(text, "@"+c.botDisplayName, "")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// resolveChatID determines the chat ID from the event source.
|
||||
// For group/room messages, use the group/room ID; for 1:1, use the user ID.
|
||||
func (c *LINEChannel) resolveChatID(source lineSource) string {
|
||||
switch source.Type {
|
||||
case "group":
|
||||
return source.GroupID
|
||||
case "room":
|
||||
return source.RoomID
|
||||
default:
|
||||
return source.UserID
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends a message to LINE. It first tries the Reply API (free)
|
||||
// using a cached reply token, then falls back to the Push API.
|
||||
func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("line channel not running")
|
||||
}
|
||||
|
||||
// Load and consume quote token for this chat
|
||||
var quoteToken string
|
||||
if qt, ok := c.quoteTokens.LoadAndDelete(msg.ChatID); ok {
|
||||
quoteToken = qt.(string)
|
||||
}
|
||||
|
||||
// Try reply token first (free, valid for ~25 seconds)
|
||||
if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok {
|
||||
tokenEntry := entry.(replyTokenEntry)
|
||||
if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge {
|
||||
if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil {
|
||||
logger.DebugCF("line", "Message sent via Reply API", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"quoted": quoteToken != "",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
logger.DebugC("line", "Reply API failed, falling back to Push API")
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to Push API
|
||||
return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken)
|
||||
}
|
||||
|
||||
// buildTextMessage creates a text message object, optionally with quoteToken.
|
||||
func buildTextMessage(content, quoteToken string) map[string]string {
|
||||
msg := map[string]string{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
}
|
||||
if quoteToken != "" {
|
||||
msg["quoteToken"] = quoteToken
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// sendReply sends a message using the LINE Reply API.
|
||||
func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error {
|
||||
payload := map[string]any{
|
||||
"replyToken": replyToken,
|
||||
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
|
||||
}
|
||||
|
||||
return c.callAPI(ctx, lineReplyEndpoint, payload)
|
||||
}
|
||||
|
||||
// sendPush sends a message using the LINE Push API.
|
||||
func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error {
|
||||
payload := map[string]any{
|
||||
"to": to,
|
||||
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
|
||||
}
|
||||
|
||||
return c.callAPI(ctx, linePushEndpoint, payload)
|
||||
}
|
||||
|
||||
// sendLoading sends a loading animation indicator to the chat.
|
||||
func (c *LINEChannel) sendLoading(chatID string) {
|
||||
payload := map[string]any{
|
||||
"chatId": chatID,
|
||||
"loadingSeconds": 60,
|
||||
}
|
||||
if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil {
|
||||
logger.DebugCF("line", "Failed to send loading indicator", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// callAPI makes an authenticated POST request to the LINE API.
|
||||
func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadContent downloads media content from the LINE API.
|
||||
func (c *LINEChannel) downloadContent(messageID, filename string) string {
|
||||
url := fmt.Sprintf(lineContentEndpoint, messageID)
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "line",
|
||||
ExtraHeaders: map[string]string{
|
||||
"Authorization": "Bearer " + c.config.ChannelAccessToken,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,243 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type MaixCamChannel struct {
|
||||
*BaseChannel
|
||||
config config.MaixCamConfig
|
||||
listener net.Listener
|
||||
clients map[net.Conn]bool
|
||||
clientsMux sync.RWMutex
|
||||
}
|
||||
|
||||
type MaixCamMessage struct {
|
||||
Type string `json:"type"`
|
||||
Tips string `json:"tips"`
|
||||
Timestamp float64 `json:"timestamp"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) {
|
||||
base := NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom)
|
||||
|
||||
return &MaixCamChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
clients: make(map[net.Conn]bool),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("maixcam", "Starting MaixCam channel server")
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
c.listener = listener
|
||||
c.setRunning(true)
|
||||
|
||||
logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{
|
||||
"host": c.config.Host,
|
||||
"port": c.config.Port,
|
||||
})
|
||||
|
||||
go c.acceptConnections(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
|
||||
logger.DebugC("maixcam", "Starting connection acceptor")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.InfoC("maixcam", "Stopping connection acceptor")
|
||||
return
|
||||
default:
|
||||
conn, err := c.listener.Accept()
|
||||
if err != nil {
|
||||
if c.running {
|
||||
logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]any{
|
||||
"remote_addr": conn.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
c.clientsMux.Lock()
|
||||
c.clients[conn] = true
|
||||
c.clientsMux.Unlock()
|
||||
|
||||
go c.handleConnection(conn, ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
|
||||
logger.DebugC("maixcam", "Handling MaixCam connection")
|
||||
|
||||
defer func() {
|
||||
conn.Close()
|
||||
c.clientsMux.Lock()
|
||||
delete(c.clients, conn)
|
||||
c.clientsMux.Unlock()
|
||||
logger.DebugC("maixcam", "Connection closed")
|
||||
}()
|
||||
|
||||
decoder := json.NewDecoder(conn)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
var msg MaixCamMessage
|
||||
if err := decoder.Decode(&msg); err != nil {
|
||||
if err.Error() != "EOF" {
|
||||
logger.ErrorCF("maixcam", "Failed to decode message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.processMessage(msg, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) processMessage(msg MaixCamMessage, conn net.Conn) {
|
||||
switch msg.Type {
|
||||
case "person_detected":
|
||||
c.handlePersonDetection(msg)
|
||||
case "heartbeat":
|
||||
logger.DebugC("maixcam", "Received heartbeat")
|
||||
case "status":
|
||||
c.handleStatusUpdate(msg)
|
||||
default:
|
||||
logger.WarnCF("maixcam", "Unknown message type", map[string]any{
|
||||
"type": msg.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
|
||||
logger.InfoCF("maixcam", "", map[string]any{
|
||||
"timestamp": msg.Timestamp,
|
||||
"data": msg.Data,
|
||||
})
|
||||
|
||||
senderID := "maixcam"
|
||||
chatID := "default"
|
||||
|
||||
classInfo, ok := msg.Data["class_name"].(string)
|
||||
if !ok {
|
||||
classInfo = "person"
|
||||
}
|
||||
|
||||
score, _ := msg.Data["score"].(float64)
|
||||
x, _ := msg.Data["x"].(float64)
|
||||
y, _ := msg.Data["y"].(float64)
|
||||
w, _ := msg.Data["w"].(float64)
|
||||
h, _ := msg.Data["h"].(float64)
|
||||
|
||||
content := fmt.Sprintf("📷 Person detected!\nClass: %s\nConfidence: %.2f%%\nPosition: (%.0f, %.0f)\nSize: %.0fx%.0f",
|
||||
classInfo, score*100, x, y, w, h)
|
||||
|
||||
metadata := map[string]string{
|
||||
"timestamp": fmt.Sprintf("%.0f", msg.Timestamp),
|
||||
"class_id": fmt.Sprintf("%.0f", msg.Data["class_id"]),
|
||||
"score": fmt.Sprintf("%.2f", score),
|
||||
"x": fmt.Sprintf("%.0f", x),
|
||||
"y": fmt.Sprintf("%.0f", y),
|
||||
"w": fmt.Sprintf("%.0f", w),
|
||||
"h": fmt.Sprintf("%.0f", h),
|
||||
"peer_kind": "channel",
|
||||
"peer_id": "default",
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
|
||||
logger.InfoCF("maixcam", "Status update from MaixCam", map[string]any{
|
||||
"status": msg.Data,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("maixcam", "Stopping MaixCam channel")
|
||||
c.setRunning(false)
|
||||
|
||||
if c.listener != nil {
|
||||
c.listener.Close()
|
||||
}
|
||||
|
||||
c.clientsMux.Lock()
|
||||
defer c.clientsMux.Unlock()
|
||||
|
||||
for conn := range c.clients {
|
||||
conn.Close()
|
||||
}
|
||||
c.clients = make(map[net.Conn]bool)
|
||||
|
||||
logger.InfoC("maixcam", "MaixCam channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("maixcam channel not running")
|
||||
}
|
||||
|
||||
c.clientsMux.RLock()
|
||||
defer c.clientsMux.RUnlock()
|
||||
|
||||
if len(c.clients) == 0 {
|
||||
logger.WarnC("maixcam", "No MaixCam devices connected")
|
||||
return fmt.Errorf("no connected MaixCam devices")
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"type": "command",
|
||||
"timestamp": float64(0),
|
||||
"message": msg.Content,
|
||||
"chat_id": msg.ChatID,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal response: %w", err)
|
||||
}
|
||||
|
||||
var sendErr error
|
||||
for conn := range c.clients {
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{
|
||||
"client": conn.RemoteAddr().String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
sendErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return sendErr
|
||||
}
|
||||
@@ -1,982 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type OneBotChannel struct {
|
||||
*BaseChannel
|
||||
config config.OneBotConfig
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
dedup map[string]struct{}
|
||||
dedupRing []string
|
||||
dedupIdx int
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
echoCounter int64
|
||||
selfID int64
|
||||
pending map[string]chan json.RawMessage
|
||||
pendingMu sync.Mutex
|
||||
transcriber *voice.GroqTranscriber
|
||||
lastMessageID sync.Map
|
||||
pendingEmojiMsg sync.Map
|
||||
}
|
||||
|
||||
type oneBotRawEvent struct {
|
||||
PostType string `json:"post_type"`
|
||||
MessageType string `json:"message_type"`
|
||||
SubType string `json:"sub_type"`
|
||||
MessageID json.RawMessage `json:"message_id"`
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
GroupID json.RawMessage `json:"group_id"`
|
||||
RawMessage string `json:"raw_message"`
|
||||
Message json.RawMessage `json:"message"`
|
||||
Sender json.RawMessage `json:"sender"`
|
||||
SelfID json.RawMessage `json:"self_id"`
|
||||
Time json.RawMessage `json:"time"`
|
||||
MetaEventType string `json:"meta_event_type"`
|
||||
NoticeType string `json:"notice_type"`
|
||||
Echo string `json:"echo"`
|
||||
RetCode json.RawMessage `json:"retcode"`
|
||||
Status json.RawMessage `json:"status"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type BotStatus struct {
|
||||
Online bool `json:"online"`
|
||||
Good bool `json:"good"`
|
||||
}
|
||||
|
||||
func isAPIResponse(raw json.RawMessage) bool {
|
||||
if len(raw) == 0 {
|
||||
return false
|
||||
}
|
||||
var s string
|
||||
if json.Unmarshal(raw, &s) == nil {
|
||||
return s == "ok" || s == "failed"
|
||||
}
|
||||
var bs BotStatus
|
||||
if json.Unmarshal(raw, &bs) == nil {
|
||||
return bs.Online || bs.Good
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type oneBotSender struct {
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
Card string `json:"card"`
|
||||
}
|
||||
|
||||
type oneBotAPIRequest struct {
|
||||
Action string `json:"action"`
|
||||
Params any `json:"params"`
|
||||
Echo string `json:"echo,omitempty"`
|
||||
}
|
||||
|
||||
type oneBotMessageSegment struct {
|
||||
Type string `json:"type"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
|
||||
base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
const dedupSize = 1024
|
||||
return &OneBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
dedup: make(map[string]struct{}, dedupSize),
|
||||
dedupRing: make([]string, dedupSize),
|
||||
dedupIdx: 0,
|
||||
pending: make(map[string]chan json.RawMessage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) {
|
||||
go func() {
|
||||
_, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{
|
||||
"message_id": messageID,
|
||||
"emoji_id": emojiID,
|
||||
"set": set,
|
||||
}, 5*time.Second)
|
||||
if err != nil {
|
||||
logger.DebugCF("onebot", "Failed to set emoji like", map[string]any{
|
||||
"message_id": messageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Start(ctx context.Context) error {
|
||||
if c.config.WSUrl == "" {
|
||||
return fmt.Errorf("OneBot ws_url not configured")
|
||||
}
|
||||
|
||||
logger.InfoCF("onebot", "Starting OneBot channel", map[string]any{
|
||||
"ws_url": c.config.WSUrl,
|
||||
})
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
c.fetchSelfID()
|
||||
}
|
||||
|
||||
if c.config.ReconnectInterval > 0 {
|
||||
go c.reconnectLoop()
|
||||
} else {
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("onebot", "OneBot channel started successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) connect() error {
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer.HandshakeTimeout = 10 * time.Second
|
||||
|
||||
header := make(map[string][]string)
|
||||
if c.config.AccessToken != "" {
|
||||
header["Authorization"] = []string{"Bearer " + c.config.AccessToken}
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(c.config.WSUrl, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.SetPongHandler(func(appData string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.pinger(conn)
|
||||
|
||||
logger.InfoC("onebot", "WebSocket connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) pinger(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.writeMu.Lock()
|
||||
err := conn.WriteMessage(websocket.PingMessage, nil)
|
||||
c.writeMu.Unlock()
|
||||
if err != nil {
|
||||
logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) fetchSelfID() {
|
||||
resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second)
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Failed to get_login_info", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type loginInfo struct {
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
}
|
||||
for _, extract := range []func() (*loginInfo, error){
|
||||
func() (*loginInfo, error) {
|
||||
var w struct {
|
||||
Data loginInfo `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(resp, &w)
|
||||
return &w.Data, err
|
||||
},
|
||||
func() (*loginInfo, error) {
|
||||
var f loginInfo
|
||||
err := json.Unmarshal(resp, &f)
|
||||
return &f, err
|
||||
},
|
||||
} {
|
||||
info, err := extract()
|
||||
if err != nil || len(info.UserID) == 0 {
|
||||
continue
|
||||
}
|
||||
if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 {
|
||||
atomic.StoreInt64(&c.selfID, uid)
|
||||
logger.InfoCF("onebot", "Bot self ID retrieved", map[string]any{
|
||||
"self_id": uid,
|
||||
"nickname": info.Nickname,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]any{
|
||||
"response": string(resp),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.Duration) (json.RawMessage, error) {
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("WebSocket not connected")
|
||||
}
|
||||
|
||||
echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1))
|
||||
|
||||
ch := make(chan json.RawMessage, 1)
|
||||
c.pendingMu.Lock()
|
||||
c.pending[echo] = ch
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.pendingMu.Lock()
|
||||
delete(c.pending, echo)
|
||||
c.pendingMu.Unlock()
|
||||
}()
|
||||
|
||||
req := oneBotAPIRequest{
|
||||
Action: action,
|
||||
Params: params,
|
||||
Echo: echo,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal API request: %w", err)
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write API request: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case resp := <-ch:
|
||||
return resp, nil
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("API request %s timed out after %v", action, timeout)
|
||||
case <-c.ctx.Done():
|
||||
return nil, fmt.Errorf("context cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) reconnectLoop() {
|
||||
interval := time.Duration(c.config.ReconnectInterval) * time.Second
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(interval):
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.InfoC("onebot", "Attempting to reconnect...")
|
||||
if err := c.connect(); err != nil {
|
||||
logger.ErrorCF("onebot", "Reconnect failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
c.fetchSelfID()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("onebot", "Stopping OneBot channel")
|
||||
c.setRunning(false)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.pendingMu.Lock()
|
||||
for echo, ch := range c.pending {
|
||||
close(ch)
|
||||
delete(c.pending, echo)
|
||||
}
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("OneBot channel not running")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("OneBot WebSocket not connected")
|
||||
}
|
||||
|
||||
action, params, err := c.buildSendRequest(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1))
|
||||
|
||||
req := oneBotAPIRequest{
|
||||
Action: action,
|
||||
Params: params,
|
||||
Echo: echo,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal OneBot request: %w", err)
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "Failed to send message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok {
|
||||
if mid, ok := msgID.(string); ok && mid != "" {
|
||||
c.setMsgEmojiLike(mid, 289, false)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment {
|
||||
var segments []oneBotMessageSegment
|
||||
|
||||
if lastMsgID, ok := c.lastMessageID.Load(chatID); ok {
|
||||
if msgID, ok := lastMsgID.(string); ok && msgID != "" {
|
||||
segments = append(segments, oneBotMessageSegment{
|
||||
Type: "reply",
|
||||
Data: map[string]any{"id": msgID},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
segments = append(segments, oneBotMessageSegment{
|
||||
Type: "text",
|
||||
Data: map[string]any{"text": content},
|
||||
})
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, any, error) {
|
||||
chatID := msg.ChatID
|
||||
segments := c.buildMessageSegments(chatID, msg.Content)
|
||||
|
||||
var action, idKey string
|
||||
var rawID string
|
||||
if rest, ok := strings.CutPrefix(chatID, "group:"); ok {
|
||||
action, idKey, rawID = "send_group_msg", "group_id", rest
|
||||
} else if rest, ok := strings.CutPrefix(chatID, "private:"); ok {
|
||||
action, idKey, rawID = "send_private_msg", "user_id", rest
|
||||
} else {
|
||||
action, idKey, rawID = "send_private_msg", "user_id", chatID
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(rawID, 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID)
|
||||
}
|
||||
return action, map[string]any{idKey: id, "message": segments}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) listen() {
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "WebSocket read error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
c.mu.Lock()
|
||||
if c.conn == conn {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
var raw oneBotRawEvent
|
||||
if err := json.Unmarshal(message, &raw); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]any{
|
||||
"error": err.Error(),
|
||||
"payload": string(message),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "WebSocket event", map[string]any{
|
||||
"length": len(message),
|
||||
"post_type": raw.PostType,
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
|
||||
if raw.Echo != "" {
|
||||
c.pendingMu.Lock()
|
||||
ch, ok := c.pending[raw.Echo]
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
if ok {
|
||||
select {
|
||||
case ch <- message:
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
logger.DebugCF("onebot", "Received API response (no waiter)", map[string]any{
|
||||
"echo": raw.Echo,
|
||||
"status": string(raw.Status),
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isAPIResponse(raw.Status) {
|
||||
logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]any{
|
||||
"status": string(raw.Status),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
c.handleRawEvent(&raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONInt64(raw json.RawMessage) (int64, error) {
|
||||
if len(raw) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var n int64
|
||||
if err := json.Unmarshal(raw, &n); err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
return 0, fmt.Errorf("cannot parse as int64: %s", string(raw))
|
||||
}
|
||||
|
||||
func parseJSONString(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
type parseMessageResult struct {
|
||||
Text string
|
||||
IsBotMentioned bool
|
||||
Media []string
|
||||
LocalFiles []string
|
||||
ReplyTo string
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult {
|
||||
if len(raw) == 0 {
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
mentioned := false
|
||||
if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(s, cqAt) {
|
||||
mentioned = true
|
||||
s = strings.ReplaceAll(s, cqAt, "")
|
||||
s = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return parseMessageResult{Text: s, IsBotMentioned: mentioned}
|
||||
}
|
||||
|
||||
var segments []map[string]any
|
||||
if err := json.Unmarshal(raw, &segments); err != nil {
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
var textParts []string
|
||||
mentioned := false
|
||||
selfIDStr := strconv.FormatInt(selfID, 10)
|
||||
var media []string
|
||||
var localFiles []string
|
||||
var replyTo string
|
||||
|
||||
for _, seg := range segments {
|
||||
segType, _ := seg["type"].(string)
|
||||
data, _ := seg["data"].(map[string]any)
|
||||
|
||||
switch segType {
|
||||
case "text":
|
||||
if data != nil {
|
||||
if t, ok := data["text"].(string); ok {
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
}
|
||||
|
||||
case "at":
|
||||
if data != nil && selfID > 0 {
|
||||
qqVal := fmt.Sprintf("%v", data["qq"])
|
||||
if qqVal == selfIDStr || qqVal == "all" {
|
||||
mentioned = true
|
||||
}
|
||||
}
|
||||
|
||||
case "image", "video", "file":
|
||||
if data != nil {
|
||||
url, _ := data["url"].(string)
|
||||
if url != "" {
|
||||
defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"}
|
||||
filename := defaults[segType]
|
||||
if f, ok := data["file"].(string); ok && f != "" {
|
||||
filename = f
|
||||
} else if n, ok := data["name"].(string); ok && n != "" {
|
||||
filename = n
|
||||
}
|
||||
localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "onebot",
|
||||
})
|
||||
if localPath != "" {
|
||||
media = append(media, localPath)
|
||||
localFiles = append(localFiles, localPath)
|
||||
textParts = append(textParts, fmt.Sprintf("[%s]", segType))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "record":
|
||||
if data != nil {
|
||||
url, _ := data["url"].(string)
|
||||
if url != "" {
|
||||
localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{
|
||||
LoggerPrefix: "onebot",
|
||||
})
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||
result, err := c.transcriber.Transcribe(tctx, localPath)
|
||||
tcancel()
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Voice transcription failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
textParts = append(textParts, "[voice (transcription failed)]")
|
||||
media = append(media, localPath)
|
||||
} else {
|
||||
textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text))
|
||||
}
|
||||
} else {
|
||||
textParts = append(textParts, "[voice]")
|
||||
media = append(media, localPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "reply":
|
||||
if data != nil {
|
||||
if id, ok := data["id"]; ok {
|
||||
replyTo = fmt.Sprintf("%v", id)
|
||||
}
|
||||
}
|
||||
|
||||
case "face":
|
||||
if data != nil {
|
||||
faceID, _ := data["id"]
|
||||
textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID))
|
||||
}
|
||||
|
||||
case "forward":
|
||||
textParts = append(textParts, "[forward message]")
|
||||
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return parseMessageResult{
|
||||
Text: strings.TrimSpace(strings.Join(textParts, "")),
|
||||
IsBotMentioned: mentioned,
|
||||
Media: media,
|
||||
LocalFiles: localFiles,
|
||||
ReplyTo: replyTo,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
|
||||
switch raw.PostType {
|
||||
case "message":
|
||||
if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 {
|
||||
if !c.IsAllowed(strconv.FormatInt(userID, 10)) {
|
||||
logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{
|
||||
"user_id": userID,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.handleMessage(raw)
|
||||
|
||||
case "message_sent":
|
||||
logger.DebugCF("onebot", "Bot sent message event", map[string]any{
|
||||
"message_type": raw.MessageType,
|
||||
"message_id": parseJSONString(raw.MessageID),
|
||||
})
|
||||
|
||||
case "meta_event":
|
||||
c.handleMetaEvent(raw)
|
||||
|
||||
case "notice":
|
||||
c.handleNoticeEvent(raw)
|
||||
|
||||
case "request":
|
||||
logger.DebugCF("onebot", "Request event received", map[string]any{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
|
||||
case "":
|
||||
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]any{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown post_type", map[string]any{
|
||||
"post_type": raw.PostType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
|
||||
if raw.MetaEventType == "lifecycle" {
|
||||
logger.InfoCF("onebot", "Lifecycle event", map[string]any{"sub_type": raw.SubType})
|
||||
} else if raw.MetaEventType != "heartbeat" {
|
||||
logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) {
|
||||
fields := map[string]any{
|
||||
"notice_type": raw.NoticeType,
|
||||
"sub_type": raw.SubType,
|
||||
"group_id": parseJSONString(raw.GroupID),
|
||||
"user_id": parseJSONString(raw.UserID),
|
||||
"message_id": parseJSONString(raw.MessageID),
|
||||
}
|
||||
switch raw.NoticeType {
|
||||
case "group_recall", "group_increase", "group_decrease",
|
||||
"friend_add", "group_admin", "group_ban":
|
||||
logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields)
|
||||
default:
|
||||
logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
// Parse fields from raw event
|
||||
userID, err := parseJSONInt64(raw.UserID)
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Failed to parse user_id", map[string]any{
|
||||
"error": err.Error(),
|
||||
"raw": string(raw.UserID),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
groupID, _ := parseJSONInt64(raw.GroupID)
|
||||
selfID, _ := parseJSONInt64(raw.SelfID)
|
||||
messageID := parseJSONString(raw.MessageID)
|
||||
|
||||
if selfID == 0 {
|
||||
selfID = atomic.LoadInt64(&c.selfID)
|
||||
}
|
||||
|
||||
parsed := c.parseMessageSegments(raw.Message, selfID)
|
||||
isBotMentioned := parsed.IsBotMentioned
|
||||
|
||||
content := raw.RawMessage
|
||||
if content == "" {
|
||||
content = parsed.Text
|
||||
} else if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(content, cqAt) {
|
||||
isBotMentioned = true
|
||||
content = strings.ReplaceAll(content, cqAt, "")
|
||||
content = strings.TrimSpace(content)
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") {
|
||||
content = parsed.Text
|
||||
}
|
||||
|
||||
var sender oneBotSender
|
||||
if len(raw.Sender) > 0 {
|
||||
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to parse sender", map[string]any{
|
||||
"error": err.Error(),
|
||||
"sender": string(raw.Sender),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up temp files when done
|
||||
if len(parsed.LocalFiles) > 0 {
|
||||
defer func() {
|
||||
for _, f := range parsed.LocalFiles {
|
||||
if err := os.Remove(f); err != nil {
|
||||
logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{
|
||||
"path": f,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if c.isDuplicate(messageID) {
|
||||
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{
|
||||
"message_id": messageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]any{
|
||||
"message_id": messageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := strconv.FormatInt(userID, 10)
|
||||
var chatID string
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": messageID,
|
||||
}
|
||||
|
||||
if parsed.ReplyTo != "" {
|
||||
metadata["reply_to_message_id"] = parsed.ReplyTo
|
||||
}
|
||||
|
||||
switch raw.MessageType {
|
||||
case "private":
|
||||
chatID = "private:" + senderID
|
||||
metadata["peer_kind"] = "direct"
|
||||
metadata["peer_id"] = senderID
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(groupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
metadata["peer_kind"] = "group"
|
||||
metadata["peer_id"] = groupIDStr
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(sender.UserID)
|
||||
if senderUserID > 0 {
|
||||
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
|
||||
}
|
||||
|
||||
if sender.Card != "" {
|
||||
metadata["sender_name"] = sender.Card
|
||||
} else if sender.Nickname != "" {
|
||||
metadata["sender_name"] = sender.Nickname
|
||||
}
|
||||
|
||||
triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned)
|
||||
if !triggered {
|
||||
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"is_mentioned": isBotMentioned,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
return
|
||||
}
|
||||
content = strippedContent
|
||||
|
||||
default:
|
||||
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]any{
|
||||
"type": raw.MessageType,
|
||||
"message_id": messageID,
|
||||
"user_id": userID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]any{
|
||||
"sender": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
"media_count": len(parsed.Media),
|
||||
})
|
||||
|
||||
if sender.Nickname != "" {
|
||||
metadata["nickname"] = sender.Nickname
|
||||
}
|
||||
|
||||
c.lastMessageID.Store(chatID, messageID)
|
||||
|
||||
if raw.MessageType == "group" && messageID != "" && messageID != "0" {
|
||||
c.setMsgEmojiLike(messageID, 289, true)
|
||||
c.pendingEmojiMsg.Store(chatID, messageID)
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, parsed.Media, metadata)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
if messageID == "" || messageID == "0" {
|
||||
return false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if _, exists := c.dedup[messageID]; exists {
|
||||
return true
|
||||
}
|
||||
|
||||
if old := c.dedupRing[c.dedupIdx]; old != "" {
|
||||
delete(c.dedup, old)
|
||||
}
|
||||
c.dedupRing[c.dedupIdx] = messageID
|
||||
c.dedup[messageID] = struct{}{}
|
||||
c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= n {
|
||||
return s
|
||||
}
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) checkGroupTrigger(
|
||||
content string,
|
||||
isBotMentioned bool,
|
||||
) (triggered bool, strippedContent string) {
|
||||
if isBotMentioned {
|
||||
return true, strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
for _, prefix := range c.config.GroupTriggerPrefix {
|
||||
if prefix == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(content, prefix) {
|
||||
return true, strings.TrimSpace(strings.TrimPrefix(content, prefix))
|
||||
}
|
||||
}
|
||||
|
||||
return false, content
|
||||
}
|
||||
@@ -1,247 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tencent-connect/botgo"
|
||||
"github.com/tencent-connect/botgo/dto"
|
||||
"github.com/tencent-connect/botgo/event"
|
||||
"github.com/tencent-connect/botgo/openapi"
|
||||
"github.com/tencent-connect/botgo/token"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type QQChannel struct {
|
||||
*BaseChannel
|
||||
config config.QQConfig
|
||||
api openapi.OpenAPI
|
||||
tokenSource oauth2.TokenSource
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
sessionManager botgo.SessionManager
|
||||
processedIDs map[string]bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) {
|
||||
base := NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &QQChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
processedIDs: make(map[string]bool),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *QQChannel) Start(ctx context.Context) error {
|
||||
if c.config.AppID == "" || c.config.AppSecret == "" {
|
||||
return fmt.Errorf("QQ app_id and app_secret not configured")
|
||||
}
|
||||
|
||||
logger.InfoC("qq", "Starting QQ bot (WebSocket mode)")
|
||||
|
||||
// create token source
|
||||
credentials := &token.QQBotCredentials{
|
||||
AppID: c.config.AppID,
|
||||
AppSecret: c.config.AppSecret,
|
||||
}
|
||||
c.tokenSource = token.NewQQBotTokenSource(credentials)
|
||||
|
||||
// create child context
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// start auto-refresh token goroutine
|
||||
if err := token.StartRefreshAccessToken(c.ctx, c.tokenSource); err != nil {
|
||||
return fmt.Errorf("failed to start token refresh: %w", err)
|
||||
}
|
||||
|
||||
// initialize OpenAPI client
|
||||
c.api = botgo.NewOpenAPI(c.config.AppID, c.tokenSource).WithTimeout(5 * time.Second)
|
||||
|
||||
// register event handlers
|
||||
intent := event.RegisterHandlers(
|
||||
c.handleC2CMessage(),
|
||||
c.handleGroupATMessage(),
|
||||
)
|
||||
|
||||
// get WebSocket endpoint
|
||||
wsInfo, err := c.api.WS(c.ctx, nil, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get websocket info: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoCF("qq", "Got WebSocket info", map[string]any{
|
||||
"shards": wsInfo.Shards,
|
||||
})
|
||||
|
||||
// create and save sessionManager
|
||||
c.sessionManager = botgo.NewSessionManager()
|
||||
|
||||
// start WebSocket connection in goroutine to avoid blocking
|
||||
go func() {
|
||||
if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil {
|
||||
logger.ErrorCF("qq", "WebSocket session error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
c.setRunning(false)
|
||||
}
|
||||
}()
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("qq", "QQ bot started successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *QQChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("qq", "Stopping QQ bot")
|
||||
c.setRunning(false)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("QQ bot not running")
|
||||
}
|
||||
|
||||
// construct message
|
||||
msgToCreate := &dto.MessageToCreate{
|
||||
Content: msg.Content,
|
||||
}
|
||||
|
||||
// send C2C message
|
||||
_, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate)
|
||||
if err != nil {
|
||||
logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleC2CMessage handles QQ private messages
|
||||
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
|
||||
// deduplication check
|
||||
if c.isDuplicate(data.ID) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extract user info
|
||||
var senderID string
|
||||
if data.Author != nil && data.Author.ID != "" {
|
||||
senderID = data.Author.ID
|
||||
} else {
|
||||
logger.WarnC("qq", "Received message with no sender ID")
|
||||
return nil
|
||||
}
|
||||
|
||||
// extract message content
|
||||
content := data.Content
|
||||
if content == "" {
|
||||
logger.DebugC("qq", "Received empty message, ignoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("qq", "Received C2C message", map[string]any{
|
||||
"sender": senderID,
|
||||
"length": len(content),
|
||||
})
|
||||
|
||||
// forward to message bus
|
||||
metadata := map[string]string{
|
||||
"message_id": data.ID,
|
||||
"peer_kind": "direct",
|
||||
"peer_id": senderID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, senderID, content, []string{}, metadata)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupATMessage handles group @messages
|
||||
func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error {
|
||||
// deduplication check
|
||||
if c.isDuplicate(data.ID) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extract user info
|
||||
var senderID string
|
||||
if data.Author != nil && data.Author.ID != "" {
|
||||
senderID = data.Author.ID
|
||||
} else {
|
||||
logger.WarnC("qq", "Received group message with no sender ID")
|
||||
return nil
|
||||
}
|
||||
|
||||
// extract message content (remove @bot part)
|
||||
content := data.Content
|
||||
if content == "" {
|
||||
logger.DebugC("qq", "Received empty group message, ignoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("qq", "Received group AT message", map[string]any{
|
||||
"sender": senderID,
|
||||
"group": data.GroupID,
|
||||
"length": len(content),
|
||||
})
|
||||
|
||||
// forward to message bus (use GroupID as ChatID)
|
||||
metadata := map[string]string{
|
||||
"message_id": data.ID,
|
||||
"group_id": data.GroupID,
|
||||
"peer_kind": "group",
|
||||
"peer_id": data.GroupID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// isDuplicate checks if message is duplicate
|
||||
func (c *QQChannel) isDuplicate(messageID string) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.processedIDs[messageID] {
|
||||
return true
|
||||
}
|
||||
|
||||
c.processedIDs[messageID] = true
|
||||
|
||||
// simple cleanup: limit map size
|
||||
if len(c.processedIDs) > 10000 {
|
||||
// clear half
|
||||
count := 0
|
||||
for id := range c.processedIDs {
|
||||
if count >= 5000 {
|
||||
break
|
||||
}
|
||||
delete(c.processedIDs, id)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,443 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/slack-go/slack"
|
||||
"github.com/slack-go/slack/slackevents"
|
||||
"github.com/slack-go/slack/socketmode"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type SlackChannel struct {
|
||||
*BaseChannel
|
||||
config config.SlackConfig
|
||||
api *slack.Client
|
||||
socketClient *socketmode.Client
|
||||
botUserID string
|
||||
teamID string
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
pendingAcks sync.Map
|
||||
}
|
||||
|
||||
type slackMessageRef struct {
|
||||
ChannelID string
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
|
||||
if cfg.BotToken == "" || cfg.AppToken == "" {
|
||||
return nil, fmt.Errorf("slack bot_token and app_token are required")
|
||||
}
|
||||
|
||||
api := slack.New(
|
||||
cfg.BotToken,
|
||||
slack.OptionAppLevelToken(cfg.AppToken),
|
||||
)
|
||||
|
||||
socketClient := socketmode.New(api)
|
||||
|
||||
base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &SlackChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
api: api,
|
||||
socketClient: socketClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *SlackChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
authResp, err := c.api.AuthTest()
|
||||
if err != nil {
|
||||
return fmt.Errorf("slack auth test failed: %w", err)
|
||||
}
|
||||
c.botUserID = authResp.UserID
|
||||
c.teamID = authResp.TeamID
|
||||
|
||||
logger.InfoCF("slack", "Slack bot connected", map[string]any{
|
||||
"bot_user_id": c.botUserID,
|
||||
"team": authResp.Team,
|
||||
})
|
||||
|
||||
go c.eventLoop()
|
||||
|
||||
go func() {
|
||||
if err := c.socketClient.RunContext(c.ctx); err != nil {
|
||||
if c.ctx.Err() == nil {
|
||||
logger.ErrorCF("slack", "Socket Mode connection error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("slack", "Slack channel started (Socket Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("slack", "Stopping Slack channel")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("slack", "Slack channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("slack channel not running")
|
||||
}
|
||||
|
||||
channelID, threadTS := parseSlackChatID(msg.ChatID)
|
||||
if channelID == "" {
|
||||
return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
|
||||
}
|
||||
|
||||
opts := []slack.MsgOption{
|
||||
slack.MsgOptionText(msg.Content, false),
|
||||
}
|
||||
|
||||
if threadTS != "" {
|
||||
opts = append(opts, slack.MsgOptionTS(threadTS))
|
||||
}
|
||||
|
||||
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send slack message: %w", err)
|
||||
}
|
||||
|
||||
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
|
||||
msgRef := ref.(slackMessageRef)
|
||||
c.api.AddReaction("white_check_mark", slack.ItemRef{
|
||||
Channel: msgRef.ChannelID,
|
||||
Timestamp: msgRef.Timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Message sent", map[string]any{
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) eventLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case event, ok := <-c.socketClient.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
switch event.Type {
|
||||
case socketmode.EventTypeEventsAPI:
|
||||
c.handleEventsAPI(event)
|
||||
case socketmode.EventTypeSlashCommand:
|
||||
c.handleSlashCommand(event)
|
||||
case socketmode.EventTypeInteractive:
|
||||
if event.Request != nil {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
|
||||
if event.Request != nil {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
|
||||
eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch ev := eventsAPIEvent.InnerEvent.Data.(type) {
|
||||
case *slackevents.MessageEvent:
|
||||
c.handleMessageEvent(ev)
|
||||
case *slackevents.AppMentionEvent:
|
||||
c.handleAppMention(ev)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
if ev.User == c.botUserID || ev.User == "" {
|
||||
return
|
||||
}
|
||||
if ev.BotID != "" {
|
||||
return
|
||||
}
|
||||
if ev.SubType != "" && ev.SubType != "file_share" {
|
||||
return
|
||||
}
|
||||
|
||||
// check allowlist to avoid downloading attachments for rejected users
|
||||
if !c.IsAllowed(ev.User) {
|
||||
logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{
|
||||
"user_id": ev.User,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
messageTS := ev.TimeStamp
|
||||
|
||||
chatID := channelID
|
||||
if threadTS != "" {
|
||||
chatID = channelID + "/" + threadTS
|
||||
}
|
||||
|
||||
c.api.AddReaction("eyes", slack.ItemRef{
|
||||
Channel: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
c.pendingAcks.Store(chatID, slackMessageRef{
|
||||
ChannelID: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
content := ev.Text
|
||||
content = c.stripBotMention(content)
|
||||
|
||||
var mediaPaths []string
|
||||
localFiles := []string{} // track local files that need cleanup
|
||||
|
||||
// ensure temp files are cleaned up when function returns
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if ev.Message != nil && len(ev.Message.Files) > 0 {
|
||||
for _, file := range ev.Message.Files {
|
||||
localPath := c.downloadSlackFile(file)
|
||||
if localPath == "" {
|
||||
continue
|
||||
}
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
|
||||
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()})
|
||||
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
|
||||
} else {
|
||||
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
|
||||
}
|
||||
} else {
|
||||
content += fmt.Sprintf("\n[file: %s]", file.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
"has_thread": threadTS != "",
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
if ev.User == c.botUserID {
|
||||
return
|
||||
}
|
||||
|
||||
if !c.IsAllowed(ev.User) {
|
||||
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{
|
||||
"user_id": ev.User,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
messageTS := ev.TimeStamp
|
||||
|
||||
var chatID string
|
||||
if threadTS != "" {
|
||||
chatID = channelID + "/" + threadTS
|
||||
} else {
|
||||
chatID = channelID + "/" + messageTS
|
||||
}
|
||||
|
||||
c.api.AddReaction("eyes", slack.ItemRef{
|
||||
Channel: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
c.pendingAcks.Store(chatID, slackMessageRef{
|
||||
ChannelID: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
content := c.stripBotMention(ev.Text)
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
mentionPeerKind := "channel"
|
||||
mentionPeerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
mentionPeerKind = "direct"
|
||||
mentionPeerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
"is_mention": "true",
|
||||
"peer_kind": mentionPeerKind,
|
||||
"peer_id": mentionPeerID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
cmd, ok := event.Data.(slack.SlashCommand)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if event.Request != nil {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
|
||||
if !c.IsAllowed(cmd.UserID) {
|
||||
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{
|
||||
"user_id": cmd.UserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := cmd.UserID
|
||||
channelID := cmd.ChannelID
|
||||
chatID := channelID
|
||||
content := cmd.Text
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "help"
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"channel_id": channelID,
|
||||
"platform": "slack",
|
||||
"is_command": "true",
|
||||
"trigger_id": cmd.TriggerID,
|
||||
"peer_kind": "channel",
|
||||
"peer_id": channelID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Slash command received", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"command": cmd.Command,
|
||||
"text": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||
downloadURL := file.URLPrivateDownload
|
||||
if downloadURL == "" {
|
||||
downloadURL = file.URLPrivate
|
||||
}
|
||||
if downloadURL == "" {
|
||||
logger.ErrorCF("slack", "No download URL for file", map[string]any{"file_id": file.ID})
|
||||
return ""
|
||||
}
|
||||
|
||||
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
|
||||
LoggerPrefix: "slack",
|
||||
ExtraHeaders: map[string]string{
|
||||
"Authorization": "Bearer " + c.config.BotToken,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *SlackChannel) stripBotMention(text string) string {
|
||||
mention := fmt.Sprintf("<@%s>", c.botUserID)
|
||||
text = strings.ReplaceAll(text, mention, "")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func parseSlackChatID(chatID string) (channelID, threadTS string) {
|
||||
parts := strings.SplitN(chatID, "/", 2)
|
||||
channelID = parts[0]
|
||||
if len(parts) > 1 {
|
||||
threadTS = parts[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestParseSlackChatID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chatID string
|
||||
wantChanID string
|
||||
wantThread string
|
||||
}{
|
||||
{
|
||||
name: "channel only",
|
||||
chatID: "C123456",
|
||||
wantChanID: "C123456",
|
||||
wantThread: "",
|
||||
},
|
||||
{
|
||||
name: "channel with thread",
|
||||
chatID: "C123456/1234567890.123456",
|
||||
wantChanID: "C123456",
|
||||
wantThread: "1234567890.123456",
|
||||
},
|
||||
{
|
||||
name: "DM channel",
|
||||
chatID: "D987654",
|
||||
wantChanID: "D987654",
|
||||
wantThread: "",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
chatID: "",
|
||||
wantChanID: "",
|
||||
wantThread: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chanID, threadTS := parseSlackChatID(tt.chatID)
|
||||
if chanID != tt.wantChanID {
|
||||
t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID)
|
||||
}
|
||||
if threadTS != tt.wantThread {
|
||||
t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBotMention(t *testing.T) {
|
||||
ch := &SlackChannel{botUserID: "U12345BOT"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "mention at start",
|
||||
input: "<@U12345BOT> hello there",
|
||||
want: "hello there",
|
||||
},
|
||||
{
|
||||
name: "mention in middle",
|
||||
input: "hey <@U12345BOT> can you help",
|
||||
want: "hey can you help",
|
||||
},
|
||||
{
|
||||
name: "no mention",
|
||||
input: "hello world",
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "only mention",
|
||||
input: "<@U12345BOT>",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ch.stripBotMention(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSlackChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing bot token", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "",
|
||||
AppToken: "xapp-test",
|
||||
}
|
||||
_, err := NewSlackChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing bot_token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing app token", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "",
|
||||
}
|
||||
_, err := NewSlackChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing app_token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "xapp-test",
|
||||
AllowFrom: []string{"U123"},
|
||||
}
|
||||
ch, err := NewSlackChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "slack" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "slack")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlackChannelIsAllowed(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("empty allowlist allows all", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "xapp-test",
|
||||
AllowFrom: []string{},
|
||||
}
|
||||
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("U_ANYONE") {
|
||||
t.Error("empty allowlist should allow all users")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allowlist restricts users", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "xapp-test",
|
||||
AllowFrom: []string{"U_ALLOWED"},
|
||||
}
|
||||
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("U_ALLOWED") {
|
||||
t.Error("allowed user should pass allowlist check")
|
||||
}
|
||||
if ch.IsAllowed("U_BLOCKED") {
|
||||
t.Error("non-allowed user should be blocked")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,529 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/mymmrac/telego/telegohandler"
|
||||
th "github.com/mymmrac/telego/telegohandler"
|
||||
tu "github.com/mymmrac/telego/telegoutil"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type TelegramChannel struct {
|
||||
*BaseChannel
|
||||
bot *telego.Bot
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
transcriber *voice.GroqTranscriber
|
||||
placeholders sync.Map // chatID -> messageID
|
||||
stopThinking sync.Map // chatID -> thinkingCancel
|
||||
}
|
||||
|
||||
type thinkingCancel struct {
|
||||
fn context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *thinkingCancel) Cancel() {
|
||||
if c != nil && c.fn != nil {
|
||||
c.fn()
|
||||
}
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
var opts []telego.BotOption
|
||||
telegramCfg := cfg.Channels.Telegram
|
||||
|
||||
if telegramCfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
|
||||
}
|
||||
opts = append(opts, telego.WithHTTPClient(&http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
},
|
||||
}))
|
||||
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
|
||||
// Use environment proxy if configured
|
||||
opts = append(opts, telego.WithHTTPClient(&http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
bot, err := telego.NewBot(telegramCfg.Token, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||
}
|
||||
|
||||
base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
commands: NewTelegramCommands(bot, cfg),
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
chatIDs: make(map[string]int64),
|
||||
transcriber: nil,
|
||||
placeholders: sync.Map{},
|
||||
stopThinking: sync.Map{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
bh, err := telegohandler.NewBotHandler(c.bot, updates)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bot handler: %w", err)
|
||||
}
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
c.commands.Help(ctx, message)
|
||||
return nil
|
||||
}, th.CommandEqual("help"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
}, th.CommandEqual("show"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.List(ctx, message)
|
||||
}, th.CommandEqual("list"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.handleMessage(ctx, &message)
|
||||
}, th.AnyMessage())
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoCF("telegram", "Telegram bot connected", map[string]any{
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go bh.Start()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
bh.Stop()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||
c.setRunning(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("telegram bot not running")
|
||||
}
|
||||
|
||||
chatID, err := parseChatID(msg.ChatID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid chat ID: %w", err)
|
||||
}
|
||||
|
||||
// Stop thinking animation
|
||||
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
|
||||
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
|
||||
cf.Cancel()
|
||||
}
|
||||
c.stopThinking.Delete(msg.ChatID)
|
||||
}
|
||||
|
||||
htmlContent := markdownToTelegramHTML(msg.Content)
|
||||
|
||||
// Try to edit placeholder
|
||||
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
|
||||
c.placeholders.Delete(msg.ChatID)
|
||||
editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent)
|
||||
editMsg.ParseMode = telego.ModeHTML
|
||||
|
||||
if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil {
|
||||
return nil
|
||||
}
|
||||
// Fallback to new message if edit fails
|
||||
}
|
||||
|
||||
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
|
||||
tgMsg.ParseMode = telego.ModeHTML
|
||||
|
||||
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
tgMsg.ParseMode = ""
|
||||
_, err = c.bot.SendMessage(ctx, tgMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
|
||||
if message == nil {
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
user := message.From
|
||||
if user == nil {
|
||||
return fmt.Errorf("message sender (user) is nil")
|
||||
}
|
||||
|
||||
senderID := fmt.Sprintf("%d", user.ID)
|
||||
if user.Username != "" {
|
||||
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
|
||||
}
|
||||
|
||||
// check allowlist to avoid downloading attachments for rejected users
|
||||
if !c.IsAllowed(senderID) {
|
||||
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{
|
||||
"user_id": senderID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
chatID := message.Chat.ID
|
||||
c.chatIDs[senderID] = chatID
|
||||
|
||||
content := ""
|
||||
mediaPaths := []string{}
|
||||
localFiles := []string{} // track local files that need cleanup
|
||||
|
||||
// ensure temp files are cleaned up when function returns
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if message.Text != "" {
|
||||
content += message.Text
|
||||
}
|
||||
|
||||
if message.Caption != "" {
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += message.Caption
|
||||
}
|
||||
|
||||
if len(message.Photo) > 0 {
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
photoPath := c.downloadPhoto(ctx, photo.FileID)
|
||||
if photoPath != "" {
|
||||
localFiles = append(localFiles, photoPath)
|
||||
mediaPaths = append(mediaPaths, photoPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += "[image: photo]"
|
||||
}
|
||||
}
|
||||
|
||||
if message.Voice != nil {
|
||||
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
|
||||
if voicePath != "" {
|
||||
localFiles = append(localFiles, voicePath)
|
||||
mediaPaths = append(mediaPaths, voicePath)
|
||||
|
||||
transcribedText := ""
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := c.transcriber.Transcribe(transcriberCtx, voicePath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
"path": voicePath,
|
||||
})
|
||||
transcribedText = "[voice (transcription failed)]"
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
||||
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{
|
||||
"text": result.Text,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = "[voice]"
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += transcribedText
|
||||
}
|
||||
}
|
||||
|
||||
if message.Audio != nil {
|
||||
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
|
||||
if audioPath != "" {
|
||||
localFiles = append(localFiles, audioPath)
|
||||
mediaPaths = append(mediaPaths, audioPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += "[audio]"
|
||||
}
|
||||
}
|
||||
|
||||
if message.Document != nil {
|
||||
docPath := c.downloadFile(ctx, message.Document.FileID, "")
|
||||
if docPath != "" {
|
||||
localFiles = append(localFiles, docPath)
|
||||
mediaPaths = append(mediaPaths, docPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += "[file]"
|
||||
}
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
logger.DebugCF("telegram", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": fmt.Sprintf("%d", chatID),
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Thinking indicator
|
||||
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Failed to send chat action", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Stop any previous thinking animation
|
||||
chatIDStr := fmt.Sprintf("%d", chatID)
|
||||
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
|
||||
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
|
||||
cf.Cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Create cancel function for thinking state
|
||||
_, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
|
||||
|
||||
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
|
||||
if err == nil {
|
||||
pID := pMsg.MessageID
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := fmt.Sprintf("%d", user.ID)
|
||||
if message.Chat.Type != "private" {
|
||||
peerKind = "group"
|
||||
peerID = fmt.Sprintf("%d", chatID)
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": fmt.Sprintf("%d", message.MessageID),
|
||||
"user_id": fmt.Sprintf("%d", user.ID),
|
||||
"username": user.Username,
|
||||
"first_name": user.FirstName,
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
}
|
||||
|
||||
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
return c.downloadFileWithInfo(file, ".jpg")
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
|
||||
if file.FilePath == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
url := c.bot.FileDownloadURL(file.FilePath)
|
||||
logger.DebugCF("telegram", "File URL", map[string]any{"url": url})
|
||||
|
||||
// Use FilePath as filename for better identification
|
||||
filename := file.FilePath + ext
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "telegram",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
|
||||
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Failed to get file", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
return c.downloadFileWithInfo(file, ext)
|
||||
}
|
||||
|
||||
func parseChatID(chatIDStr string) (int64, error) {
|
||||
var id int64
|
||||
_, err := fmt.Sscanf(chatIDStr, "%d", &id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
func markdownToTelegramHTML(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
codeBlocks := extractCodeBlocks(text)
|
||||
text = codeBlocks.text
|
||||
|
||||
inlineCodes := extractInlineCodes(text)
|
||||
text = inlineCodes.text
|
||||
|
||||
text = regexp.MustCompile(`^#{1,6}\s+(.+)$`).ReplaceAllString(text, "$1")
|
||||
|
||||
text = regexp.MustCompile(`^>\s*(.*)$`).ReplaceAllString(text, "$1")
|
||||
|
||||
text = escapeHTML(text)
|
||||
|
||||
text = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(text, `<a href="$2">$1</a>`)
|
||||
|
||||
text = regexp.MustCompile(`\*\*(.+?)\*\*`).ReplaceAllString(text, "<b>$1</b>")
|
||||
|
||||
text = regexp.MustCompile(`__(.+?)__`).ReplaceAllString(text, "<b>$1</b>")
|
||||
|
||||
reItalic := regexp.MustCompile(`_([^_]+)_`)
|
||||
text = reItalic.ReplaceAllStringFunc(text, func(s string) string {
|
||||
match := reItalic.FindStringSubmatch(s)
|
||||
if len(match) < 2 {
|
||||
return s
|
||||
}
|
||||
return "<i>" + match[1] + "</i>"
|
||||
})
|
||||
|
||||
text = regexp.MustCompile(`~~(.+?)~~`).ReplaceAllString(text, "<s>$1</s>")
|
||||
|
||||
text = regexp.MustCompile(`^[-*]\s+`).ReplaceAllString(text, "• ")
|
||||
|
||||
for i, code := range inlineCodes.codes {
|
||||
escaped := escapeHTML(code)
|
||||
text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("<code>%s</code>", escaped))
|
||||
}
|
||||
|
||||
for i, code := range codeBlocks.codes {
|
||||
escaped := escapeHTML(code)
|
||||
text = strings.ReplaceAll(
|
||||
text,
|
||||
fmt.Sprintf("\x00CB%d\x00", i),
|
||||
fmt.Sprintf("<pre><code>%s</code></pre>", escaped),
|
||||
)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
type codeBlockMatch struct {
|
||||
text string
|
||||
codes []string
|
||||
}
|
||||
|
||||
func extractCodeBlocks(text string) codeBlockMatch {
|
||||
re := regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```")
|
||||
matches := re.FindAllStringSubmatch(text, -1)
|
||||
|
||||
codes := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return codeBlockMatch{text: text, codes: codes}
|
||||
}
|
||||
|
||||
type inlineCodeMatch struct {
|
||||
text string
|
||||
codes []string
|
||||
}
|
||||
|
||||
func extractInlineCodes(text string) inlineCodeMatch {
|
||||
re := regexp.MustCompile("`([^`]+)`")
|
||||
matches := re.FindAllStringSubmatch(text, -1)
|
||||
|
||||
codes := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return inlineCodeMatch{text: text, codes: codes}
|
||||
}
|
||||
|
||||
func escapeHTML(text string) string {
|
||||
text = strings.ReplaceAll(text, "&", "&")
|
||||
text = strings.ReplaceAll(text, "<", "<")
|
||||
text = strings.ReplaceAll(text, ">", ">")
|
||||
return text
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type TelegramCommander interface {
|
||||
Help(ctx context.Context, message telego.Message) error
|
||||
Start(ctx context.Context, message telego.Message) error
|
||||
Show(ctx context.Context, message telego.Message) error
|
||||
List(ctx context.Context, message telego.Message) error
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
bot *telego.Bot
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
|
||||
return &cmd{
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func commandArgs(text string) string {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
|
||||
msg := `/start - Start the bot
|
||||
/help - Show this help message
|
||||
/show [model|channel] - Show current configuration
|
||||
/list [models|channels] - List available options
|
||||
`
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: msg,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Hello! I am PicoClaw 🦞",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /show [model|channel]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "model":
|
||||
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
|
||||
c.config.Agents.Defaults.Model,
|
||||
c.config.Agents.Defaults.Provider)
|
||||
case "channel":
|
||||
response = "Current Channel: telegram"
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) List(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /list [models|channels]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "models":
|
||||
provider := c.config.Agents.Defaults.Provider
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
|
||||
c.config.Agents.Defaults.Model, provider)
|
||||
|
||||
case "channels":
|
||||
var enabled []string
|
||||
if c.config.Channels.Telegram.Enabled {
|
||||
enabled = append(enabled, "telegram")
|
||||
}
|
||||
if c.config.Channels.WhatsApp.Enabled {
|
||||
enabled = append(enabled, "whatsapp")
|
||||
}
|
||||
if c.config.Channels.Feishu.Enabled {
|
||||
enabled = append(enabled, "feishu")
|
||||
}
|
||||
if c.config.Channels.Discord.Enabled {
|
||||
enabled = append(enabled, "discord")
|
||||
}
|
||||
if c.config.Channels.Slack.Enabled {
|
||||
enabled = append(enabled, "slack")
|
||||
}
|
||||
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
|
||||
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -1,605 +0,0 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// WeCom Bot (企业微信智能机器人) channel implementation
|
||||
// Uses webhook callback mode for receiving messages and webhook API for sending replies
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人)
|
||||
// Uses webhook callback mode - simpler than WeCom App but only supports passive replies
|
||||
type WeComBotChannel struct {
|
||||
*BaseChannel
|
||||
config config.WeComConfig
|
||||
server *http.Server
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
|
||||
msgMu sync.RWMutex
|
||||
}
|
||||
|
||||
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
|
||||
type WeComBotMessage struct {
|
||||
MsgID string `json:"msgid"`
|
||||
AIBotID string `json:"aibotid"`
|
||||
ChatID string `json:"chatid"` // Session ID, only present for group chats
|
||||
ChatType string `json:"chattype"` // "single" for DM, "group" for group chat
|
||||
From struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"from"`
|
||||
ResponseURL string `json:"response_url"`
|
||||
MsgType string `json:"msgtype"` // text, image, voice, file, mixed
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Image struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"image"`
|
||||
Voice struct {
|
||||
Content string `json:"content"` // Voice to text content
|
||||
} `json:"voice"`
|
||||
File struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"file"`
|
||||
Mixed struct {
|
||||
MsgItem []struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Image struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"image"`
|
||||
} `json:"msg_item"`
|
||||
} `json:"mixed"`
|
||||
Quote struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
} `json:"quote"`
|
||||
}
|
||||
|
||||
// WeComBotReplyMessage represents the reply message structure
|
||||
type WeComBotReplyMessage struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// NewWeComBotChannel creates a new WeCom Bot channel instance
|
||||
func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) {
|
||||
if cfg.Token == "" || cfg.WebhookURL == "" {
|
||||
return nil, fmt.Errorf("wecom token and webhook_url are required")
|
||||
}
|
||||
|
||||
base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &WeComBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
processedMsgs: make(map[string]bool),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the channel name
|
||||
func (c *WeComBotChannel) Name() string {
|
||||
return "wecom"
|
||||
}
|
||||
|
||||
// Start initializes the WeCom Bot channel with HTTP webhook server
|
||||
func (c *WeComBotChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Starting WeCom Bot channel...")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Setup HTTP server for webhook
|
||||
mux := http.NewServeMux()
|
||||
webhookPath := c.config.WebhookPath
|
||||
if webhookPath == "" {
|
||||
webhookPath = "/webhook/wecom"
|
||||
}
|
||||
mux.HandleFunc(webhookPath, c.handleWebhook)
|
||||
|
||||
// Health check endpoint
|
||||
mux.HandleFunc("/health/wecom", c.handleHealth)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
|
||||
c.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{
|
||||
"address": addr,
|
||||
"path": webhookPath,
|
||||
})
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("wecom", "HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the WeCom Bot channel
|
||||
func (c *WeComBotChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("wecom", "Stopping WeCom Bot channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.server != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
c.server.Shutdown(shutdownCtx)
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("wecom", "WeCom Bot channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to WeCom user via webhook API
|
||||
// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message
|
||||
// For delayed responses, we use the webhook URL
|
||||
func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("wecom channel not running")
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom", "Sending message via webhook", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
return c.sendWebhookReply(ctx, msg.ChatID, msg.Content)
|
||||
}
|
||||
|
||||
// handleWebhook handles incoming webhook requests from WeCom
|
||||
func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
// Handle verification request
|
||||
c.handleVerification(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
// Handle message callback
|
||||
c.handleMessageCallback(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handleVerification handles the URL verification request from WeCom
|
||||
func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
echostr := query.Get("echostr")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
|
||||
logger.WarnC("wecom", "Signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt echostr
|
||||
// For AIBOT (智能机器人), receiveid should be empty string ""
|
||||
// Reference: https://developer.work.weixin.qq.com/document/path/101033
|
||||
decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "")
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove BOM and whitespace as per WeCom documentation
|
||||
// The response must be plain text without quotes, BOM, or newlines
|
||||
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
|
||||
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
|
||||
w.Write([]byte(decryptedEchoStr))
|
||||
}
|
||||
|
||||
// handleMessageCallback handles incoming messages from WeCom
|
||||
func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Parse XML to get encrypted message
|
||||
var encryptedMsg struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
}
|
||||
|
||||
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid XML", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
|
||||
logger.WarnC("wecom", "Message signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt message
|
||||
// For AIBOT (智能机器人), receiveid should be empty string ""
|
||||
// Reference: https://developer.work.weixin.qq.com/document/path/101033
|
||||
decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "")
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted JSON message (AIBOT uses JSON format)
|
||||
var msg WeComBotMessage
|
||||
if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
|
||||
logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message asynchronously with context
|
||||
go c.processMessage(ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom Bot requires response within configured timeout (default 5 seconds)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
|
||||
// processMessage processes the received message
|
||||
func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) {
|
||||
// Skip unsupported message types
|
||||
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" &&
|
||||
msg.MsgType != "mixed" {
|
||||
logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{
|
||||
"msg_type": msg.MsgType,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
msgID := msg.MsgID
|
||||
c.msgMu.Lock()
|
||||
if c.processedMsgs[msgID] {
|
||||
c.msgMu.Unlock()
|
||||
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
c.msgMu.Unlock()
|
||||
|
||||
// Clean up old messages periodically (keep last 1000)
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.msgMu.Lock()
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.msgMu.Unlock()
|
||||
}
|
||||
|
||||
senderID := msg.From.UserID
|
||||
|
||||
// Determine if this is a group chat or direct message
|
||||
// ChatType: "single" for DM, "group" for group chat
|
||||
isGroupChat := msg.ChatType == "group"
|
||||
|
||||
var chatID, peerKind, peerID string
|
||||
if isGroupChat {
|
||||
// Group chat: use ChatID as chatID and peer_id
|
||||
chatID = msg.ChatID
|
||||
peerKind = "group"
|
||||
peerID = msg.ChatID
|
||||
} else {
|
||||
// Direct message: use senderID as chatID and peer_id
|
||||
chatID = senderID
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
// Extract content based on message type
|
||||
var content string
|
||||
switch msg.MsgType {
|
||||
case "text":
|
||||
content = msg.Text.Content
|
||||
case "voice":
|
||||
content = msg.Voice.Content // Voice to text content
|
||||
case "mixed":
|
||||
// For mixed messages, concatenate text items
|
||||
for _, item := range msg.Mixed.MsgItem {
|
||||
if item.MsgType == "text" {
|
||||
content += item.Text.Content
|
||||
}
|
||||
}
|
||||
case "image", "file":
|
||||
// For image and file, we don't have text content
|
||||
content = ""
|
||||
}
|
||||
|
||||
// Build metadata
|
||||
metadata := map[string]string{
|
||||
"msg_type": msg.MsgType,
|
||||
"msg_id": msg.MsgID,
|
||||
"platform": "wecom",
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
"response_url": msg.ResponseURL,
|
||||
}
|
||||
if isGroupChat {
|
||||
metadata["chat_id"] = msg.ChatID
|
||||
metadata["sender_id"] = senderID
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"msg_type": msg.MsgType,
|
||||
"peer_kind": peerKind,
|
||||
"is_group_chat": isGroupChat,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
}
|
||||
|
||||
// sendWebhookReply sends a reply using the webhook URL
|
||||
func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error {
|
||||
reply := WeComBotReplyMessage{
|
||||
MsgType: "text",
|
||||
}
|
||||
reply.Text.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(reply)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal reply: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send webhook reply: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Check response
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]any{
|
||||
"status": "ok",
|
||||
"running": c.IsRunning(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
// WeCom common utilities for both WeCom Bot and WeCom App
|
||||
// The following functions were moved from wecom_common.go
|
||||
|
||||
// WeComVerifySignature verifies the message signature for WeCom
|
||||
// This is a common function used by both WeCom Bot and WeCom App
|
||||
func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
|
||||
if token == "" {
|
||||
return true // Skip verification if token is not set
|
||||
}
|
||||
|
||||
// Sort parameters
|
||||
params := []string{token, timestamp, nonce, msgEncrypt}
|
||||
sort.Strings(params)
|
||||
|
||||
// Concatenate
|
||||
str := strings.Join(params, "")
|
||||
|
||||
// SHA1 hash
|
||||
hash := sha1.Sum([]byte(str))
|
||||
expectedSignature := fmt.Sprintf("%x", hash)
|
||||
|
||||
return expectedSignature == msgSignature
|
||||
}
|
||||
|
||||
// WeComDecryptMessage decrypts the encrypted message using AES
|
||||
// This is a common function used by both WeCom Bot and WeCom App
|
||||
// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id
|
||||
func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) {
|
||||
return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "")
|
||||
}
|
||||
|
||||
// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid
|
||||
// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification.
|
||||
func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) {
|
||||
if encodingAESKey == "" {
|
||||
// No encryption, return as is (base64 decode)
|
||||
decoded, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(decoded), nil
|
||||
}
|
||||
|
||||
// Decode AES key (base64)
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode AES key: %w", err)
|
||||
}
|
||||
|
||||
// Decode encrypted message
|
||||
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
|
||||
// AES decrypt
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
if len(cipherText) < aes.BlockSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// IV is the first 16 bytes of AESKey
|
||||
iv := aesKey[:aes.BlockSize]
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
plainText := make([]byte, len(cipherText))
|
||||
mode.CryptBlocks(plainText, cipherText)
|
||||
|
||||
// Remove PKCS7 padding
|
||||
plainText, err = pkcs7UnpadWeCom(plainText)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unpad: %w", err)
|
||||
}
|
||||
|
||||
// Parse message structure
|
||||
// Format: random(16) + msg_len(4) + msg + receiveid
|
||||
if len(plainText) < 20 {
|
||||
return "", fmt.Errorf("decrypted message too short")
|
||||
}
|
||||
|
||||
msgLen := binary.BigEndian.Uint32(plainText[16:20])
|
||||
if int(msgLen) > len(plainText)-20 {
|
||||
return "", fmt.Errorf("invalid message length")
|
||||
}
|
||||
|
||||
msg := plainText[20 : 20+msgLen]
|
||||
|
||||
// Verify receiveid if provided
|
||||
if receiveid != "" && len(plainText) > 20+int(msgLen) {
|
||||
actualReceiveID := string(plainText[20+msgLen:])
|
||||
if actualReceiveID != receiveid {
|
||||
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
|
||||
}
|
||||
}
|
||||
|
||||
return string(msg), nil
|
||||
}
|
||||
|
||||
// pkcs7UnpadWeCom removes PKCS7 padding with validation
|
||||
// WeCom uses block size of 32 (not standard AES block size of 16)
|
||||
const wecomBlockSize = 32
|
||||
|
||||
func pkcs7UnpadWeCom(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
padding := int(data[len(data)-1])
|
||||
// WeCom uses 32-byte block size for PKCS7 padding
|
||||
if padding == 0 || padding > wecomBlockSize {
|
||||
return nil, fmt.Errorf("invalid padding size: %d", padding)
|
||||
}
|
||||
if padding > len(data) {
|
||||
return nil, fmt.Errorf("padding size larger than data")
|
||||
}
|
||||
// Verify all padding bytes
|
||||
for i := 0; i < padding; i++ {
|
||||
if data[len(data)-1-i] != byte(padding) {
|
||||
return nil, fmt.Errorf("invalid padding byte at position %d", i)
|
||||
}
|
||||
}
|
||||
return data[:len(data)-padding], nil
|
||||
}
|
||||
@@ -1,639 +0,0 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// WeCom App (企业微信自建应用) channel implementation
|
||||
// Supports receiving messages via webhook callback and sending messages proactively
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomAPIBase = "https://qyapi.weixin.qq.com"
|
||||
)
|
||||
|
||||
// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用)
|
||||
type WeComAppChannel struct {
|
||||
*BaseChannel
|
||||
config config.WeComAppConfig
|
||||
server *http.Server
|
||||
accessToken string
|
||||
tokenExpiry time.Time
|
||||
tokenMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
|
||||
msgMu sync.RWMutex
|
||||
}
|
||||
|
||||
// WeComXMLMessage represents the XML message structure from WeCom
|
||||
type WeComXMLMessage struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Content string `xml:"Content"`
|
||||
MsgId int64 `xml:"MsgId"`
|
||||
AgentID int64 `xml:"AgentID"`
|
||||
PicUrl string `xml:"PicUrl"`
|
||||
MediaId string `xml:"MediaId"`
|
||||
Format string `xml:"Format"`
|
||||
ThumbMediaId string `xml:"ThumbMediaId"`
|
||||
LocationX float64 `xml:"Location_X"`
|
||||
LocationY float64 `xml:"Location_Y"`
|
||||
Scale int `xml:"Scale"`
|
||||
Label string `xml:"Label"`
|
||||
Title string `xml:"Title"`
|
||||
Description string `xml:"Description"`
|
||||
Url string `xml:"Url"`
|
||||
Event string `xml:"Event"`
|
||||
EventKey string `xml:"EventKey"`
|
||||
}
|
||||
|
||||
// WeComTextMessage represents text message for sending
|
||||
type WeComTextMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Safe int `json:"safe,omitempty"`
|
||||
}
|
||||
|
||||
// WeComMarkdownMessage represents markdown message for sending
|
||||
type WeComMarkdownMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Markdown struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"markdown"`
|
||||
}
|
||||
|
||||
// WeComImageMessage represents image message for sending
|
||||
type WeComImageMessage struct {
|
||||
ToUser string `json:"touser"`
|
||||
MsgType string `json:"msgtype"`
|
||||
AgentID int64 `json:"agentid"`
|
||||
Image struct {
|
||||
MediaID string `json:"media_id"`
|
||||
} `json:"image"`
|
||||
}
|
||||
|
||||
// WeComAccessTokenResponse represents the access token API response
|
||||
type WeComAccessTokenResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// WeComSendMessageResponse represents the send message API response
|
||||
type WeComSendMessageResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
InvalidUser string `json:"invaliduser"`
|
||||
InvalidParty string `json:"invalidparty"`
|
||||
InvalidTag string `json:"invalidtag"`
|
||||
}
|
||||
|
||||
// PKCS7Padding adds PKCS7 padding
|
||||
type PKCS7Padding struct{}
|
||||
|
||||
// NewWeComAppChannel creates a new WeCom App channel instance
|
||||
func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) {
|
||||
if cfg.CorpID == "" || cfg.CorpSecret == "" || cfg.AgentID == 0 {
|
||||
return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required")
|
||||
}
|
||||
|
||||
base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &WeComAppChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
processedMsgs: make(map[string]bool),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name returns the channel name
|
||||
func (c *WeComAppChannel) Name() string {
|
||||
return "wecom_app"
|
||||
}
|
||||
|
||||
// Start initializes the WeCom App channel with HTTP webhook server
|
||||
func (c *WeComAppChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Starting WeCom App channel...")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Get initial access token
|
||||
if err := c.refreshAccessToken(); err != nil {
|
||||
logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Start token refresh goroutine
|
||||
go c.tokenRefreshLoop()
|
||||
|
||||
// Setup HTTP server for webhook
|
||||
mux := http.NewServeMux()
|
||||
webhookPath := c.config.WebhookPath
|
||||
if webhookPath == "" {
|
||||
webhookPath = "/webhook/wecom-app"
|
||||
}
|
||||
mux.HandleFunc(webhookPath, c.handleWebhook)
|
||||
|
||||
// Health check endpoint
|
||||
mux.HandleFunc("/health/wecom-app", c.handleHealth)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
|
||||
c.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{
|
||||
"address": addr,
|
||||
"path": webhookPath,
|
||||
})
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the WeCom App channel
|
||||
func (c *WeComAppChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("wecom_app", "Stopping WeCom App channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.server != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
c.server.Shutdown(shutdownCtx)
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("wecom_app", "WeCom App channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to WeCom user proactively using access token
|
||||
func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("wecom_app channel not running")
|
||||
}
|
||||
|
||||
accessToken := c.getAccessToken()
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("no valid access token available")
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom_app", "Sending message", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content)
|
||||
}
|
||||
|
||||
// handleWebhook handles incoming webhook requests from WeCom
|
||||
func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Log all incoming requests for debugging
|
||||
logger.DebugCF("wecom_app", "Received webhook request", map[string]any{
|
||||
"method": r.Method,
|
||||
"url": r.URL.String(),
|
||||
"path": r.URL.Path,
|
||||
"query": r.URL.RawQuery,
|
||||
})
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
// Handle verification request
|
||||
c.handleVerification(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
// Handle message callback
|
||||
c.handleMessageCallback(ctx, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
logger.WarnCF("wecom_app", "Method not allowed", map[string]any{
|
||||
"method": r.Method,
|
||||
})
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handleVerification handles the URL verification request from WeCom
|
||||
func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
echostr := query.Get("echostr")
|
||||
|
||||
logger.DebugCF("wecom_app", "Handling verification request", map[string]any{
|
||||
"msg_signature": msgSignature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"echostr": echostr,
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
|
||||
logger.ErrorC("wecom_app", "Missing parameters in verification request")
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
|
||||
logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{
|
||||
"token": c.config.Token,
|
||||
"msg_signature": msgSignature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
})
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugC("wecom_app", "Signature verification passed")
|
||||
|
||||
// Decrypt echostr with CorpID verification
|
||||
// For WeCom App (自建应用), receiveid should be corp_id
|
||||
logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{
|
||||
"encoding_aes_key": c.config.EncodingAESKey,
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{
|
||||
"error": err.Error(),
|
||||
"encoding_aes_key": c.config.EncodingAESKey,
|
||||
"corp_id": c.config.CorpID,
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{
|
||||
"decrypted": decryptedEchoStr,
|
||||
})
|
||||
|
||||
// Remove BOM and whitespace as per WeCom documentation
|
||||
// The response must be plain text without quotes, BOM, or newlines
|
||||
decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
|
||||
decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
|
||||
w.Write([]byte(decryptedEchoStr))
|
||||
}
|
||||
|
||||
// handleMessageCallback handles incoming messages from WeCom
|
||||
func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
msgSignature := query.Get("msg_signature")
|
||||
timestamp := query.Get("timestamp")
|
||||
nonce := query.Get("nonce")
|
||||
|
||||
if msgSignature == "" || timestamp == "" || nonce == "" {
|
||||
http.Error(w, "Missing parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Parse XML to get encrypted message
|
||||
var encryptedMsg struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
}
|
||||
|
||||
if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid XML", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
|
||||
logger.WarnC("wecom_app", "Message signature verification failed")
|
||||
http.Error(w, "Invalid signature", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt message with CorpID verification
|
||||
// For WeCom App (自建应用), receiveid should be corp_id
|
||||
decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Decryption failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted XML message
|
||||
var msg WeComXMLMessage
|
||||
if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Invalid message format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the message with context
|
||||
go c.processMessage(ctx, msg)
|
||||
|
||||
// Return success response immediately
|
||||
// WeCom App requires response within configured timeout (default 5 seconds)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
|
||||
// processMessage processes the received message
|
||||
func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) {
|
||||
// Skip non-text messages for now (can be extended)
|
||||
if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" {
|
||||
logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{
|
||||
"msg_type": msg.MsgType,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
// As per WeCom documentation, use msg_id for deduplication
|
||||
msgID := fmt.Sprintf("%d", msg.MsgId)
|
||||
c.msgMu.Lock()
|
||||
if c.processedMsgs[msgID] {
|
||||
c.msgMu.Unlock()
|
||||
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
c.msgMu.Unlock()
|
||||
|
||||
// Clean up old messages periodically (keep last 1000)
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.msgMu.Lock()
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.msgMu.Unlock()
|
||||
}
|
||||
|
||||
senderID := msg.FromUserName
|
||||
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
|
||||
|
||||
// Build metadata
|
||||
// WeCom App only supports direct messages (private chat)
|
||||
metadata := map[string]string{
|
||||
"msg_type": msg.MsgType,
|
||||
"msg_id": fmt.Sprintf("%d", msg.MsgId),
|
||||
"agent_id": fmt.Sprintf("%d", msg.AgentID),
|
||||
"platform": "wecom_app",
|
||||
"media_id": msg.MediaId,
|
||||
"create_time": fmt.Sprintf("%d", msg.CreateTime),
|
||||
"peer_kind": "direct",
|
||||
"peer_id": senderID,
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
|
||||
logger.DebugCF("wecom_app", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"msg_type": msg.MsgType,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
}
|
||||
|
||||
// tokenRefreshLoop periodically refreshes the access token
|
||||
func (c *WeComAppChannel) tokenRefreshLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.refreshAccessToken(); err != nil {
|
||||
logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshAccessToken gets a new access token from WeCom API
|
||||
func (c *WeComAppChannel) refreshAccessToken() error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
|
||||
wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret))
|
||||
|
||||
resp, err := http.Get(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to request access token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp WeComAccessTokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode)
|
||||
}
|
||||
|
||||
c.tokenMu.Lock()
|
||||
c.accessToken = tokenResp.AccessToken
|
||||
c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early
|
||||
c.tokenMu.Unlock()
|
||||
|
||||
logger.DebugC("wecom_app", "Access token refreshed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccessToken returns the current valid access token
|
||||
func (c *WeComAppChannel) getAccessToken() string {
|
||||
c.tokenMu.RLock()
|
||||
defer c.tokenMu.RUnlock()
|
||||
|
||||
if time.Now().After(c.tokenExpiry) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return c.accessToken
|
||||
}
|
||||
|
||||
// sendTextMessage sends a text message to a user
|
||||
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComTextMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "text",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Text.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var sendResp WeComSendMessageResponse
|
||||
if err := json.Unmarshal(body, &sendResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if sendResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendMarkdownMessage sends a markdown message to a user
|
||||
func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, userID, content string) error {
|
||||
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
|
||||
|
||||
msg := WeComMarkdownMessage{
|
||||
ToUser: userID,
|
||||
MsgType: "markdown",
|
||||
AgentID: c.config.AgentID,
|
||||
}
|
||||
msg.Markdown.Content = content
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
// Use configurable timeout (default 5 seconds)
|
||||
timeout := c.config.ReplyTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = 5
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var sendResp WeComSendMessageResponse
|
||||
if err := json.Unmarshal(body, &sendResp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if sendResp.ErrCode != 0 {
|
||||
return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests
|
||||
func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := map[string]any{
|
||||
"status": "ok",
|
||||
"running": c.IsRunning(),
|
||||
"has_token": c.getAccessToken() != "",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,785 +0,0 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// WeCom Bot (企业微信智能机器人) channel tests
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// generateTestAESKey generates a valid test AES key
|
||||
func generateTestAESKey() string {
|
||||
// AES key needs to be 32 bytes (256 bits) for AES-256
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
// Return base64 encoded key without padding
|
||||
return base64.StdEncoding.EncodeToString(key)[:43]
|
||||
}
|
||||
|
||||
// encryptTestMessage encrypts a message for testing (AIBOT JSON format)
|
||||
func encryptTestMessage(message, aesKey string) (string, error) {
|
||||
// Decode AES key
|
||||
key, err := base64.StdEncoding.DecodeString(aesKey + "=")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prepare message: random(16) + msg_len(4) + msg + receiveid
|
||||
random := make([]byte, 0, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
random = append(random, byte(i))
|
||||
}
|
||||
|
||||
msgBytes := []byte(message)
|
||||
receiveID := []byte("test_aibot_id")
|
||||
|
||||
msgLen := uint32(len(msgBytes))
|
||||
lenBytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(lenBytes, msgLen)
|
||||
|
||||
plainText := append(random, lenBytes...)
|
||||
plainText = append(plainText, msgBytes...)
|
||||
plainText = append(plainText, receiveID...)
|
||||
|
||||
// PKCS7 padding
|
||||
blockSize := aes.BlockSize
|
||||
padding := blockSize - len(plainText)%blockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
plainText = append(plainText, padText...)
|
||||
|
||||
// Encrypt
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize])
|
||||
cipherText := make([]byte, len(plainText))
|
||||
mode.CryptBlocks(cipherText, plainText)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// generateSignature generates a signature for testing
|
||||
func generateSignature(token, timestamp, nonce, msgEncrypt string) string {
|
||||
params := []string{token, timestamp, nonce, msgEncrypt}
|
||||
sort.Strings(params)
|
||||
str := strings.Join(params, "")
|
||||
hash := sha1.Sum([]byte(str))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
func TestNewWeComBotChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing token", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
_, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing webhook_url", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "",
|
||||
}
|
||||
_, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing webhook_url, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
AllowFrom: []string{"user1", "user2"},
|
||||
}
|
||||
ch, err := NewWeComBotChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "wecom" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "wecom")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotChannelIsAllowed(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("empty allowlist allows all", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
AllowFrom: []string{},
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("any_user") {
|
||||
t.Error("empty allowlist should allow all users")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allowlist restricts users", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
AllowFrom: []string{"allowed_user"},
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("allowed_user") {
|
||||
t.Error("allowed user should pass allowlist check")
|
||||
}
|
||||
if ch.IsAllowed("blocked_user") {
|
||||
t.Error("non-allowed user should be blocked")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotVerifySignature(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid signature", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
msgEncrypt := "test_message"
|
||||
expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt)
|
||||
|
||||
if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
|
||||
t.Error("valid signature should pass verification")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
msgEncrypt := "test_message"
|
||||
|
||||
if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
|
||||
t.Error("invalid signature should fail verification")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty token skips verification", func(t *testing.T) {
|
||||
// Create a channel manually with empty token to test the behavior
|
||||
cfgEmpty := config.WeComConfig{
|
||||
Token: "",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
chEmpty := &WeComBotChannel{
|
||||
config: cfgEmpty,
|
||||
}
|
||||
|
||||
if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
|
||||
t.Error("empty token should skip verification and return true")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotDecryptMessage(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("decrypt without AES key", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
EncodingAESKey: "",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
// Without AES key, message should be base64 decoded only
|
||||
plainText := "hello world"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
|
||||
|
||||
result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != plainText {
|
||||
t.Errorf("decryptMessage() = %q, want %q", result, plainText)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decrypt with AES key", func(t *testing.T) {
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
EncodingAESKey: aesKey,
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
originalMsg := "<xml><Content>Hello</Content></xml>"
|
||||
encrypted, err := encryptTestMessage(originalMsg, aesKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt test message: %v", err)
|
||||
}
|
||||
|
||||
result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != originalMsg {
|
||||
t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid base64", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
EncodingAESKey: "",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
_, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid base64, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid AES key", func(t *testing.T) {
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
EncodingAESKey: "invalid_key",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
_, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid AES key, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotPKCS7Unpad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "valid padding 3 bytes",
|
||||
input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
|
||||
expected: []byte("hello"),
|
||||
},
|
||||
{
|
||||
name: "valid padding 16 bytes (full block)",
|
||||
input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("123456789012345"),
|
||||
},
|
||||
{
|
||||
name: "invalid padding larger than data",
|
||||
input: []byte{20},
|
||||
expected: nil, // should return error
|
||||
},
|
||||
{
|
||||
name: "invalid padding zero",
|
||||
input: append([]byte("test"), byte(0)),
|
||||
expected: nil, // should return error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs7UnpadWeCom(tt.input)
|
||||
if tt.expected == nil {
|
||||
// This case should return an error
|
||||
if err == nil {
|
||||
t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotHandleVerification(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
EncodingAESKey: aesKey,
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid verification request", func(t *testing.T) {
|
||||
echostr := "test_echostr_123"
|
||||
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != echostr {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), echostr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing parameters", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
echostr := "test_echostr"
|
||||
encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleVerification(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleMessageCallback(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
aesKey := generateTestAESKey()
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
EncodingAESKey: aesKey,
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("valid direct message callback", func(t *testing.T) {
|
||||
// Create JSON message for direct chat (single)
|
||||
jsonMsg := `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid group message callback", func(t *testing.T) {
|
||||
// Create JSON message for group chat
|
||||
jsonMsg := `{
|
||||
"msgid": "test_msg_id_456",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user456"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello Group"}
|
||||
}`
|
||||
|
||||
// Encrypt message
|
||||
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
|
||||
|
||||
// Create encrypted XML wrapper
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: encrypted,
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encrypted)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("response body = %q, want %q", w.Body.String(), "success")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing parameters", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid XML", func(t *testing.T) {
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, "")
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
strings.NewReader("invalid xml"),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid signature", func(t *testing.T) {
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: "encrypted_data",
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleMessageCallback(context.Background(), w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotProcessMessage(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("process direct text message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_123",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
msg.Text.Content = "Hello World"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("process group text message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_456",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatID: "group_chat_id_123",
|
||||
ChatType: "group",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.From.UserID = "user456"
|
||||
msg.Text.Content = "Hello Group"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("process voice message", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_789",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "voice",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
msg.Voice.Content = "Voice message text"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
|
||||
t.Run("skip unsupported message type", func(t *testing.T) {
|
||||
msg := WeComBotMessage{
|
||||
MsgID: "test_msg_id_000",
|
||||
AIBotID: "test_aibot_id",
|
||||
ChatType: "single",
|
||||
ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
MsgType: "video",
|
||||
}
|
||||
msg.From.UserID = "user123"
|
||||
|
||||
// Should not panic
|
||||
ch.processMessage(context.Background(), msg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleWebhook(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
t.Run("GET request calls verification", func(t *testing.T) {
|
||||
echostr := "test_echostr"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(echostr))
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encoded)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded,
|
||||
nil,
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("POST request calls message callback", func(t *testing.T) {
|
||||
encryptedWrapper := struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
}{
|
||||
Encrypt: base64.StdEncoding.EncodeToString([]byte("test")),
|
||||
}
|
||||
wrapperData, _ := xml.Marshal(encryptedWrapper)
|
||||
|
||||
timestamp := "1234567890"
|
||||
nonce := "test_nonce"
|
||||
signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt)
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
|
||||
bytes.NewReader(wrapperData),
|
||||
)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
// Should not be method not allowed
|
||||
if w.Code == http.StatusMethodNotAllowed {
|
||||
t.Error("POST request should not return Method Not Allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported method", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleWebhook(w, req)
|
||||
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWeComBotHandleHealth(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.WeComConfig{
|
||||
Token: "test_token",
|
||||
WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
}
|
||||
ch, _ := NewWeComBotChannel(cfg, msgBus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ch.handleHealth(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want %q", contentType, "application/json")
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "status") || !strings.Contains(body, "running") {
|
||||
t.Errorf("response body should contain status and running fields, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotReplyMessage(t *testing.T) {
|
||||
msg := WeComBotReplyMessage{
|
||||
MsgType: "text",
|
||||
}
|
||||
msg.Text.Content = "Hello World"
|
||||
|
||||
if msg.MsgType != "text" {
|
||||
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
|
||||
}
|
||||
if msg.Text.Content != "Hello World" {
|
||||
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeComBotMessageStructure(t *testing.T) {
|
||||
jsonData := `{
|
||||
"msgid": "test_msg_id_123",
|
||||
"aibotid": "test_aibot_id",
|
||||
"chatid": "group_chat_id_123",
|
||||
"chattype": "group",
|
||||
"from": {"userid": "user123"},
|
||||
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
|
||||
"msgtype": "text",
|
||||
"text": {"content": "Hello World"}
|
||||
}`
|
||||
|
||||
var msg WeComBotMessage
|
||||
err := json.Unmarshal([]byte(jsonData), &msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
if msg.MsgID != "test_msg_id_123" {
|
||||
t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123")
|
||||
}
|
||||
if msg.AIBotID != "test_aibot_id" {
|
||||
t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id")
|
||||
}
|
||||
if msg.ChatID != "group_chat_id_123" {
|
||||
t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123")
|
||||
}
|
||||
if msg.ChatType != "group" {
|
||||
t.Errorf("ChatType = %q, want %q", msg.ChatType, "group")
|
||||
}
|
||||
if msg.From.UserID != "user123" {
|
||||
t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123")
|
||||
}
|
||||
if msg.MsgType != "text" {
|
||||
t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
|
||||
}
|
||||
if msg.Text.Content != "Hello World" {
|
||||
t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
|
||||
}
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type WhatsAppChannel struct {
|
||||
*BaseChannel
|
||||
conn *websocket.Conn
|
||||
config config.WhatsAppConfig
|
||||
url string
|
||||
mu sync.Mutex
|
||||
connected bool
|
||||
}
|
||||
|
||||
func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) {
|
||||
base := NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom)
|
||||
|
||||
return &WhatsAppChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
url: cfg.BridgeURL,
|
||||
connected: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) Start(ctx context.Context) error {
|
||||
log.Printf("Starting WhatsApp channel connecting to %s...", c.url)
|
||||
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer.HandshakeTimeout = 10 * time.Second
|
||||
|
||||
conn, _, err := dialer.Dial(c.url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.connected = true
|
||||
c.mu.Unlock()
|
||||
|
||||
c.setRunning(true)
|
||||
log.Println("WhatsApp channel connected")
|
||||
|
||||
go c.listen(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) Stop(ctx context.Context) error {
|
||||
log.Println("Stopping WhatsApp channel...")
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Printf("Error closing WhatsApp connection: %v", err)
|
||||
}
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
c.connected = false
|
||||
c.setRunning(false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("whatsapp connection not established")
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message",
|
||||
"to": msg.ChatID,
|
||||
"content": msg.Content,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %w", err)
|
||||
}
|
||||
|
||||
if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("WhatsApp read error: %v", err)
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
log.Printf("Failed to unmarshal WhatsApp message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, ok := msg["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if msgType == "message" {
|
||||
c.handleIncomingMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
senderID, ok := msg["from"].(string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
chatID, ok := msg["chat"].(string)
|
||||
if !ok {
|
||||
chatID = senderID
|
||||
}
|
||||
|
||||
content, ok := msg["content"].(string)
|
||||
if !ok {
|
||||
content = ""
|
||||
}
|
||||
|
||||
var mediaPaths []string
|
||||
if mediaData, ok := msg["media"].([]any); ok {
|
||||
mediaPaths = make([]string, 0, len(mediaData))
|
||||
for _, m := range mediaData {
|
||||
if path, ok := m.(string); ok {
|
||||
mediaPaths = append(mediaPaths, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metadata := make(map[string]string)
|
||||
if messageID, ok := msg["id"].(string); ok {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
if userName, ok := msg["from_name"].(string); ok {
|
||||
metadata["user_name"] = userName
|
||||
}
|
||||
|
||||
if chatID == senderID {
|
||||
metadata["peer_kind"] = "direct"
|
||||
metadata["peer_id"] = senderID
|
||||
} else {
|
||||
metadata["peer_kind"] = "group"
|
||||
metadata["peer_id"] = chatID
|
||||
}
|
||||
|
||||
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
Reference in New Issue
Block a user