mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into refactor-inbound-context-routing-session
# Conflicts: # pkg/agent/eventbus_test.go # pkg/agent/loop.go # pkg/bus/bus.go # pkg/bus/types.go # pkg/channels/pico/pico.go # pkg/channels/telegram/telegram.go # pkg/config/config.go # web/backend/api/session.go # web/backend/api/session_test.go
This commit is contained in:
@@ -3,6 +3,7 @@ package discord
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -14,6 +15,8 @@ import (
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio"
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -42,6 +45,15 @@ type DiscordChannel struct {
|
||||
typingMu sync.Mutex
|
||||
typingStop map[string]chan struct{} // chatID → stop signal
|
||||
botUserID string // stored for mention checking
|
||||
bus *bus.MessageBus
|
||||
tts tts.TTSProvider
|
||||
voiceMu sync.RWMutex
|
||||
voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID
|
||||
|
||||
// TTS interruption: cancel active playback when user speaks
|
||||
ttsMu sync.Mutex
|
||||
cancelTTS context.CancelFunc
|
||||
ttsPlayID uint64
|
||||
}
|
||||
|
||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||
@@ -73,6 +85,8 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
config: cfg,
|
||||
ctx: context.Background(),
|
||||
typingStop: make(map[string]chan struct{}),
|
||||
bus: bus,
|
||||
voiceSSRC: make(map[string]map[uint32]string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -90,6 +104,8 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
|
||||
c.session.AddHandler(c.handleMessage)
|
||||
|
||||
go c.listenVoiceControl(c.ctx)
|
||||
|
||||
if err := c.session.Open(); err != nil {
|
||||
return fmt.Errorf("failed to open discord session: %w", err)
|
||||
}
|
||||
@@ -142,6 +158,25 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if c.tts != nil {
|
||||
if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" {
|
||||
if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil {
|
||||
// Cancel any previous TTS playback
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
}
|
||||
ttsCtx, ttsCancel := context.WithCancel(c.ctx)
|
||||
c.ttsPlayID++
|
||||
playID := c.ttsPlayID
|
||||
c.cancelTTS = ttsCancel
|
||||
c.ttsMu.Unlock()
|
||||
|
||||
go c.playTTS(ttsCtx, vc, msg.Content, playID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -359,6 +394,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
return
|
||||
}
|
||||
|
||||
if c.handleVoiceCommand(s, m) {
|
||||
return
|
||||
}
|
||||
|
||||
content := m.Content
|
||||
|
||||
// In guild (group) channels, apply unified group trigger filtering
|
||||
@@ -642,3 +681,134 @@ func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) listenVoiceControl(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case ctrl, ok := <-c.bus.VoiceControlsChan():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if ctrl.Type == "command" && ctrl.Action == "leave" {
|
||||
if strings.HasPrefix(ctrl.SessionID, "discord_vc_") {
|
||||
guildID := strings.TrimPrefix(ctrl.SessionID, "discord_vc_")
|
||||
vc, exists := c.session.VoiceConnections[guildID]
|
||||
if exists && vc != nil {
|
||||
vc.Disconnect(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) playTTS(ctx context.Context, vc *discordgo.VoiceConnection, text string, playID uint64) {
|
||||
// Capture the cancel func associated with this playback (if any).
|
||||
// Clear cancelTTS when playback finishes (normal or interrupted),
|
||||
// but only if it still refers to this playback's cancel func.
|
||||
defer func() {
|
||||
c.ttsMu.Lock()
|
||||
if c.ttsPlayID == playID {
|
||||
c.cancelTTS = nil
|
||||
}
|
||||
c.ttsMu.Unlock()
|
||||
}()
|
||||
|
||||
sentences := audio.SplitSentences(text)
|
||||
if len(sentences) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("discord", "Starting streamed TTS", map[string]any{"sentences": len(sentences)})
|
||||
|
||||
// Pipeline: prefetch next sentence's audio while playing current
|
||||
type ttResult struct {
|
||||
stream io.ReadCloser
|
||||
err error
|
||||
}
|
||||
|
||||
var prefetch chan ttResult
|
||||
|
||||
// Ensure any in-flight prefetch is drained on exit to prevent stream leaks,
|
||||
// but avoid blocking indefinitely if the prefetch goroutine is stuck or never sends.
|
||||
defer func() {
|
||||
if prefetch != nil {
|
||||
select {
|
||||
case result := <-prefetch:
|
||||
if result.stream != nil {
|
||||
result.stream.Close()
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Timed out waiting for a prefetched result; avoid blocking on exit.
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for i, sentence := range sentences {
|
||||
// Check for cancellation (interruption)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.InfoCF("discord", "TTS interrupted", map[string]any{"at_sentence": i})
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Start prefetching the NEXT sentence while we process the current one
|
||||
var nextPrefetch chan ttResult
|
||||
if i+1 < len(sentences) {
|
||||
nextPrefetch = make(chan ttResult, 1)
|
||||
nextSentence := sentences[i+1]
|
||||
go func() {
|
||||
s, e := c.tts.Synthesize(ctx, nextSentence)
|
||||
nextPrefetch <- ttResult{s, e}
|
||||
}()
|
||||
}
|
||||
|
||||
// Get the current sentence's audio
|
||||
var stream io.ReadCloser
|
||||
var err error
|
||||
|
||||
if prefetch != nil {
|
||||
// Use prefetched result from previous iteration, but be responsive to cancellation.
|
||||
var result ttResult
|
||||
select {
|
||||
case result = <-prefetch:
|
||||
stream, err = result.stream, result.err
|
||||
case <-ctx.Done():
|
||||
// Context canceled while waiting for prefetched audio; abort playback.
|
||||
logger.InfoCF(
|
||||
"discord",
|
||||
"TTS interrupted while waiting for prefetched audio",
|
||||
map[string]any{"at_sentence": i},
|
||||
)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// First sentence: synthesize directly
|
||||
stream, err = c.tts.Synthesize(ctx, sentence)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if stream != nil {
|
||||
stream.Close()
|
||||
}
|
||||
logger.ErrorCF("discord", "TTS synthesize failed", map[string]any{"error": err.Error(), "sentence": i})
|
||||
prefetch = nextPrefetch
|
||||
continue
|
||||
}
|
||||
|
||||
if err := streamOggOpusToDiscord(ctx, vc, stream); err != nil {
|
||||
logger.ErrorCF("discord", "TTS playback failed", map[string]any{"error": err.Error(), "sentence": i})
|
||||
}
|
||||
stream.Close()
|
||||
|
||||
prefetch = nextPrefetch
|
||||
}
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *DiscordChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/audio/tts"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -8,6 +9,10 @@ import (
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewDiscordChannel(cfg.Channels.Discord, b)
|
||||
ch, err := NewDiscordChannel(cfg.Channels.Discord, b)
|
||||
if err == nil {
|
||||
ch.tts = tts.DetectTTS(cfg)
|
||||
}
|
||||
return ch, err
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/audio"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func (c *DiscordChannel) setVoiceUserID(guildID string, ssrc uint32, userID string) {
|
||||
if userID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.voiceMu.Lock()
|
||||
defer c.voiceMu.Unlock()
|
||||
|
||||
ssrcMap, ok := c.voiceSSRC[guildID]
|
||||
if !ok {
|
||||
ssrcMap = make(map[uint32]string)
|
||||
c.voiceSSRC[guildID] = ssrcMap
|
||||
}
|
||||
ssrcMap[ssrc] = userID
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) voiceUserID(guildID string, ssrc uint32) string {
|
||||
c.voiceMu.RLock()
|
||||
defer c.voiceMu.RUnlock()
|
||||
|
||||
ssrcMap, ok := c.voiceSSRC[guildID]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return ssrcMap[ssrc]
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) handleVoiceCommand(s *discordgo.Session, m *discordgo.MessageCreate) bool {
|
||||
if m.Content == "!vc join" {
|
||||
vs, err := s.State.VoiceState(m.GuildID, m.Author.ID)
|
||||
if err != nil || vs == nil {
|
||||
if _, sendErr := s.ChannelMessageSend(
|
||||
m.ChannelID,
|
||||
"You need to be in a voice channel first!",
|
||||
); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice channel requirement message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
logger.InfoCF("discord", "Joining voice channel", map[string]any{"channel": vs.ChannelID})
|
||||
vc, err := s.ChannelVoiceJoin(c.ctx, m.GuildID, vs.ChannelID, false, false)
|
||||
if err != nil {
|
||||
if _, sendErr := s.ChannelMessageSend(
|
||||
m.ChannelID,
|
||||
fmt.Sprintf("Failed to join voice channel: %v", err),
|
||||
); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice join error message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
go c.receiveVoice(vc, m.GuildID, m.ChannelID)
|
||||
if _, sendErr := s.ChannelMessageSend(
|
||||
m.ChannelID,
|
||||
"Joined Voice Channel! Listening for audio...",
|
||||
); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice join success message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
return true
|
||||
} else if m.Content == "!vc leave" {
|
||||
vc, exists := s.VoiceConnections[m.GuildID]
|
||||
if exists && vc != nil {
|
||||
if err := vc.Disconnect(c.ctx); err != nil {
|
||||
logger.InfoCF("discord", "Failed to disconnect from voice channel", map[string]any{
|
||||
"guild": m.GuildID,
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Left Voice Channel."); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice leave success message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Not in a voice channel."); sendErr != nil {
|
||||
logger.InfoCF("discord", "Failed to send voice not-in-channel message", map[string]any{
|
||||
"channel": m.ChannelID,
|
||||
"error": sendErr,
|
||||
})
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func VoiceReceiveActive(vc *discordgo.VoiceConnection) bool {
|
||||
return vc != nil && vc.OpusRecv != nil
|
||||
}
|
||||
|
||||
func streamOggOpusToDiscord(ctx context.Context, vc *discordgo.VoiceConnection, r io.Reader) (retErr error) {
|
||||
// Recover from panic if vc.OpusSend is closed mid-send (e.g. on disconnect)
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
retErr = fmt.Errorf("voice connection closed during playback")
|
||||
logger.RecoverPanicNoExit(rec)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for the speaking transition to register
|
||||
vc.Speaking(true)
|
||||
defer vc.Speaking(false)
|
||||
|
||||
return audio.DecodeOggOpus(r, func(frame []byte) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case vc.OpusSend <- frame:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) receiveVoice(vc *discordgo.VoiceConnection, guildID string, chatID string) {
|
||||
logger.InfoCF("discord", "Started listening for voice", map[string]any{"guild": guildID})
|
||||
|
||||
vc.AddHandler(func(_ *discordgo.VoiceConnection, vs *discordgo.VoiceSpeakingUpdate) {
|
||||
if vs == nil {
|
||||
return
|
||||
}
|
||||
c.setVoiceUserID(guildID, uint32(vs.SSRC), vs.UserID)
|
||||
})
|
||||
|
||||
defer func() {
|
||||
c.voiceMu.Lock()
|
||||
delete(c.voiceSSRC, guildID)
|
||||
c.voiceMu.Unlock()
|
||||
}()
|
||||
|
||||
go func(ctx context.Context, vc *discordgo.VoiceConnection) {
|
||||
// Recover from potential panics if OpusSend is closed mid-send.
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
logger.WarnCF("discord", "Recovered from panic while sending wake-up frames", map[string]any{
|
||||
"error": rec,
|
||||
"guild": guildID,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// If the voice connection or OpusSend are not available, nothing to do.
|
||||
if vc == nil || vc.OpusSend == nil {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(250 * time.Millisecond) // Wait a bit for connection to settle
|
||||
|
||||
// Abort if the context has already been canceled.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
vc.Speaking(true)
|
||||
defer vc.Speaking(false)
|
||||
|
||||
silenceFrame := []byte{0xF8, 0xFF, 0xFE}
|
||||
for i := 0; i < 5; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case vc.OpusSend <- silenceFrame:
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
logger.DebugCF("discord", "Sent wake-up silence frames", map[string]any{"guild": guildID})
|
||||
}(c.ctx, vc)
|
||||
sessionID := fmt.Sprintf("discord_vc_%s", guildID)
|
||||
|
||||
c.bus.PublishVoiceControl(c.ctx, bus.VoiceControl{
|
||||
SessionID: sessionID,
|
||||
Type: "state",
|
||||
Action: "listening",
|
||||
})
|
||||
|
||||
var sequence uint64 = 0
|
||||
var interruptCount int
|
||||
var lastInterruptAt time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case p, ok := <-vc.OpusRecv:
|
||||
if !ok {
|
||||
logger.InfoCF("discord", "Voice channel closed", map[string]any{"guild": guildID})
|
||||
// Cancel any TTS that may still be playing
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
c.cancelTTS = nil
|
||||
}
|
||||
c.ttsMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
logger.DebugCF("discord", "Received nil Opus packet", nil)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(p.Opus) == 0 {
|
||||
logger.DebugCF("discord", "Received empty Opus packet", map[string]any{
|
||||
"seq": p.Sequence,
|
||||
"ssrc": p.SSRC,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.DebugCF("discord", "Received Opus packet", map[string]any{
|
||||
"seq": p.Sequence,
|
||||
"len": len(p.Opus),
|
||||
"ssrc": p.SSRC,
|
||||
})
|
||||
// Interruption detection: if user sends voice while TTS is playing,
|
||||
// cancel TTS after a short debounce (3 packets in 200ms)
|
||||
now := time.Now()
|
||||
if now.Sub(lastInterruptAt) > 500*time.Millisecond {
|
||||
interruptCount = 0
|
||||
}
|
||||
interruptCount++
|
||||
lastInterruptAt = now
|
||||
|
||||
if interruptCount >= 3 {
|
||||
c.ttsMu.Lock()
|
||||
if c.cancelTTS != nil {
|
||||
c.cancelTTS()
|
||||
c.cancelTTS = nil
|
||||
logger.InfoCF("discord", "TTS interrupted by user voice", nil)
|
||||
}
|
||||
c.ttsMu.Unlock()
|
||||
interruptCount = 0
|
||||
}
|
||||
|
||||
userID := c.voiceUserID(guildID, p.SSRC)
|
||||
if userID == "" {
|
||||
logger.DebugCF("discord", "Dropping voice packet without user mapping", map[string]any{
|
||||
"ssrc": p.SSRC,
|
||||
"guild": guildID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "discord",
|
||||
PlatformID: userID,
|
||||
CanonicalID: identity.BuildCanonicalID("discord", userID),
|
||||
}
|
||||
if !c.IsAllowedSender(sender) {
|
||||
logger.DebugCF("discord", "Voice packet rejected by allowlist", map[string]any{
|
||||
"user_id": userID,
|
||||
"guild": guildID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
sequence++
|
||||
|
||||
chunk := bus.AudioChunk{
|
||||
SessionID: sessionID,
|
||||
SpeakerID: userID,
|
||||
ChatID: chatID,
|
||||
Channel: "discord",
|
||||
Sequence: sequence,
|
||||
Timestamp: p.Timestamp,
|
||||
SampleRate: 48000,
|
||||
Channels: 2,
|
||||
Format: "opus",
|
||||
Data: p.Opus,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 100*time.Millisecond)
|
||||
err := c.bus.PublishAudioChunk(ctx, chunk)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.ErrorCF("discord", "Failed to publish audio chunk", map[string]any{
|
||||
"guild": guildID,
|
||||
"sessionID": sessionID,
|
||||
"sequence": sequence,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
|
||||
@@ -145,3 +147,8 @@ func extractImageKeysRecursive(v any, feishuKeys, externalURLs *[]string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *FeishuChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -696,3 +696,8 @@ func (c *LINEChannel) downloadContent(messageID, filename string) string {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *LINEChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -444,6 +444,23 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
m.initChannel("irc", "IRC")
|
||||
}
|
||||
|
||||
if channels.VK.Enabled && channels.VK.Token.String() != "" && channels.VK.GroupID != 0 {
|
||||
m.initChannel("vk", "VK")
|
||||
}
|
||||
|
||||
if channels.TeamsWebhook.Enabled && len(channels.TeamsWebhook.Webhooks) > 0 {
|
||||
hasValidTarget := false
|
||||
for _, target := range channels.TeamsWebhook.Webhooks {
|
||||
if target.WebhookURL.String() != "" {
|
||||
hasValidTarget = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasValidTarget {
|
||||
m.initChannel("teams_webhook", "Teams Webhook")
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
|
||||
"enabled_channels": len(m.channels),
|
||||
})
|
||||
|
||||
@@ -62,6 +62,13 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
|
||||
value["app_secret"] = ch.Feishu.AppSecret.String()
|
||||
value["encrypt_key"] = ch.Feishu.EncryptKey.String()
|
||||
value["verification_token"] = ch.Feishu.VerificationToken.String()
|
||||
case "teams_webhook":
|
||||
// Expose webhook URLs for hash computation (they contain secrets)
|
||||
webhooks := make(map[string]string)
|
||||
for name, target := range ch.TeamsWebhook.Webhooks {
|
||||
webhooks[name] = target.WebhookURL.String()
|
||||
}
|
||||
value["webhooks"] = webhooks
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,4 +173,13 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
|
||||
newcfg.Feishu.EncryptKey = old.Feishu.EncryptKey
|
||||
newcfg.Feishu.VerificationToken = old.Feishu.VerificationToken
|
||||
}
|
||||
if newcfg.TeamsWebhook.Enabled {
|
||||
// Copy SecureString webhook URLs from old config
|
||||
for name, oldTarget := range old.TeamsWebhook.Webhooks {
|
||||
if newTarget, ok := newcfg.TeamsWebhook.Webhooks[name]; ok {
|
||||
newTarget.WebhookURL = oldTarget.WebhookURL
|
||||
newcfg.TeamsWebhook.Webhooks[name] = newTarget
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
type mockChannel struct {
|
||||
BaseChannel
|
||||
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
|
||||
startFn func(ctx context.Context) error
|
||||
stopFn func(ctx context.Context) error
|
||||
sentMessages []bus.OutboundMessage
|
||||
placeholdersSent int
|
||||
editedMessages int
|
||||
@@ -33,8 +35,19 @@ func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
return nil, m.sendFn(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *mockChannel) Start(ctx context.Context) error { return nil }
|
||||
func (m *mockChannel) Stop(ctx context.Context) error { return nil }
|
||||
func (m *mockChannel) Start(ctx context.Context) error {
|
||||
if m.startFn != nil {
|
||||
return m.startFn(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockChannel) Stop(ctx context.Context) error {
|
||||
if m.stopFn != nil {
|
||||
return m.stopFn(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
m.placeholdersSent++
|
||||
@@ -86,6 +99,101 @@ func newTestManager() *Manager {
|
||||
return &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
bus: bus.NewMessageBus(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAll_AllChannelsFail_ReturnsJoinedError(t *testing.T) {
|
||||
m := newTestManager()
|
||||
errA := errors.New("channel-a start failed")
|
||||
errB := errors.New("channel-b start failed")
|
||||
|
||||
m.channels["a"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errA },
|
||||
}
|
||||
m.channels["b"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errB },
|
||||
}
|
||||
|
||||
err := m.StartAll(t.Context())
|
||||
if err == nil {
|
||||
t.Fatal("expected StartAll to fail when all channels fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to start any enabled channels") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !errors.Is(err, errA) {
|
||||
t.Fatalf("expected error to wrap errA, got: %v", err)
|
||||
}
|
||||
if !errors.Is(err, errB) {
|
||||
t.Fatalf("expected error to wrap errB, got: %v", err)
|
||||
}
|
||||
if len(m.workers) != 0 {
|
||||
t.Fatalf("expected no workers on full startup failure, got %d", len(m.workers))
|
||||
}
|
||||
if m.dispatchTask != nil {
|
||||
t.Fatal("expected dispatch task to be cleared on full startup failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
m := newTestManager()
|
||||
errBad := errors.New("bad channel start failed")
|
||||
processed := make(chan struct{}, 1)
|
||||
|
||||
m.channels["good"] = &mockChannel{
|
||||
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
|
||||
if msg.Channel == "good" {
|
||||
select {
|
||||
case processed <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
m.channels["bad"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errBad },
|
||||
}
|
||||
|
||||
err := m.StartAll(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("expected StartAll to succeed with partial channel failures, got: %v", err)
|
||||
}
|
||||
if len(m.workers) != 1 {
|
||||
t.Fatalf("expected exactly 1 active worker, got %d", len(m.workers))
|
||||
}
|
||||
if _, ok := m.workers["good"]; !ok {
|
||||
t.Fatal("expected worker for successful channel 'good'")
|
||||
}
|
||||
if _, ok := m.workers["bad"]; ok {
|
||||
t.Fatal("did not expect worker for failed channel 'bad'")
|
||||
}
|
||||
if m.dispatchTask == nil {
|
||||
t.Fatal("expected dispatch task to run when at least one channel starts")
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: "good",
|
||||
ChatID: "chat-1",
|
||||
Content: "hello",
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishOutbound() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
// worker processed outbound message as expected
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected successful channel worker to process outbound message")
|
||||
}
|
||||
|
||||
stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer stopCancel()
|
||||
if err := m.StopAll(stopCtx); err != nil {
|
||||
t.Fatalf("StopAll() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1300,3 +1300,8 @@ func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp.
|
||||
cleaned = strings.TrimLeft(cleaned, ",:; ")
|
||||
return strings.TrimSpace(cleaned)
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *MatrixChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -1117,3 +1117,8 @@ func truncate(s string, n int) string {
|
||||
}
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *OneBotChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -262,3 +262,57 @@ func TestSend_ClosedConnection(t *testing.T) {
|
||||
|
||||
ch.Stop(ctx)
|
||||
}
|
||||
|
||||
func TestParseInlineImageMedia_Valid(t *testing.T) {
|
||||
media, err := parseInlineImageMedia(map[string]any{
|
||||
"media": []any{
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+X2ioAAAAASUVORK5CYII=",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("parseInlineImageMedia() error = %v", err)
|
||||
}
|
||||
if len(media) != 1 {
|
||||
t.Fatalf("len(media) = %d, want 1", len(media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewPicoChannel(config.PicoConfig{
|
||||
Token: *config.NewSecureString("test-token"),
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("NewPicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
pc := &picoConn{id: "conn-1", sessionID: "sess-1"}
|
||||
ch.handleMessageSend(pc, PicoMessage{
|
||||
ID: "msg-1",
|
||||
Payload: map[string]any{
|
||||
"media": []any{
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+X2ioAAAAASUVORK5CYII=",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case msg := <-mb.InboundChan():
|
||||
if msg.Content != "" {
|
||||
t.Fatalf("msg.Content = %q, want empty", msg.Content)
|
||||
}
|
||||
if len(msg.Media) != 1 || !strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
|
||||
t.Fatalf("msg.Media = %#v, want inline image payload", msg.Media)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for inbound media message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,19 +39,18 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
|
||||
}
|
||||
}
|
||||
|
||||
// newError creates an error PicoMessage.
|
||||
func newError(code, message string) PicoMessage {
|
||||
return newMessage(TypeError, map[string]any{
|
||||
func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage {
|
||||
payload := map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
func newErrorWithPayload(code, message string, payload map[string]any) PicoMessage {
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
payload["code"] = code
|
||||
payload["message"] = message
|
||||
for key, value := range extra {
|
||||
payload[key] = value
|
||||
}
|
||||
return newMessage(TypeError, payload)
|
||||
}
|
||||
|
||||
// newError creates an error PicoMessage.
|
||||
func newError(code, message string) PicoMessage {
|
||||
return newErrorWithPayload(code, message, nil)
|
||||
}
|
||||
|
||||
@@ -1003,3 +1003,8 @@ func sanitizeURLs(text string) string {
|
||||
return scheme + domain + path
|
||||
})
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *QQChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package teamswebhook
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("teams_webhook", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewTeamsWebhookChannel(cfg.Channels.TeamsWebhook, b)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
package teamswebhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
|
||||
"github.com/atc0005/go-teams-notify/v2/adaptivecard"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// statusCodeRe extracts HTTP status codes from error messages like "401 Unauthorized".
|
||||
var statusCodeRe = regexp.MustCompile(`\b([45]\d{2})\b`)
|
||||
|
||||
// markdownTableRe matches a markdown table block (header + separator + rows).
|
||||
// It captures the entire table including all rows.
|
||||
var markdownTableRe = regexp.MustCompile(`(?m)^(\|[^\n]+\|)\n(\|[-:\|\s]+\|)\n((?:\|[^\n]+\|\n?)+)`)
|
||||
|
||||
// teamsMessageSender abstracts the Teams client for testability.
|
||||
type teamsMessageSender interface {
|
||||
SendWithContext(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
|
||||
}
|
||||
|
||||
// classifyTeamsError extracts HTTP status code from error message and classifies it.
|
||||
// The go-teams-notify library returns errors like "error on notification: 401 Unauthorized, ...".
|
||||
// This allows proper retry behavior: 4xx errors are permanent, 5xx are temporary.
|
||||
func classifyTeamsError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
errMsg := err.Error()
|
||||
if matches := statusCodeRe.FindStringSubmatch(errMsg); len(matches) > 1 {
|
||||
if statusCode, parseErr := strconv.Atoi(matches[1]); parseErr == nil {
|
||||
return channels.ClassifySendError(statusCode, err)
|
||||
}
|
||||
}
|
||||
// Fallback: treat as temporary network error (retryable)
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
// TeamsWebhookChannel is an output-only channel that sends messages
|
||||
// to Microsoft Teams via Power Automate workflow webhooks.
|
||||
// Multiple webhook targets can be configured and selected via ChatID.
|
||||
type TeamsWebhookChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.TeamsWebhookConfig
|
||||
client teamsMessageSender
|
||||
}
|
||||
|
||||
// NewTeamsWebhookChannel creates a new Teams webhook channel.
|
||||
func NewTeamsWebhookChannel(
|
||||
cfg config.TeamsWebhookConfig,
|
||||
bus *bus.MessageBus,
|
||||
) (*TeamsWebhookChannel, error) {
|
||||
if len(cfg.Webhooks) == 0 {
|
||||
return nil, fmt.Errorf("teams_webhook: at least one webhook target is required")
|
||||
}
|
||||
|
||||
// Require "default" webhook target
|
||||
if _, hasDefault := cfg.Webhooks["default"]; !hasDefault {
|
||||
return nil, fmt.Errorf("teams_webhook: a 'default' webhook target is required")
|
||||
}
|
||||
|
||||
// Validate all webhook targets have valid HTTPS URLs
|
||||
for name, target := range cfg.Webhooks {
|
||||
webhookURL := target.WebhookURL.String()
|
||||
if webhookURL == "" {
|
||||
return nil, fmt.Errorf("teams_webhook: webhook %q has empty webhook_url", name)
|
||||
}
|
||||
parsed, err := url.Parse(webhookURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("teams_webhook: webhook %q has invalid URL: %w", name, err)
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return nil, fmt.Errorf("teams_webhook: webhook %q must use HTTPS (got %q)", name, parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"teams_webhook",
|
||||
cfg,
|
||||
bus,
|
||||
[]string{
|
||||
"*",
|
||||
}, // Output-only channel; "*" suppresses misleading "allows EVERYONE" audit warning
|
||||
channels.WithMaxMessageLength(24000), // Power Automate webhook payload limit is 28KB
|
||||
)
|
||||
|
||||
client := goteamsnotify.NewTeamsClient()
|
||||
|
||||
return &TeamsWebhookChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start initializes the channel. For output-only channels, this is a no-op.
|
||||
func (c *TeamsWebhookChannel) Start(ctx context.Context) error {
|
||||
targets := make([]string, 0, len(c.config.Webhooks))
|
||||
for name := range c.config.Webhooks {
|
||||
targets = append(targets, name)
|
||||
}
|
||||
sort.Strings(targets)
|
||||
logger.InfoCF("teams_webhook", "Starting Teams webhook channel (output-only)", map[string]any{
|
||||
"targets": targets,
|
||||
})
|
||||
c.SetRunning(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the channel.
|
||||
func (c *TeamsWebhookChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("teams_webhook", "Stopping Teams webhook channel")
|
||||
c.SetRunning(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send delivers a message to the specified Teams webhook target.
|
||||
// The target is selected by msg.ChatID which must match a key in the webhooks map.
|
||||
func (c *TeamsWebhookChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Look up webhook target by ChatID, fall back to "default" if empty or unknown
|
||||
targetName := msg.ChatID
|
||||
if targetName == "" {
|
||||
targetName = "default"
|
||||
}
|
||||
|
||||
target, ok := c.config.Webhooks[targetName]
|
||||
if !ok {
|
||||
// Log warning and fall back to default target
|
||||
logger.WarnCF("teams_webhook", "Unknown target, falling back to default", map[string]any{
|
||||
"requested": msg.ChatID,
|
||||
"using": "default",
|
||||
})
|
||||
target = c.config.Webhooks["default"]
|
||||
}
|
||||
|
||||
// Build an Adaptive Card for rich formatting
|
||||
card, err := c.buildAdaptiveCard(msg, target)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("teams_webhook: failed to build card: %w", err)
|
||||
}
|
||||
|
||||
// Create the message with the card
|
||||
teamsMsg, err := adaptivecard.NewMessageFromCard(card)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("teams_webhook: failed to create message: %w", err)
|
||||
}
|
||||
|
||||
// Send to Teams
|
||||
if err := c.client.SendWithContext(ctx, target.WebhookURL.String(), teamsMsg); err != nil {
|
||||
// Log without raw error to avoid leaking webhook URL (embedded in net/http errors)
|
||||
logger.ErrorCF("teams_webhook", "Failed to send message to Teams webhook", map[string]any{
|
||||
"target": msg.ChatID,
|
||||
})
|
||||
// Classify error based on status code extracted from error message.
|
||||
// The go-teams-notify library includes status in errors like "401 Unauthorized".
|
||||
// Use ClassifySendError for proper retry behavior (4xx = permanent, 5xx = temporary).
|
||||
classifiedErr := classifyTeamsError(err)
|
||||
return nil, fmt.Errorf("teams_webhook: send failed: %w", classifiedErr)
|
||||
}
|
||||
|
||||
logger.DebugCF("teams_webhook", "Message sent successfully", map[string]any{
|
||||
"target": msg.ChatID,
|
||||
})
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// buildAdaptiveCard creates a formatted Adaptive Card from the outbound message.
|
||||
// It detects markdown tables and converts them to native Adaptive Card Table elements,
|
||||
// since TextBlocks only support a limited markdown subset (no tables).
|
||||
func (c *TeamsWebhookChannel) buildAdaptiveCard(
|
||||
msg bus.OutboundMessage,
|
||||
target config.TeamsWebhookTarget,
|
||||
) (adaptivecard.Card, error) {
|
||||
card := adaptivecard.NewCard()
|
||||
card.Type = adaptivecard.TypeAdaptiveCard
|
||||
|
||||
// Set full width for Teams rendering
|
||||
card.MSTeams.Width = "Full"
|
||||
|
||||
// Add title if configured on the target
|
||||
title := target.Title
|
||||
if title == "" {
|
||||
title = "PicoClaw Notification"
|
||||
}
|
||||
|
||||
titleBlock := adaptivecard.NewTextBlock(title, true)
|
||||
titleBlock.Size = adaptivecard.SizeLarge
|
||||
titleBlock.Weight = adaptivecard.WeightBolder
|
||||
titleBlock.Style = adaptivecard.TextBlockStyleHeading
|
||||
|
||||
if err := card.AddElement(false, titleBlock); err != nil {
|
||||
return card, err
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
if content == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Split content into text segments and tables
|
||||
// TextBlocks support: bold, italic, bullet/numbered lists, links
|
||||
// TextBlocks do NOT support: headers, tables, images
|
||||
segments := splitContentWithTables(content)
|
||||
|
||||
for _, seg := range segments {
|
||||
if seg.isTable {
|
||||
// Convert markdown table to Adaptive Card Table element
|
||||
tableElement, err := parseMarkdownTable(seg.content)
|
||||
if err != nil {
|
||||
// Fallback: render as preformatted text if parsing fails
|
||||
logger.WarnCF("teams_webhook", "Failed to parse markdown table, using fallback", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
block := adaptivecard.NewTextBlock("```\n"+seg.content+"\n```", true)
|
||||
block.Wrap = true
|
||||
if err := card.AddElement(false, block); err != nil {
|
||||
return card, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := card.AddElement(false, tableElement); err != nil {
|
||||
return card, err
|
||||
}
|
||||
} else {
|
||||
// Regular text content
|
||||
text := strings.TrimSpace(seg.content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
block := adaptivecard.NewTextBlock(text, true)
|
||||
block.Wrap = true
|
||||
if err := card.AddElement(false, block); err != nil {
|
||||
return card, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return card, nil
|
||||
}
|
||||
|
||||
// contentSegment represents either a text block or a table in the message content.
|
||||
type contentSegment struct {
|
||||
content string
|
||||
isTable bool
|
||||
}
|
||||
|
||||
// splitContentWithTables splits content into alternating text and table segments.
|
||||
func splitContentWithTables(content string) []contentSegment {
|
||||
var segments []contentSegment
|
||||
|
||||
matches := markdownTableRe.FindAllStringSubmatchIndex(content, -1)
|
||||
if len(matches) == 0 {
|
||||
// No tables found, return entire content as text
|
||||
return []contentSegment{{content: content, isTable: false}}
|
||||
}
|
||||
|
||||
lastEnd := 0
|
||||
for _, match := range matches {
|
||||
// Text before this table
|
||||
if match[0] > lastEnd {
|
||||
segments = append(segments, contentSegment{
|
||||
content: content[lastEnd:match[0]],
|
||||
isTable: false,
|
||||
})
|
||||
}
|
||||
// The table itself
|
||||
segments = append(segments, contentSegment{
|
||||
content: content[match[0]:match[1]],
|
||||
isTable: true,
|
||||
})
|
||||
lastEnd = match[1]
|
||||
}
|
||||
|
||||
// Text after the last table
|
||||
if lastEnd < len(content) {
|
||||
segments = append(segments, contentSegment{
|
||||
content: content[lastEnd:],
|
||||
isTable: false,
|
||||
})
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// parseMarkdownTable converts a markdown table string to an Adaptive Card Table element.
|
||||
func parseMarkdownTable(tableStr string) (adaptivecard.Element, error) {
|
||||
lines := strings.Split(strings.TrimSpace(tableStr), "\n")
|
||||
if len(lines) < 2 {
|
||||
return adaptivecard.Element{}, fmt.Errorf("table must have at least header and separator rows")
|
||||
}
|
||||
|
||||
// Track header content length per column for width calculation
|
||||
var headerLengths []int
|
||||
|
||||
// Parse all rows (header + data rows, skip separator)
|
||||
var allRows [][]adaptivecard.TableCell
|
||||
for i, line := range lines {
|
||||
// Skip separator row (contains only |, -, :, and spaces)
|
||||
if i == 1 && isSeparatorRow(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
cells := parseTableRow(line)
|
||||
if len(cells) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var tableCells []adaptivecard.TableCell
|
||||
for _, cellText := range cells {
|
||||
trimmedText := strings.TrimSpace(cellText)
|
||||
|
||||
// Use header row (first row) to determine column widths
|
||||
if i == 0 {
|
||||
headerLengths = append(headerLengths, len(trimmedText))
|
||||
}
|
||||
|
||||
textBlock := adaptivecard.Element{
|
||||
Type: adaptivecard.TypeElementTextBlock,
|
||||
Text: trimmedText,
|
||||
Wrap: true,
|
||||
}
|
||||
cell := adaptivecard.TableCell{
|
||||
Type: adaptivecard.TypeTableCell,
|
||||
Items: []*adaptivecard.Element{&textBlock},
|
||||
}
|
||||
tableCells = append(tableCells, cell)
|
||||
}
|
||||
allRows = append(allRows, tableCells)
|
||||
}
|
||||
|
||||
if len(allRows) == 0 {
|
||||
return adaptivecard.Element{}, fmt.Errorf("no valid rows found in table")
|
||||
}
|
||||
|
||||
// Create table with first row as headers
|
||||
firstRowAsHeaders := true
|
||||
showGridLines := true
|
||||
|
||||
table, err := adaptivecard.NewTableFromTableCells(allRows, 0, firstRowAsHeaders, showGridLines)
|
||||
if err != nil {
|
||||
return adaptivecard.Element{}, fmt.Errorf("failed to create table: %w", err)
|
||||
}
|
||||
|
||||
// Set column widths based on header content length
|
||||
table.Columns = calculateColumnWidths(headerLengths)
|
||||
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// calculateColumnWidths creates TableColumnDefinition entries with widths
|
||||
// proportional to the max content length of each column.
|
||||
func calculateColumnWidths(maxLengths []int) []adaptivecard.Column {
|
||||
if len(maxLengths) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use content length as relative weight, with a minimum of 1
|
||||
columns := make([]adaptivecard.Column, len(maxLengths))
|
||||
for i, length := range maxLengths {
|
||||
weight := length
|
||||
if weight < 1 {
|
||||
weight = 1
|
||||
}
|
||||
columns[i] = adaptivecard.Column{
|
||||
Type: "TableColumnDefinition",
|
||||
Width: weight,
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// isSeparatorRow checks if a line is a markdown table separator (e.g., |---|---|).
|
||||
func isSeparatorRow(line string) bool {
|
||||
// Remove pipes and spaces, check if only dashes and colons remain
|
||||
cleaned := strings.ReplaceAll(line, "|", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, " ", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, "-", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, ":", "")
|
||||
return cleaned == ""
|
||||
}
|
||||
|
||||
// parseTableRow extracts cell values from a markdown table row.
|
||||
func parseTableRow(line string) []string {
|
||||
// Trim leading/trailing pipes and split by |
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimPrefix(line, "|")
|
||||
line = strings.TrimSuffix(line, "|")
|
||||
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(line, "|")
|
||||
var cells []string
|
||||
for _, p := range parts {
|
||||
cells = append(cells, strings.TrimSpace(p))
|
||||
}
|
||||
return cells
|
||||
}
|
||||
@@ -0,0 +1,583 @@
|
||||
package teamswebhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// mockTeamsClient implements teamsMessageSender for testing.
|
||||
type mockTeamsClient struct {
|
||||
sendFunc func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
|
||||
}
|
||||
|
||||
func (m *mockTeamsClient) SendWithContext(
|
||||
ctx context.Context,
|
||||
webhookURL string,
|
||||
message goteamsnotify.TeamsMessage,
|
||||
) error {
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, webhookURL, message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewTeamsWebhookChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
// Test missing webhooks
|
||||
_, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: nil,
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing webhooks")
|
||||
}
|
||||
|
||||
// Test missing "default" webhook
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing 'default' webhook")
|
||||
}
|
||||
|
||||
// Test empty webhook URL
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {Title: "Default"},
|
||||
},
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty webhook_url")
|
||||
}
|
||||
|
||||
// Test HTTP URL (should fail, must be HTTPS)
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
|
||||
Title: "Default",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for HTTP webhook URL (must be HTTPS)")
|
||||
}
|
||||
|
||||
// Test valid config with HTTPS (must include "default")
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if ch.Name() != "teams_webhook" {
|
||||
t.Errorf("expected name 'teams_webhook', got %q", ch.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_StartStop(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if ch.IsRunning() {
|
||||
t.Error("channel should not be running before Start")
|
||||
}
|
||||
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
|
||||
if !ch.IsRunning() {
|
||||
t.Error("channel should be running after Start")
|
||||
}
|
||||
|
||||
if err := ch.Stop(ctx); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
|
||||
if ch.IsRunning() {
|
||||
t.Error("channel should not be running after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
Title: "Custom Title",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
target := ch.config.Webhooks["alerts"]
|
||||
msg := bus.OutboundMessage{
|
||||
Content: "Test message content",
|
||||
ChatID: "alerts",
|
||||
}
|
||||
|
||||
card, err := ch.buildAdaptiveCard(msg, target)
|
||||
if err != nil {
|
||||
t.Fatalf("buildAdaptiveCard failed: %v", err)
|
||||
}
|
||||
|
||||
if card.Type != "AdaptiveCard" {
|
||||
t.Errorf("expected card type 'AdaptiveCard', got %q", card.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendNotRunning(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Content: "test", ChatID: "default"}
|
||||
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err == nil {
|
||||
t.Error("expected error when sending while not running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chatID string
|
||||
}{
|
||||
{"unknown target falls back to default", "unknown"},
|
||||
{"empty ChatID uses default", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var sentURL string
|
||||
ch.client = &mockTeamsClient{
|
||||
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
|
||||
sentURL = webhookURL
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = ch.Start(ctx)
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
msg := bus.OutboundMessage{Content: "test", ChatID: tt.chatID}
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success, got error: %v", err)
|
||||
}
|
||||
|
||||
if sentURL != "https://example.com/webhook-default" {
|
||||
t.Errorf("expected default webhook URL, got %q", sentURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
Title: "Test Alerts",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Inject mock client
|
||||
var sentURL string
|
||||
ch.client = &mockTeamsClient{
|
||||
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
|
||||
sentURL = webhookURL
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = ch.Start(ctx)
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
msg := bus.OutboundMessage{Content: "Hello Teams!", ChatID: "alerts"}
|
||||
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if sentURL != "https://example.com/webhook-alerts" {
|
||||
t.Errorf("expected webhook URL 'https://example.com/webhook-alerts', got %q", sentURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendError(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Inject mock client that returns an error
|
||||
ch.client = &mockTeamsClient{
|
||||
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
|
||||
return errors.New("error on notification: 401 Unauthorized, forbidden")
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = ch.Start(ctx)
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
msg := bus.OutboundMessage{Content: "test", ChatID: "alerts"}
|
||||
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err == nil {
|
||||
t.Error("expected error from failed send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitContentWithTables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantSegs int
|
||||
wantTbl int // number of table segments
|
||||
}{
|
||||
{
|
||||
name: "no tables",
|
||||
content: "Just some text\nwith multiple lines",
|
||||
wantSegs: 1,
|
||||
wantTbl: 0,
|
||||
},
|
||||
{
|
||||
name: "single table",
|
||||
content: `| Col1 | Col2 |
|
||||
|------|------|
|
||||
| A | B |
|
||||
| C | D |`,
|
||||
wantSegs: 1,
|
||||
wantTbl: 1,
|
||||
},
|
||||
{
|
||||
name: "text before table",
|
||||
content: `Here is some text.
|
||||
|
||||
| Col1 | Col2 |
|
||||
|------|------|
|
||||
| A | B |`,
|
||||
wantSegs: 2,
|
||||
wantTbl: 1,
|
||||
},
|
||||
{
|
||||
name: "text before and after table",
|
||||
content: `Before table.
|
||||
|
||||
| Col1 | Col2 |
|
||||
|------|------|
|
||||
| A | B |
|
||||
|
||||
After table.`,
|
||||
wantSegs: 3,
|
||||
wantTbl: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple tables",
|
||||
content: `First table:
|
||||
|
||||
| A | B |
|
||||
|---|---|
|
||||
| 1 | 2 |
|
||||
|
||||
Second table:
|
||||
|
||||
| X | Y |
|
||||
|---|---|
|
||||
| 3 | 4 |`,
|
||||
wantSegs: 4,
|
||||
wantTbl: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
segs := splitContentWithTables(tt.content)
|
||||
if len(segs) != tt.wantSegs {
|
||||
t.Errorf("got %d segments, want %d", len(segs), tt.wantSegs)
|
||||
}
|
||||
tableCount := 0
|
||||
for _, s := range segs {
|
||||
if s.isTable {
|
||||
tableCount++
|
||||
}
|
||||
}
|
||||
if tableCount != tt.wantTbl {
|
||||
t.Errorf("got %d tables, want %d", tableCount, tt.wantTbl)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMarkdownTable(t *testing.T) {
|
||||
tableStr := `| Name | Value |
|
||||
|------|-------|
|
||||
| foo | 123 |
|
||||
| bar | 456 |`
|
||||
|
||||
elem, err := parseMarkdownTable(tableStr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if elem.Type != "Table" {
|
||||
t.Errorf("expected type 'Table', got %q", elem.Type)
|
||||
}
|
||||
|
||||
// Should have 3 rows (header + 2 data rows)
|
||||
if len(elem.Rows) != 3 {
|
||||
t.Errorf("expected 3 rows, got %d", len(elem.Rows))
|
||||
}
|
||||
|
||||
// Should have 2 columns with widths based on content length
|
||||
if len(elem.Columns) != 2 {
|
||||
t.Errorf("expected 2 columns, got %d", len(elem.Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMarkdownTableColumnWidths(t *testing.T) {
|
||||
// Column widths are based on HEADER row only:
|
||||
// Col1: "Description" (11 chars)
|
||||
// Col2: "X" (1 char)
|
||||
// Col3: "Amount" (6 chars)
|
||||
tableStr := `| Description | X | Amount |
|
||||
|-------------|---|--------|
|
||||
| Short | Y | 100 |
|
||||
| Longer text | Z | 50 |`
|
||||
|
||||
elem, err := parseMarkdownTable(tableStr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(elem.Columns) != 3 {
|
||||
t.Fatalf("expected 3 columns, got %d", len(elem.Columns))
|
||||
}
|
||||
|
||||
// Verify column widths are based on header content length
|
||||
w1, ok1 := elem.Columns[0].Width.(int)
|
||||
w2, ok2 := elem.Columns[1].Width.(int)
|
||||
w3, ok3 := elem.Columns[2].Width.(int)
|
||||
|
||||
if !ok1 || !ok2 || !ok3 {
|
||||
t.Fatalf("expected int widths, got types: %T, %T, %T",
|
||||
elem.Columns[0].Width, elem.Columns[1].Width, elem.Columns[2].Width)
|
||||
}
|
||||
|
||||
// Header lengths: "Description" = 11, "X" = 1, "Amount" = 6
|
||||
if w1 != 11 {
|
||||
t.Errorf("expected col1 width 11 (from 'Description'), got %d", w1)
|
||||
}
|
||||
if w2 != 1 {
|
||||
t.Errorf("expected col2 width 1 (from 'X'), got %d", w2)
|
||||
}
|
||||
if w3 != 6 {
|
||||
t.Errorf("expected col3 width 6 (from 'Amount'), got %d", w3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateColumnWidths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxLengths []int
|
||||
wantWidths []int
|
||||
}{
|
||||
{
|
||||
name: "equal lengths",
|
||||
maxLengths: []int{10, 10, 10},
|
||||
wantWidths: []int{10, 10, 10},
|
||||
},
|
||||
{
|
||||
name: "varying lengths",
|
||||
maxLengths: []int{5, 20, 10},
|
||||
wantWidths: []int{5, 20, 10},
|
||||
},
|
||||
{
|
||||
name: "zero length gets minimum of 1",
|
||||
maxLengths: []int{0, 5, 0},
|
||||
wantWidths: []int{1, 5, 1},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
maxLengths: []int{},
|
||||
wantWidths: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cols := calculateColumnWidths(tt.maxLengths)
|
||||
|
||||
if tt.wantWidths == nil {
|
||||
if cols != nil {
|
||||
t.Errorf("expected nil, got %v", cols)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(cols) != len(tt.wantWidths) {
|
||||
t.Fatalf("expected %d columns, got %d", len(tt.wantWidths), len(cols))
|
||||
}
|
||||
|
||||
for i, col := range cols {
|
||||
width, ok := col.Width.(int)
|
||||
if !ok {
|
||||
t.Errorf("column %d: expected int width, got %T", i, col.Width)
|
||||
continue
|
||||
}
|
||||
if width != tt.wantWidths[i] {
|
||||
t.Errorf("column %d: expected width %d, got %d", i, tt.wantWidths[i], width)
|
||||
}
|
||||
if col.Type != "TableColumnDefinition" {
|
||||
t.Errorf("column %d: expected type 'TableColumnDefinition', got %q", i, col.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTableRow(t *testing.T) {
|
||||
tests := []struct {
|
||||
line string
|
||||
want []string
|
||||
}{
|
||||
{"| A | B | C |", []string{"A", "B", "C"}},
|
||||
{"|A|B|C|", []string{"A", "B", "C"}},
|
||||
{"| foo | bar |", []string{"foo", "bar"}},
|
||||
{"", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := parseTableRow(tt.line)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("parseTableRow(%q): got %v, want %v", tt.line, got, tt.want)
|
||||
continue
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("parseTableRow(%q)[%d]: got %q, want %q", tt.line, i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSeparatorRow(t *testing.T) {
|
||||
tests := []struct {
|
||||
line string
|
||||
want bool
|
||||
}{
|
||||
{"|---|---|", true},
|
||||
{"| --- | --- |", true},
|
||||
{"|:---|---:|", true},
|
||||
{"| :---: | :---: |", true},
|
||||
{"| A | B |", false},
|
||||
{"| foo | bar |", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := isSeparatorRow(tt.line)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSeparatorRow(%q): got %v, want %v", tt.line, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1190,3 +1190,8 @@ func isPostConnectError(err error) bool {
|
||||
strings.Contains(msg, "connection closed by foreign host") ||
|
||||
strings.Contains(msg, "broken pipe")
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *TelegramChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package vk
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("vk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewVKChannel(cfg, b)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
package vk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/SevereCloud/vksdk/v3/api"
|
||||
"github.com/SevereCloud/vksdk/v3/api/params"
|
||||
"github.com/SevereCloud/vksdk/v3/events"
|
||||
"github.com/SevereCloud/vksdk/v3/longpoll-bot"
|
||||
"github.com/SevereCloud/vksdk/v3/object"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type VKChannel struct {
|
||||
*channels.BaseChannel
|
||||
vk *api.VK
|
||||
lp *longpoll.LongPoll
|
||||
config *config.Config
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewVKChannel(cfg *config.Config, bus *bus.MessageBus) (*VKChannel, error) {
|
||||
vkCfg := cfg.Channels.VK
|
||||
|
||||
vk := api.NewVK(vkCfg.Token.String())
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"vk",
|
||||
vkCfg,
|
||||
bus,
|
||||
vkCfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(4000),
|
||||
channels.WithGroupTrigger(vkCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(vkCfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &VKChannel{
|
||||
BaseChannel: base,
|
||||
vk: vk,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *VKChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("vk", "Starting VK bot (Long Poll mode)...")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
groupID := c.config.Channels.VK.GroupID
|
||||
if groupID == 0 {
|
||||
c.cancel()
|
||||
return fmt.Errorf("group_id is required for VK bot")
|
||||
}
|
||||
|
||||
lp, err := longpoll.NewLongPoll(c.vk, groupID)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("failed to create long poll: %w", err)
|
||||
}
|
||||
c.lp = lp
|
||||
|
||||
lp.MessageNew(func(_ context.Context, obj events.MessageNewObject) {
|
||||
c.handleMessage(obj.Message)
|
||||
})
|
||||
|
||||
c.SetRunning(true)
|
||||
|
||||
logger.InfoCF("vk", "VK bot connected", map[string]any{
|
||||
"group_id": groupID,
|
||||
})
|
||||
|
||||
go func() {
|
||||
if err := lp.Run(); err != nil {
|
||||
logger.ErrorCF("vk", "Long poll failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *VKChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("vk", "Stopping VK bot...")
|
||||
c.SetRunning(false)
|
||||
|
||||
if c.lp != nil {
|
||||
c.lp.Shutdown()
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
|
||||
if msg.Action.Type != "" {
|
||||
return
|
||||
}
|
||||
|
||||
if bool(msg.Out) {
|
||||
return
|
||||
}
|
||||
|
||||
peerID := msg.PeerID
|
||||
chatID := strconv.Itoa(peerID)
|
||||
|
||||
fromID := msg.FromID
|
||||
userID := strconv.Itoa(fromID)
|
||||
|
||||
platformID := userID
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "vk",
|
||||
PlatformID: platformID,
|
||||
CanonicalID: identity.BuildCanonicalID("vk", platformID),
|
||||
DisplayName: c.getUserName(fromID),
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
logger.DebugCF("vk", "Message from unauthorized user", map[string]any{
|
||||
"peer_id": peerID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
text := msg.Text
|
||||
if text == "" && len(msg.Attachments) > 0 {
|
||||
text = c.processAttachments(msg.Attachments)
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
|
||||
groupTrigger := c.config.Channels.VK.GroupTrigger
|
||||
isGroupChat := peerID != fromID
|
||||
|
||||
if isGroupChat {
|
||||
isMentioned := c.isMentioned(msg)
|
||||
if isMentioned {
|
||||
text = c.stripBotMention(text)
|
||||
}
|
||||
respond, cleaned := c.ShouldRespondInGroup(isMentioned, text)
|
||||
if !respond {
|
||||
return
|
||||
}
|
||||
text = cleaned
|
||||
_ = groupTrigger
|
||||
}
|
||||
|
||||
chatType := "direct"
|
||||
if isGroupChat {
|
||||
chatType = "group"
|
||||
}
|
||||
|
||||
messageID := strconv.Itoa(msg.ConversationMessageID)
|
||||
|
||||
metadata := map[string]string{
|
||||
"user_id": userID,
|
||||
"is_group": fmt.Sprintf("%t", isGroupChat),
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{
|
||||
Channel: "vk",
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
SenderID: userID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isGroupChat && c.isMentioned(msg),
|
||||
Raw: metadata,
|
||||
}, sender)
|
||||
}
|
||||
|
||||
func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
peerID, err := strconv.Atoi(msg.ChatID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
if msg.Content == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var messageIDs []string
|
||||
chunks := channels.SplitMessage(msg.Content, 4000)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if chunk == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
b := params.NewMessagesSendBuilder()
|
||||
b.Message(chunk)
|
||||
b.RandomID(0)
|
||||
b.PeerID(peerID)
|
||||
|
||||
if msg.ReplyToMessageID != "" {
|
||||
if replyID, err := strconv.Atoi(msg.ReplyToMessageID); err == nil {
|
||||
b.ReplyTo(replyID)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := c.vk.MessagesSend(b.Params)
|
||||
if err != nil {
|
||||
logger.ErrorCF("vk", "Failed to send message", map[string]any{
|
||||
"error": err.Error(),
|
||||
"peer_id": peerID,
|
||||
})
|
||||
return messageIDs, fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
messageIDs = append(messageIDs, strconv.Itoa(resp))
|
||||
}
|
||||
|
||||
return messageIDs, nil
|
||||
}
|
||||
|
||||
func (c *VKChannel) isMentioned(msg object.MessagesMessage) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *VKChannel) stripBotMention(text string) string {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func (c *VKChannel) getUserName(userID int) string {
|
||||
users, err := c.vk.UsersGet(api.Params{
|
||||
"user_ids": userID,
|
||||
})
|
||||
if err != nil || len(users) == 0 {
|
||||
return strconv.Itoa(userID)
|
||||
}
|
||||
|
||||
user := users[0]
|
||||
return fmt.Sprintf("%s %s", user.FirstName, user.LastName)
|
||||
}
|
||||
|
||||
func (c *VKChannel) processAttachments(attachments []object.MessagesMessageAttachment) string {
|
||||
var parts []string
|
||||
|
||||
for _, att := range attachments {
|
||||
switch att.Type {
|
||||
case "photo":
|
||||
parts = append(parts, "[photo]")
|
||||
case "video":
|
||||
parts = append(parts, "[video]")
|
||||
case "audio":
|
||||
parts = append(parts, "[audio]")
|
||||
case "doc":
|
||||
if att.Doc.Title != "" {
|
||||
parts = append(parts, fmt.Sprintf("[document: %s]", att.Doc.Title))
|
||||
} else {
|
||||
parts = append(parts, "[document]")
|
||||
}
|
||||
case "audio_message":
|
||||
parts = append(parts, "[voice]")
|
||||
case "sticker":
|
||||
parts = append(parts, "[sticker]")
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func (c *VKChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
package vk
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestNewVKChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing group_id", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during creation: %v", err)
|
||||
}
|
||||
if ch.Name() != "vk" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config with group_id", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "vk" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with allow_from", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
AllowFrom: []string{"123456789"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !ch.IsAllowedSender(bus.SenderInfo{PlatformID: "123456789"}) {
|
||||
t.Error("user 123456789 should be allowed")
|
||||
}
|
||||
if ch.IsAllowedSender(bus.SenderInfo{PlatformID: "999999999"}) {
|
||||
t.Error("user 999999999 should not be allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with group_trigger", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
GroupTrigger: config.GroupTriggerConfig{
|
||||
MentionOnly: false,
|
||||
Prefixes: []string{"/bot", "!bot"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "vk" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestVKChannel_MaxMessageLength(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
maxLen := ch.MaxMessageLength()
|
||||
if maxLen != 4000 {
|
||||
t.Errorf("MaxMessageLength() = %d, want 4000", maxLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVKChannel_SplitMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
maxLen int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "short message",
|
||||
content: "hello",
|
||||
maxLen: 4000,
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "exact length",
|
||||
content: string(make([]byte, 4000)),
|
||||
maxLen: 4000,
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "needs split",
|
||||
content: string(make([]byte, 5000)),
|
||||
maxLen: 4000,
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
content: "",
|
||||
maxLen: 4000,
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := channels.SplitMessage(tt.content, tt.maxLen)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("SplitMessage() got %d parts, want %d parts", len(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVKChannel_ProcessAttachments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attachments []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty attachments",
|
||||
attachments: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "photo attachment",
|
||||
attachments: []string{"photo"},
|
||||
want: "[photo]",
|
||||
},
|
||||
{
|
||||
name: "video attachment",
|
||||
attachments: []string{"video"},
|
||||
want: "[video]",
|
||||
},
|
||||
{
|
||||
name: "audio attachment",
|
||||
attachments: []string{"audio"},
|
||||
want: "[audio]",
|
||||
},
|
||||
{
|
||||
name: "document attachment",
|
||||
attachments: []string{"doc"},
|
||||
want: "[doc]",
|
||||
},
|
||||
{
|
||||
name: "sticker attachment",
|
||||
attachments: []string{"sticker"},
|
||||
want: "[sticker]",
|
||||
},
|
||||
{
|
||||
name: "audio_message attachment",
|
||||
attachments: []string{"audio_message"},
|
||||
want: "[voice]",
|
||||
},
|
||||
{
|
||||
name: "multiple attachments",
|
||||
attachments: []string{"photo", "video", "audio"},
|
||||
want: "[photo] [video] [audio]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result string
|
||||
for i, att := range tt.attachments {
|
||||
if i > 0 {
|
||||
result += " "
|
||||
}
|
||||
if att == "audio_message" {
|
||||
result += "[voice]"
|
||||
} else {
|
||||
result += "[" + att + "]"
|
||||
}
|
||||
}
|
||||
if result != tt.want {
|
||||
t.Errorf("processAttachments() = %q, want %q", result, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVKChannel_VoiceCapabilities(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := &config.Config{
|
||||
Channels: config.ChannelsConfig{
|
||||
VK: config.VKConfig{
|
||||
Enabled: true,
|
||||
Token: *config.NewSecureString("test_token"),
|
||||
GroupID: 123456789,
|
||||
},
|
||||
},
|
||||
}
|
||||
ch, err := NewVKChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
caps := ch.VoiceCapabilities()
|
||||
if !caps.ASR {
|
||||
t.Error("VoiceCapabilities().ASR should be true")
|
||||
}
|
||||
if !caps.TTS {
|
||||
t.Error("VoiceCapabilities().TTS should be true")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package channels
|
||||
|
||||
// VoiceCapabilities describes whether ASR (speech-to-text) and TTS (text-to-speech)
|
||||
// are available for a channel under the current configuration.
|
||||
type VoiceCapabilities struct {
|
||||
ASR bool
|
||||
TTS bool
|
||||
}
|
||||
|
||||
// VoiceCapabilityProvider is an optional interface for channels that want to
|
||||
// explicitly declare their ASR/TTS support.
|
||||
type VoiceCapabilityProvider interface {
|
||||
VoiceCapabilities() VoiceCapabilities
|
||||
}
|
||||
|
||||
// Deprecated: Channels should implement VoiceCapabilityProvider instead.
|
||||
// To be removed once all existing capable channels conform to the interface.
|
||||
var asrCapableChannels = map[string]bool{
|
||||
"discord": true,
|
||||
"telegram": true,
|
||||
"matrix": true,
|
||||
"qq": true,
|
||||
"weixin": true,
|
||||
"line": true,
|
||||
"feishu": true,
|
||||
"onebot": true,
|
||||
}
|
||||
|
||||
// DetectVoiceCapabilities returns ASR/TTS availability for a channel, gated by
|
||||
// whether providers are configured.
|
||||
func DetectVoiceCapabilities(channelName string, ch Channel, asrAvailable bool, ttsAvailable bool) VoiceCapabilities {
|
||||
if ch == nil {
|
||||
return VoiceCapabilities{}
|
||||
}
|
||||
|
||||
if vcp, ok := ch.(VoiceCapabilityProvider); ok {
|
||||
caps := vcp.VoiceCapabilities()
|
||||
if !asrAvailable {
|
||||
caps.ASR = false
|
||||
}
|
||||
if !ttsAvailable {
|
||||
caps.TTS = false
|
||||
}
|
||||
return caps
|
||||
}
|
||||
|
||||
caps := VoiceCapabilities{}
|
||||
if asrAvailable {
|
||||
caps.ASR = asrCapableChannels[channelName]
|
||||
}
|
||||
if ttsAvailable {
|
||||
if _, ok := ch.(MediaSender); ok {
|
||||
caps.TTS = true
|
||||
}
|
||||
}
|
||||
|
||||
return caps
|
||||
}
|
||||
@@ -414,3 +414,8 @@ func (c *WeixinChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// VoiceCapabilities returns the voice capabilities of the channel.
|
||||
func (c *WeixinChannel) VoiceCapabilities() channels.VoiceCapabilities {
|
||||
return channels.VoiceCapabilities{ASR: true, TTS: true}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user