Merge branch 'main' into refactor-inbound-context-routing-session

# Conflicts:
#	pkg/agent/eventbus_test.go
#	pkg/agent/loop.go
#	pkg/bus/bus.go
#	pkg/bus/types.go
#	pkg/channels/pico/pico.go
#	pkg/channels/telegram/telegram.go
#	pkg/config/config.go
#	web/backend/api/session.go
#	web/backend/api/session_test.go
This commit is contained in:
Hoshina
2026-04-07 21:41:02 +08:00
282 changed files with 33064 additions and 3251 deletions
+170
View File
@@ -3,6 +3,7 @@ package discord
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
@@ -14,6 +15,8 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/audio"
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
@@ -42,6 +45,15 @@ type DiscordChannel struct {
typingMu sync.Mutex
typingStop map[string]chan struct{} // chatID → stop signal
botUserID string // stored for mention checking
bus *bus.MessageBus
tts tts.TTSProvider
voiceMu sync.RWMutex
voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID
// TTS interruption: cancel active playback when user speaks
ttsMu sync.Mutex
cancelTTS context.CancelFunc
ttsPlayID uint64
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -73,6 +85,8 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
config: cfg,
ctx: context.Background(),
typingStop: make(map[string]chan struct{}),
bus: bus,
voiceSSRC: make(map[string]map[uint32]string),
}, nil
}
@@ -90,6 +104,8 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
c.session.AddHandler(c.handleMessage)
go c.listenVoiceControl(c.ctx)
if err := c.session.Open(); err != nil {
return fmt.Errorf("failed to open discord session: %w", err)
}
@@ -142,6 +158,25 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s
return nil, nil
}
if c.tts != nil {
if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" {
if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil {
// Cancel any previous TTS playback
c.ttsMu.Lock()
if c.cancelTTS != nil {
c.cancelTTS()
}
ttsCtx, ttsCancel := context.WithCancel(c.ctx)
c.ttsPlayID++
playID := c.ttsPlayID
c.cancelTTS = ttsCancel
c.ttsMu.Unlock()
go c.playTTS(ttsCtx, vc, msg.Content, playID)
}
}
}
msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID)
if err != nil {
return nil, err
@@ -359,6 +394,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
if c.handleVoiceCommand(s, m) {
return
}
content := m.Content
// In guild (group) channels, apply unified group trigger filtering
@@ -642,3 +681,134 @@ func (c *DiscordChannel) stripBotMention(text string) string {
text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
return strings.TrimSpace(text)
}
func (c *DiscordChannel) listenVoiceControl(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case ctrl, ok := <-c.bus.VoiceControlsChan():
if !ok {
return
}
if ctrl.Type == "command" && ctrl.Action == "leave" {
if strings.HasPrefix(ctrl.SessionID, "discord_vc_") {
guildID := strings.TrimPrefix(ctrl.SessionID, "discord_vc_")
vc, exists := c.session.VoiceConnections[guildID]
if exists && vc != nil {
vc.Disconnect(ctx)
}
}
}
}
}
}
func (c *DiscordChannel) playTTS(ctx context.Context, vc *discordgo.VoiceConnection, text string, playID uint64) {
// Capture the cancel func associated with this playback (if any).
// Clear cancelTTS when playback finishes (normal or interrupted),
// but only if it still refers to this playback's cancel func.
defer func() {
c.ttsMu.Lock()
if c.ttsPlayID == playID {
c.cancelTTS = nil
}
c.ttsMu.Unlock()
}()
sentences := audio.SplitSentences(text)
if len(sentences) == 0 {
return
}
logger.InfoCF("discord", "Starting streamed TTS", map[string]any{"sentences": len(sentences)})
// Pipeline: prefetch next sentence's audio while playing current
type ttResult struct {
stream io.ReadCloser
err error
}
var prefetch chan ttResult
// Ensure any in-flight prefetch is drained on exit to prevent stream leaks,
// but avoid blocking indefinitely if the prefetch goroutine is stuck or never sends.
defer func() {
if prefetch != nil {
select {
case result := <-prefetch:
if result.stream != nil {
result.stream.Close()
}
case <-time.After(100 * time.Millisecond):
// Timed out waiting for a prefetched result; avoid blocking on exit.
}
}
}()
for i, sentence := range sentences {
// Check for cancellation (interruption)
select {
case <-ctx.Done():
logger.InfoCF("discord", "TTS interrupted", map[string]any{"at_sentence": i})
return
default:
}
// Start prefetching the NEXT sentence while we process the current one
var nextPrefetch chan ttResult
if i+1 < len(sentences) {
nextPrefetch = make(chan ttResult, 1)
nextSentence := sentences[i+1]
go func() {
s, e := c.tts.Synthesize(ctx, nextSentence)
nextPrefetch <- ttResult{s, e}
}()
}
// Get the current sentence's audio
var stream io.ReadCloser
var err error
if prefetch != nil {
// Use prefetched result from previous iteration, but be responsive to cancellation.
var result ttResult
select {
case result = <-prefetch:
stream, err = result.stream, result.err
case <-ctx.Done():
// Context canceled while waiting for prefetched audio; abort playback.
logger.InfoCF(
"discord",
"TTS interrupted while waiting for prefetched audio",
map[string]any{"at_sentence": i},
)
return
}
} else {
// First sentence: synthesize directly
stream, err = c.tts.Synthesize(ctx, sentence)
}
if err != nil {
if stream != nil {
stream.Close()
}
logger.ErrorCF("discord", "TTS synthesize failed", map[string]any{"error": err.Error(), "sentence": i})
prefetch = nextPrefetch
continue
}
if err := streamOggOpusToDiscord(ctx, vc, stream); err != nil {
logger.ErrorCF("discord", "TTS playback failed", map[string]any{"error": err.Error(), "sentence": i})
}
stream.Close()
prefetch = nextPrefetch
}
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *DiscordChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+6 -1
View File
@@ -1,6 +1,7 @@
package discord
import (
"github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
@@ -8,6 +9,10 @@ import (
func init() {
channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewDiscordChannel(cfg.Channels.Discord, b)
ch, err := NewDiscordChannel(cfg.Channels.Discord, b)
if err == nil {
ch.tts = tts.DetectTTS(cfg)
}
return ch, err
})
}
+314
View File
@@ -0,0 +1,314 @@
package discord
import (
"context"
"fmt"
"io"
"time"
"github.com/bwmarrin/discordgo"
"github.com/sipeed/picoclaw/pkg/audio"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
func (c *DiscordChannel) setVoiceUserID(guildID string, ssrc uint32, userID string) {
if userID == "" {
return
}
c.voiceMu.Lock()
defer c.voiceMu.Unlock()
ssrcMap, ok := c.voiceSSRC[guildID]
if !ok {
ssrcMap = make(map[uint32]string)
c.voiceSSRC[guildID] = ssrcMap
}
ssrcMap[ssrc] = userID
}
func (c *DiscordChannel) voiceUserID(guildID string, ssrc uint32) string {
c.voiceMu.RLock()
defer c.voiceMu.RUnlock()
ssrcMap, ok := c.voiceSSRC[guildID]
if !ok {
return ""
}
return ssrcMap[ssrc]
}
func (c *DiscordChannel) handleVoiceCommand(s *discordgo.Session, m *discordgo.MessageCreate) bool {
if m.Content == "!vc join" {
vs, err := s.State.VoiceState(m.GuildID, m.Author.ID)
if err != nil || vs == nil {
if _, sendErr := s.ChannelMessageSend(
m.ChannelID,
"You need to be in a voice channel first!",
); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice channel requirement message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
return true
}
logger.InfoCF("discord", "Joining voice channel", map[string]any{"channel": vs.ChannelID})
vc, err := s.ChannelVoiceJoin(c.ctx, m.GuildID, vs.ChannelID, false, false)
if err != nil {
if _, sendErr := s.ChannelMessageSend(
m.ChannelID,
fmt.Sprintf("Failed to join voice channel: %v", err),
); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice join error message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
return true
}
go c.receiveVoice(vc, m.GuildID, m.ChannelID)
if _, sendErr := s.ChannelMessageSend(
m.ChannelID,
"Joined Voice Channel! Listening for audio...",
); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice join success message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
return true
} else if m.Content == "!vc leave" {
vc, exists := s.VoiceConnections[m.GuildID]
if exists && vc != nil {
if err := vc.Disconnect(c.ctx); err != nil {
logger.InfoCF("discord", "Failed to disconnect from voice channel", map[string]any{
"guild": m.GuildID,
"error": err,
})
}
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Left Voice Channel."); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice leave success message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
} else {
if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Not in a voice channel."); sendErr != nil {
logger.InfoCF("discord", "Failed to send voice not-in-channel message", map[string]any{
"channel": m.ChannelID,
"error": sendErr,
})
}
}
return true
}
return false
}
func VoiceReceiveActive(vc *discordgo.VoiceConnection) bool {
return vc != nil && vc.OpusRecv != nil
}
func streamOggOpusToDiscord(ctx context.Context, vc *discordgo.VoiceConnection, r io.Reader) (retErr error) {
// Recover from panic if vc.OpusSend is closed mid-send (e.g. on disconnect)
defer func() {
if rec := recover(); rec != nil {
retErr = fmt.Errorf("voice connection closed during playback")
logger.RecoverPanicNoExit(rec)
}
}()
// Wait for the speaking transition to register
vc.Speaking(true)
defer vc.Speaking(false)
return audio.DecodeOggOpus(r, func(frame []byte) error {
select {
case <-ctx.Done():
return ctx.Err()
case vc.OpusSend <- frame:
return nil
}
})
}
func (c *DiscordChannel) receiveVoice(vc *discordgo.VoiceConnection, guildID string, chatID string) {
logger.InfoCF("discord", "Started listening for voice", map[string]any{"guild": guildID})
vc.AddHandler(func(_ *discordgo.VoiceConnection, vs *discordgo.VoiceSpeakingUpdate) {
if vs == nil {
return
}
c.setVoiceUserID(guildID, uint32(vs.SSRC), vs.UserID)
})
defer func() {
c.voiceMu.Lock()
delete(c.voiceSSRC, guildID)
c.voiceMu.Unlock()
}()
go func(ctx context.Context, vc *discordgo.VoiceConnection) {
// Recover from potential panics if OpusSend is closed mid-send.
defer func() {
if rec := recover(); rec != nil {
logger.WarnCF("discord", "Recovered from panic while sending wake-up frames", map[string]any{
"error": rec,
"guild": guildID,
})
}
}()
// If the voice connection or OpusSend are not available, nothing to do.
if vc == nil || vc.OpusSend == nil {
return
}
time.Sleep(250 * time.Millisecond) // Wait a bit for connection to settle
// Abort if the context has already been canceled.
select {
case <-ctx.Done():
return
default:
}
vc.Speaking(true)
defer vc.Speaking(false)
silenceFrame := []byte{0xF8, 0xFF, 0xFE}
for i := 0; i < 5; i++ {
select {
case <-ctx.Done():
return
case vc.OpusSend <- silenceFrame:
}
time.Sleep(20 * time.Millisecond)
}
logger.DebugCF("discord", "Sent wake-up silence frames", map[string]any{"guild": guildID})
}(c.ctx, vc)
sessionID := fmt.Sprintf("discord_vc_%s", guildID)
c.bus.PublishVoiceControl(c.ctx, bus.VoiceControl{
SessionID: sessionID,
Type: "state",
Action: "listening",
})
var sequence uint64 = 0
var interruptCount int
var lastInterruptAt time.Time
for {
select {
case <-c.ctx.Done():
return
case p, ok := <-vc.OpusRecv:
if !ok {
logger.InfoCF("discord", "Voice channel closed", map[string]any{"guild": guildID})
// Cancel any TTS that may still be playing
c.ttsMu.Lock()
if c.cancelTTS != nil {
c.cancelTTS()
c.cancelTTS = nil
}
c.ttsMu.Unlock()
return
}
if p == nil {
logger.DebugCF("discord", "Received nil Opus packet", nil)
continue
}
if len(p.Opus) == 0 {
logger.DebugCF("discord", "Received empty Opus packet", map[string]any{
"seq": p.Sequence,
"ssrc": p.SSRC,
})
continue
}
logger.DebugCF("discord", "Received Opus packet", map[string]any{
"seq": p.Sequence,
"len": len(p.Opus),
"ssrc": p.SSRC,
})
// Interruption detection: if user sends voice while TTS is playing,
// cancel TTS after a short debounce (3 packets in 200ms)
now := time.Now()
if now.Sub(lastInterruptAt) > 500*time.Millisecond {
interruptCount = 0
}
interruptCount++
lastInterruptAt = now
if interruptCount >= 3 {
c.ttsMu.Lock()
if c.cancelTTS != nil {
c.cancelTTS()
c.cancelTTS = nil
logger.InfoCF("discord", "TTS interrupted by user voice", nil)
}
c.ttsMu.Unlock()
interruptCount = 0
}
userID := c.voiceUserID(guildID, p.SSRC)
if userID == "" {
logger.DebugCF("discord", "Dropping voice packet without user mapping", map[string]any{
"ssrc": p.SSRC,
"guild": guildID,
})
continue
}
sender := bus.SenderInfo{
Platform: "discord",
PlatformID: userID,
CanonicalID: identity.BuildCanonicalID("discord", userID),
}
if !c.IsAllowedSender(sender) {
logger.DebugCF("discord", "Voice packet rejected by allowlist", map[string]any{
"user_id": userID,
"guild": guildID,
})
continue
}
sequence++
chunk := bus.AudioChunk{
SessionID: sessionID,
SpeakerID: userID,
ChatID: chatID,
Channel: "discord",
Sequence: sequence,
Timestamp: p.Timestamp,
SampleRate: 48000,
Channels: 2,
Format: "opus",
Data: p.Opus,
}
ctx, cancel := context.WithTimeout(c.ctx, 100*time.Millisecond)
err := c.bus.PublishAudioChunk(ctx, chunk)
cancel()
if err != nil {
logger.ErrorCF("discord", "Failed to publish audio chunk", map[string]any{
"guild": guildID,
"sessionID": sessionID,
"sequence": sequence,
"error": err.Error(),
})
}
}
}
}
+7
View File
@@ -6,6 +6,8 @@ import (
"strings"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
"github.com/sipeed/picoclaw/pkg/channels"
)
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
@@ -145,3 +147,8 @@ func extractImageKeysRecursive(v any, feishuKeys, externalURLs *[]string) {
}
}
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *FeishuChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+5
View File
@@ -696,3 +696,8 @@ func (c *LINEChannel) downloadContent(messageID, filename string) string {
},
})
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *LINEChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+17
View File
@@ -444,6 +444,23 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
m.initChannel("irc", "IRC")
}
if channels.VK.Enabled && channels.VK.Token.String() != "" && channels.VK.GroupID != 0 {
m.initChannel("vk", "VK")
}
if channels.TeamsWebhook.Enabled && len(channels.TeamsWebhook.Webhooks) > 0 {
hasValidTarget := false
for _, target := range channels.TeamsWebhook.Webhooks {
if target.WebhookURL.String() != "" {
hasValidTarget = true
break
}
}
if hasValidTarget {
m.initChannel("teams_webhook", "Teams Webhook")
}
}
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
"enabled_channels": len(m.channels),
})
+16
View File
@@ -62,6 +62,13 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
value["app_secret"] = ch.Feishu.AppSecret.String()
value["encrypt_key"] = ch.Feishu.EncryptKey.String()
value["verification_token"] = ch.Feishu.VerificationToken.String()
case "teams_webhook":
// Expose webhook URLs for hash computation (they contain secrets)
webhooks := make(map[string]string)
for name, target := range ch.TeamsWebhook.Webhooks {
webhooks[name] = target.WebhookURL.String()
}
value["webhooks"] = webhooks
}
}
@@ -166,4 +173,13 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
newcfg.Feishu.EncryptKey = old.Feishu.EncryptKey
newcfg.Feishu.VerificationToken = old.Feishu.VerificationToken
}
if newcfg.TeamsWebhook.Enabled {
// Copy SecureString webhook URLs from old config
for name, oldTarget := range old.TeamsWebhook.Webhooks {
if newTarget, ok := newcfg.TeamsWebhook.Webhooks[name]; ok {
newTarget.WebhookURL = oldTarget.WebhookURL
newcfg.TeamsWebhook.Webhooks[name] = newTarget
}
}
}
}
+110 -2
View File
@@ -19,6 +19,8 @@ import (
type mockChannel struct {
BaseChannel
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
startFn func(ctx context.Context) error
stopFn func(ctx context.Context) error
sentMessages []bus.OutboundMessage
placeholdersSent int
editedMessages int
@@ -33,8 +35,19 @@ func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
return nil, m.sendFn(ctx, msg)
}
func (m *mockChannel) Start(ctx context.Context) error { return nil }
func (m *mockChannel) Stop(ctx context.Context) error { return nil }
func (m *mockChannel) Start(ctx context.Context) error {
if m.startFn != nil {
return m.startFn(ctx)
}
return nil
}
func (m *mockChannel) Stop(ctx context.Context) error {
if m.stopFn != nil {
return m.stopFn(ctx)
}
return nil
}
func (m *mockChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
m.placeholdersSent++
@@ -86,6 +99,101 @@ func newTestManager() *Manager {
return &Manager{
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
bus: bus.NewMessageBus(),
}
}
func TestStartAll_AllChannelsFail_ReturnsJoinedError(t *testing.T) {
m := newTestManager()
errA := errors.New("channel-a start failed")
errB := errors.New("channel-b start failed")
m.channels["a"] = &mockChannel{
startFn: func(_ context.Context) error { return errA },
}
m.channels["b"] = &mockChannel{
startFn: func(_ context.Context) error { return errB },
}
err := m.StartAll(t.Context())
if err == nil {
t.Fatal("expected StartAll to fail when all channels fail")
}
if !strings.Contains(err.Error(), "failed to start any enabled channels") {
t.Fatalf("unexpected error: %v", err)
}
if !errors.Is(err, errA) {
t.Fatalf("expected error to wrap errA, got: %v", err)
}
if !errors.Is(err, errB) {
t.Fatalf("expected error to wrap errB, got: %v", err)
}
if len(m.workers) != 0 {
t.Fatalf("expected no workers on full startup failure, got %d", len(m.workers))
}
if m.dispatchTask != nil {
t.Fatal("expected dispatch task to be cleared on full startup failure")
}
}
func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
m := newTestManager()
errBad := errors.New("bad channel start failed")
processed := make(chan struct{}, 1)
m.channels["good"] = &mockChannel{
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
if msg.Channel == "good" {
select {
case processed <- struct{}{}:
default:
}
}
return nil
},
}
m.channels["bad"] = &mockChannel{
startFn: func(_ context.Context) error { return errBad },
}
err := m.StartAll(t.Context())
if err != nil {
t.Fatalf("expected StartAll to succeed with partial channel failures, got: %v", err)
}
if len(m.workers) != 1 {
t.Fatalf("expected exactly 1 active worker, got %d", len(m.workers))
}
if _, ok := m.workers["good"]; !ok {
t.Fatal("expected worker for successful channel 'good'")
}
if _, ok := m.workers["bad"]; ok {
t.Fatal("did not expect worker for failed channel 'bad'")
}
if m.dispatchTask == nil {
t.Fatal("expected dispatch task to run when at least one channel starts")
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: "good",
ChatID: "chat-1",
Content: "hello",
}); err != nil {
t.Fatalf("PublishOutbound() error = %v", err)
}
select {
case <-processed:
// worker processed outbound message as expected
case <-time.After(2 * time.Second):
t.Fatal("expected successful channel worker to process outbound message")
}
stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer stopCancel()
if err := m.StopAll(stopCtx); err != nil {
t.Fatalf("StopAll() error = %v", err)
}
}
+5
View File
@@ -1300,3 +1300,8 @@ func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp.
cleaned = strings.TrimLeft(cleaned, ",:; ")
return strings.TrimSpace(cleaned)
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *MatrixChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+5
View File
@@ -1117,3 +1117,8 @@ func truncate(s string, n int) string {
}
return string(runes[:n]) + "..."
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *OneBotChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+54
View File
@@ -262,3 +262,57 @@ func TestSend_ClosedConnection(t *testing.T) {
ch.Stop(ctx)
}
func TestParseInlineImageMedia_Valid(t *testing.T) {
media, err := parseInlineImageMedia(map[string]any{
"media": []any{
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+X2ioAAAAASUVORK5CYII=",
},
})
if err != nil {
t.Fatalf("parseInlineImageMedia() error = %v", err)
}
if len(media) != 1 {
t.Fatalf("len(media) = %d, want 1", len(media))
}
}
func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewPicoChannel(config.PicoConfig{
Token: *config.NewSecureString("test-token"),
}, mb)
if err != nil {
t.Fatalf("NewPicoChannel() error = %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := ch.Start(ctx); err != nil {
t.Fatalf("Start() error = %v", err)
}
defer ch.Stop(ctx)
pc := &picoConn{id: "conn-1", sessionID: "sess-1"}
ch.handleMessageSend(pc, PicoMessage{
ID: "msg-1",
Payload: map[string]any{
"media": []any{
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+X2ioAAAAASUVORK5CYII=",
},
},
})
select {
case msg := <-mb.InboundChan():
if msg.Content != "" {
t.Fatalf("msg.Content = %q, want empty", msg.Content)
}
if len(msg.Media) != 1 || !strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
t.Fatalf("msg.Media = %#v, want inline image payload", msg.Media)
}
case <-ctx.Done():
t.Fatal("timed out waiting for inbound media message")
}
}
+10 -11
View File
@@ -39,19 +39,18 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
}
}
// newError creates an error PicoMessage.
func newError(code, message string) PicoMessage {
return newMessage(TypeError, map[string]any{
func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage {
payload := map[string]any{
"code": code,
"message": message,
})
}
func newErrorWithPayload(code, message string, payload map[string]any) PicoMessage {
if payload == nil {
payload = map[string]any{}
}
payload["code"] = code
payload["message"] = message
for key, value := range extra {
payload[key] = value
}
return newMessage(TypeError, payload)
}
// newError creates an error PicoMessage.
func newError(code, message string) PicoMessage {
return newErrorWithPayload(code, message, nil)
}
+5
View File
@@ -1003,3 +1003,8 @@ func sanitizeURLs(text string) string {
return scheme + domain + path
})
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *QQChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+13
View File
@@ -0,0 +1,13 @@
package teamswebhook
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func init() {
channels.RegisterFactory("teams_webhook", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewTeamsWebhookChannel(cfg.Channels.TeamsWebhook, b)
})
}
+422
View File
@@ -0,0 +1,422 @@
package teamswebhook
import (
"context"
"fmt"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
"github.com/atc0005/go-teams-notify/v2/adaptivecard"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// statusCodeRe extracts HTTP status codes from error messages like "401 Unauthorized".
var statusCodeRe = regexp.MustCompile(`\b([45]\d{2})\b`)
// markdownTableRe matches a markdown table block (header + separator + rows).
// It captures the entire table including all rows.
var markdownTableRe = regexp.MustCompile(`(?m)^(\|[^\n]+\|)\n(\|[-:\|\s]+\|)\n((?:\|[^\n]+\|\n?)+)`)
// teamsMessageSender abstracts the Teams client for testability.
type teamsMessageSender interface {
SendWithContext(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
}
// classifyTeamsError extracts HTTP status code from error message and classifies it.
// The go-teams-notify library returns errors like "error on notification: 401 Unauthorized, ...".
// This allows proper retry behavior: 4xx errors are permanent, 5xx are temporary.
func classifyTeamsError(err error) error {
if err == nil {
return nil
}
errMsg := err.Error()
if matches := statusCodeRe.FindStringSubmatch(errMsg); len(matches) > 1 {
if statusCode, parseErr := strconv.Atoi(matches[1]); parseErr == nil {
return channels.ClassifySendError(statusCode, err)
}
}
// Fallback: treat as temporary network error (retryable)
return channels.ClassifyNetError(err)
}
// TeamsWebhookChannel is an output-only channel that sends messages
// to Microsoft Teams via Power Automate workflow webhooks.
// Multiple webhook targets can be configured and selected via ChatID.
type TeamsWebhookChannel struct {
*channels.BaseChannel
config config.TeamsWebhookConfig
client teamsMessageSender
}
// NewTeamsWebhookChannel creates a new Teams webhook channel.
func NewTeamsWebhookChannel(
cfg config.TeamsWebhookConfig,
bus *bus.MessageBus,
) (*TeamsWebhookChannel, error) {
if len(cfg.Webhooks) == 0 {
return nil, fmt.Errorf("teams_webhook: at least one webhook target is required")
}
// Require "default" webhook target
if _, hasDefault := cfg.Webhooks["default"]; !hasDefault {
return nil, fmt.Errorf("teams_webhook: a 'default' webhook target is required")
}
// Validate all webhook targets have valid HTTPS URLs
for name, target := range cfg.Webhooks {
webhookURL := target.WebhookURL.String()
if webhookURL == "" {
return nil, fmt.Errorf("teams_webhook: webhook %q has empty webhook_url", name)
}
parsed, err := url.Parse(webhookURL)
if err != nil {
return nil, fmt.Errorf("teams_webhook: webhook %q has invalid URL: %w", name, err)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return nil, fmt.Errorf("teams_webhook: webhook %q must use HTTPS (got %q)", name, parsed.Scheme)
}
}
base := channels.NewBaseChannel(
"teams_webhook",
cfg,
bus,
[]string{
"*",
}, // Output-only channel; "*" suppresses misleading "allows EVERYONE" audit warning
channels.WithMaxMessageLength(24000), // Power Automate webhook payload limit is 28KB
)
client := goteamsnotify.NewTeamsClient()
return &TeamsWebhookChannel{
BaseChannel: base,
config: cfg,
client: client,
}, nil
}
// Start initializes the channel. For output-only channels, this is a no-op.
func (c *TeamsWebhookChannel) Start(ctx context.Context) error {
targets := make([]string, 0, len(c.config.Webhooks))
for name := range c.config.Webhooks {
targets = append(targets, name)
}
sort.Strings(targets)
logger.InfoCF("teams_webhook", "Starting Teams webhook channel (output-only)", map[string]any{
"targets": targets,
})
c.SetRunning(true)
return nil
}
// Stop shuts down the channel.
func (c *TeamsWebhookChannel) Stop(ctx context.Context) error {
logger.InfoC("teams_webhook", "Stopping Teams webhook channel")
c.SetRunning(false)
return nil
}
// Send delivers a message to the specified Teams webhook target.
// The target is selected by msg.ChatID which must match a key in the webhooks map.
func (c *TeamsWebhookChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// Look up webhook target by ChatID, fall back to "default" if empty or unknown
targetName := msg.ChatID
if targetName == "" {
targetName = "default"
}
target, ok := c.config.Webhooks[targetName]
if !ok {
// Log warning and fall back to default target
logger.WarnCF("teams_webhook", "Unknown target, falling back to default", map[string]any{
"requested": msg.ChatID,
"using": "default",
})
target = c.config.Webhooks["default"]
}
// Build an Adaptive Card for rich formatting
card, err := c.buildAdaptiveCard(msg, target)
if err != nil {
return nil, fmt.Errorf("teams_webhook: failed to build card: %w", err)
}
// Create the message with the card
teamsMsg, err := adaptivecard.NewMessageFromCard(card)
if err != nil {
return nil, fmt.Errorf("teams_webhook: failed to create message: %w", err)
}
// Send to Teams
if err := c.client.SendWithContext(ctx, target.WebhookURL.String(), teamsMsg); err != nil {
// Log without raw error to avoid leaking webhook URL (embedded in net/http errors)
logger.ErrorCF("teams_webhook", "Failed to send message to Teams webhook", map[string]any{
"target": msg.ChatID,
})
// Classify error based on status code extracted from error message.
// The go-teams-notify library includes status in errors like "401 Unauthorized".
// Use ClassifySendError for proper retry behavior (4xx = permanent, 5xx = temporary).
classifiedErr := classifyTeamsError(err)
return nil, fmt.Errorf("teams_webhook: send failed: %w", classifiedErr)
}
logger.DebugCF("teams_webhook", "Message sent successfully", map[string]any{
"target": msg.ChatID,
})
return nil, nil
}
// buildAdaptiveCard creates a formatted Adaptive Card from the outbound message.
// It detects markdown tables and converts them to native Adaptive Card Table elements,
// since TextBlocks only support a limited markdown subset (no tables).
func (c *TeamsWebhookChannel) buildAdaptiveCard(
msg bus.OutboundMessage,
target config.TeamsWebhookTarget,
) (adaptivecard.Card, error) {
card := adaptivecard.NewCard()
card.Type = adaptivecard.TypeAdaptiveCard
// Set full width for Teams rendering
card.MSTeams.Width = "Full"
// Add title if configured on the target
title := target.Title
if title == "" {
title = "PicoClaw Notification"
}
titleBlock := adaptivecard.NewTextBlock(title, true)
titleBlock.Size = adaptivecard.SizeLarge
titleBlock.Weight = adaptivecard.WeightBolder
titleBlock.Style = adaptivecard.TextBlockStyleHeading
if err := card.AddElement(false, titleBlock); err != nil {
return card, err
}
content := msg.Content
if content == "" {
content = "(empty message)"
}
// Split content into text segments and tables
// TextBlocks support: bold, italic, bullet/numbered lists, links
// TextBlocks do NOT support: headers, tables, images
segments := splitContentWithTables(content)
for _, seg := range segments {
if seg.isTable {
// Convert markdown table to Adaptive Card Table element
tableElement, err := parseMarkdownTable(seg.content)
if err != nil {
// Fallback: render as preformatted text if parsing fails
logger.WarnCF("teams_webhook", "Failed to parse markdown table, using fallback", map[string]any{
"error": err.Error(),
})
block := adaptivecard.NewTextBlock("```\n"+seg.content+"\n```", true)
block.Wrap = true
if err := card.AddElement(false, block); err != nil {
return card, err
}
continue
}
if err := card.AddElement(false, tableElement); err != nil {
return card, err
}
} else {
// Regular text content
text := strings.TrimSpace(seg.content)
if text == "" {
continue
}
block := adaptivecard.NewTextBlock(text, true)
block.Wrap = true
if err := card.AddElement(false, block); err != nil {
return card, err
}
}
}
return card, nil
}
// contentSegment represents either a text block or a table in the message content.
type contentSegment struct {
content string
isTable bool
}
// splitContentWithTables splits content into alternating text and table segments.
func splitContentWithTables(content string) []contentSegment {
var segments []contentSegment
matches := markdownTableRe.FindAllStringSubmatchIndex(content, -1)
if len(matches) == 0 {
// No tables found, return entire content as text
return []contentSegment{{content: content, isTable: false}}
}
lastEnd := 0
for _, match := range matches {
// Text before this table
if match[0] > lastEnd {
segments = append(segments, contentSegment{
content: content[lastEnd:match[0]],
isTable: false,
})
}
// The table itself
segments = append(segments, contentSegment{
content: content[match[0]:match[1]],
isTable: true,
})
lastEnd = match[1]
}
// Text after the last table
if lastEnd < len(content) {
segments = append(segments, contentSegment{
content: content[lastEnd:],
isTable: false,
})
}
return segments
}
// parseMarkdownTable converts a markdown table string to an Adaptive Card Table element.
func parseMarkdownTable(tableStr string) (adaptivecard.Element, error) {
lines := strings.Split(strings.TrimSpace(tableStr), "\n")
if len(lines) < 2 {
return adaptivecard.Element{}, fmt.Errorf("table must have at least header and separator rows")
}
// Track header content length per column for width calculation
var headerLengths []int
// Parse all rows (header + data rows, skip separator)
var allRows [][]adaptivecard.TableCell
for i, line := range lines {
// Skip separator row (contains only |, -, :, and spaces)
if i == 1 && isSeparatorRow(line) {
continue
}
cells := parseTableRow(line)
if len(cells) == 0 {
continue
}
var tableCells []adaptivecard.TableCell
for _, cellText := range cells {
trimmedText := strings.TrimSpace(cellText)
// Use header row (first row) to determine column widths
if i == 0 {
headerLengths = append(headerLengths, len(trimmedText))
}
textBlock := adaptivecard.Element{
Type: adaptivecard.TypeElementTextBlock,
Text: trimmedText,
Wrap: true,
}
cell := adaptivecard.TableCell{
Type: adaptivecard.TypeTableCell,
Items: []*adaptivecard.Element{&textBlock},
}
tableCells = append(tableCells, cell)
}
allRows = append(allRows, tableCells)
}
if len(allRows) == 0 {
return adaptivecard.Element{}, fmt.Errorf("no valid rows found in table")
}
// Create table with first row as headers
firstRowAsHeaders := true
showGridLines := true
table, err := adaptivecard.NewTableFromTableCells(allRows, 0, firstRowAsHeaders, showGridLines)
if err != nil {
return adaptivecard.Element{}, fmt.Errorf("failed to create table: %w", err)
}
// Set column widths based on header content length
table.Columns = calculateColumnWidths(headerLengths)
return table, nil
}
// calculateColumnWidths creates TableColumnDefinition entries with widths
// proportional to the max content length of each column.
func calculateColumnWidths(maxLengths []int) []adaptivecard.Column {
if len(maxLengths) == 0 {
return nil
}
// Use content length as relative weight, with a minimum of 1
columns := make([]adaptivecard.Column, len(maxLengths))
for i, length := range maxLengths {
weight := length
if weight < 1 {
weight = 1
}
columns[i] = adaptivecard.Column{
Type: "TableColumnDefinition",
Width: weight,
}
}
return columns
}
// isSeparatorRow checks if a line is a markdown table separator (e.g., |---|---|).
func isSeparatorRow(line string) bool {
// Remove pipes and spaces, check if only dashes and colons remain
cleaned := strings.ReplaceAll(line, "|", "")
cleaned = strings.ReplaceAll(cleaned, " ", "")
cleaned = strings.ReplaceAll(cleaned, "-", "")
cleaned = strings.ReplaceAll(cleaned, ":", "")
return cleaned == ""
}
// parseTableRow extracts cell values from a markdown table row.
func parseTableRow(line string) []string {
// Trim leading/trailing pipes and split by |
line = strings.TrimSpace(line)
line = strings.TrimPrefix(line, "|")
line = strings.TrimSuffix(line, "|")
if line == "" {
return nil
}
parts := strings.Split(line, "|")
var cells []string
for _, p := range parts {
cells = append(cells, strings.TrimSpace(p))
}
return cells
}
@@ -0,0 +1,583 @@
package teamswebhook
import (
"context"
"errors"
"testing"
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
// mockTeamsClient implements teamsMessageSender for testing.
type mockTeamsClient struct {
sendFunc func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
}
func (m *mockTeamsClient) SendWithContext(
ctx context.Context,
webhookURL string,
message goteamsnotify.TeamsMessage,
) error {
if m.sendFunc != nil {
return m.sendFunc(ctx, webhookURL, message)
}
return nil
}
func TestNewTeamsWebhookChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
// Test missing webhooks
_, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: nil,
}, msgBus)
if err == nil {
t.Error("expected error for missing webhooks")
}
// Test missing "default" webhook
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
Title: "Alerts",
},
},
}, msgBus)
if err == nil {
t.Error("expected error for missing 'default' webhook")
}
// Test empty webhook URL
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {Title: "Default"},
},
}, msgBus)
if err == nil {
t.Error("expected error for empty webhook_url")
}
// Test HTTP URL (should fail, must be HTTPS)
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
Title: "Default",
},
},
}, msgBus)
if err == nil {
t.Error("expected error for HTTP webhook URL (must be HTTPS)")
}
// Test valid config with HTTPS (must include "default")
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
Title: "Alerts",
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "teams_webhook" {
t.Errorf("expected name 'teams_webhook', got %q", ch.Name())
}
}
func TestTeamsWebhookChannel_StartStop(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ctx := context.Background()
if ch.IsRunning() {
t.Error("channel should not be running before Start")
}
if err := ch.Start(ctx); err != nil {
t.Fatalf("Start failed: %v", err)
}
if !ch.IsRunning() {
t.Error("channel should be running after Start")
}
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Stop failed: %v", err)
}
if ch.IsRunning() {
t.Error("channel should not be running after Stop")
}
}
func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
Title: "Custom Title",
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
target := ch.config.Webhooks["alerts"]
msg := bus.OutboundMessage{
Content: "Test message content",
ChatID: "alerts",
}
card, err := ch.buildAdaptiveCard(msg, target)
if err != nil {
t.Fatalf("buildAdaptiveCard failed: %v", err)
}
if card.Type != "AdaptiveCard" {
t.Errorf("expected card type 'AdaptiveCard', got %q", card.Type)
}
}
func TestTeamsWebhookChannel_SendNotRunning(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ctx := context.Background()
msg := bus.OutboundMessage{Content: "test", ChatID: "default"}
_, err = ch.Send(ctx, msg)
if err == nil {
t.Error("expected error when sending while not running")
}
}
func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
tests := []struct {
name string
chatID string
}{
{"unknown target falls back to default", "unknown"},
{"empty ChatID uses default", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var sentURL string
ch.client = &mockTeamsClient{
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
sentURL = webhookURL
return nil
},
}
ctx := context.Background()
_ = ch.Start(ctx)
defer ch.Stop(ctx)
msg := bus.OutboundMessage{Content: "test", ChatID: tt.chatID}
_, err = ch.Send(ctx, msg)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if sentURL != "https://example.com/webhook-default" {
t.Errorf("expected default webhook URL, got %q", sentURL)
}
})
}
}
func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
Title: "Test Alerts",
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Inject mock client
var sentURL string
ch.client = &mockTeamsClient{
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
sentURL = webhookURL
return nil
},
}
ctx := context.Background()
_ = ch.Start(ctx)
defer ch.Stop(ctx)
msg := bus.OutboundMessage{Content: "Hello Teams!", ChatID: "alerts"}
_, err = ch.Send(ctx, msg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if sentURL != "https://example.com/webhook-alerts" {
t.Errorf("expected webhook URL 'https://example.com/webhook-alerts', got %q", sentURL)
}
}
func TestTeamsWebhookChannel_SendError(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Inject mock client that returns an error
ch.client = &mockTeamsClient{
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
return errors.New("error on notification: 401 Unauthorized, forbidden")
},
}
ctx := context.Background()
_ = ch.Start(ctx)
defer ch.Stop(ctx)
msg := bus.OutboundMessage{Content: "test", ChatID: "alerts"}
_, err = ch.Send(ctx, msg)
if err == nil {
t.Error("expected error from failed send")
}
}
func TestSplitContentWithTables(t *testing.T) {
tests := []struct {
name string
content string
wantSegs int
wantTbl int // number of table segments
}{
{
name: "no tables",
content: "Just some text\nwith multiple lines",
wantSegs: 1,
wantTbl: 0,
},
{
name: "single table",
content: `| Col1 | Col2 |
|------|------|
| A | B |
| C | D |`,
wantSegs: 1,
wantTbl: 1,
},
{
name: "text before table",
content: `Here is some text.
| Col1 | Col2 |
|------|------|
| A | B |`,
wantSegs: 2,
wantTbl: 1,
},
{
name: "text before and after table",
content: `Before table.
| Col1 | Col2 |
|------|------|
| A | B |
After table.`,
wantSegs: 3,
wantTbl: 1,
},
{
name: "multiple tables",
content: `First table:
| A | B |
|---|---|
| 1 | 2 |
Second table:
| X | Y |
|---|---|
| 3 | 4 |`,
wantSegs: 4,
wantTbl: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
segs := splitContentWithTables(tt.content)
if len(segs) != tt.wantSegs {
t.Errorf("got %d segments, want %d", len(segs), tt.wantSegs)
}
tableCount := 0
for _, s := range segs {
if s.isTable {
tableCount++
}
}
if tableCount != tt.wantTbl {
t.Errorf("got %d tables, want %d", tableCount, tt.wantTbl)
}
})
}
}
func TestParseMarkdownTable(t *testing.T) {
tableStr := `| Name | Value |
|------|-------|
| foo | 123 |
| bar | 456 |`
elem, err := parseMarkdownTable(tableStr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if elem.Type != "Table" {
t.Errorf("expected type 'Table', got %q", elem.Type)
}
// Should have 3 rows (header + 2 data rows)
if len(elem.Rows) != 3 {
t.Errorf("expected 3 rows, got %d", len(elem.Rows))
}
// Should have 2 columns with widths based on content length
if len(elem.Columns) != 2 {
t.Errorf("expected 2 columns, got %d", len(elem.Columns))
}
}
func TestParseMarkdownTableColumnWidths(t *testing.T) {
// Column widths are based on HEADER row only:
// Col1: "Description" (11 chars)
// Col2: "X" (1 char)
// Col3: "Amount" (6 chars)
tableStr := `| Description | X | Amount |
|-------------|---|--------|
| Short | Y | 100 |
| Longer text | Z | 50 |`
elem, err := parseMarkdownTable(tableStr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(elem.Columns) != 3 {
t.Fatalf("expected 3 columns, got %d", len(elem.Columns))
}
// Verify column widths are based on header content length
w1, ok1 := elem.Columns[0].Width.(int)
w2, ok2 := elem.Columns[1].Width.(int)
w3, ok3 := elem.Columns[2].Width.(int)
if !ok1 || !ok2 || !ok3 {
t.Fatalf("expected int widths, got types: %T, %T, %T",
elem.Columns[0].Width, elem.Columns[1].Width, elem.Columns[2].Width)
}
// Header lengths: "Description" = 11, "X" = 1, "Amount" = 6
if w1 != 11 {
t.Errorf("expected col1 width 11 (from 'Description'), got %d", w1)
}
if w2 != 1 {
t.Errorf("expected col2 width 1 (from 'X'), got %d", w2)
}
if w3 != 6 {
t.Errorf("expected col3 width 6 (from 'Amount'), got %d", w3)
}
}
func TestCalculateColumnWidths(t *testing.T) {
tests := []struct {
name string
maxLengths []int
wantWidths []int
}{
{
name: "equal lengths",
maxLengths: []int{10, 10, 10},
wantWidths: []int{10, 10, 10},
},
{
name: "varying lengths",
maxLengths: []int{5, 20, 10},
wantWidths: []int{5, 20, 10},
},
{
name: "zero length gets minimum of 1",
maxLengths: []int{0, 5, 0},
wantWidths: []int{1, 5, 1},
},
{
name: "empty input",
maxLengths: []int{},
wantWidths: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cols := calculateColumnWidths(tt.maxLengths)
if tt.wantWidths == nil {
if cols != nil {
t.Errorf("expected nil, got %v", cols)
}
return
}
if len(cols) != len(tt.wantWidths) {
t.Fatalf("expected %d columns, got %d", len(tt.wantWidths), len(cols))
}
for i, col := range cols {
width, ok := col.Width.(int)
if !ok {
t.Errorf("column %d: expected int width, got %T", i, col.Width)
continue
}
if width != tt.wantWidths[i] {
t.Errorf("column %d: expected width %d, got %d", i, tt.wantWidths[i], width)
}
if col.Type != "TableColumnDefinition" {
t.Errorf("column %d: expected type 'TableColumnDefinition', got %q", i, col.Type)
}
}
})
}
}
func TestParseTableRow(t *testing.T) {
tests := []struct {
line string
want []string
}{
{"| A | B | C |", []string{"A", "B", "C"}},
{"|A|B|C|", []string{"A", "B", "C"}},
{"| foo | bar |", []string{"foo", "bar"}},
{"", nil},
}
for _, tt := range tests {
got := parseTableRow(tt.line)
if len(got) != len(tt.want) {
t.Errorf("parseTableRow(%q): got %v, want %v", tt.line, got, tt.want)
continue
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("parseTableRow(%q)[%d]: got %q, want %q", tt.line, i, got[i], tt.want[i])
}
}
}
}
func TestIsSeparatorRow(t *testing.T) {
tests := []struct {
line string
want bool
}{
{"|---|---|", true},
{"| --- | --- |", true},
{"|:---|---:|", true},
{"| :---: | :---: |", true},
{"| A | B |", false},
{"| foo | bar |", false},
}
for _, tt := range tests {
got := isSeparatorRow(tt.line)
if got != tt.want {
t.Errorf("isSeparatorRow(%q): got %v, want %v", tt.line, got, tt.want)
}
}
}
+5
View File
@@ -1190,3 +1190,8 @@ func isPostConnectError(err error) bool {
strings.Contains(msg, "connection closed by foreign host") ||
strings.Contains(msg, "broken pipe")
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *TelegramChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+13
View File
@@ -0,0 +1,13 @@
package vk
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func init() {
channels.RegisterFactory("vk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewVKChannel(cfg, b)
})
}
+282
View File
@@ -0,0 +1,282 @@
package vk
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/SevereCloud/vksdk/v3/api"
"github.com/SevereCloud/vksdk/v3/api/params"
"github.com/SevereCloud/vksdk/v3/events"
"github.com/SevereCloud/vksdk/v3/longpoll-bot"
"github.com/SevereCloud/vksdk/v3/object"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
type VKChannel struct {
*channels.BaseChannel
vk *api.VK
lp *longpoll.LongPoll
config *config.Config
ctx context.Context
cancel context.CancelFunc
}
func NewVKChannel(cfg *config.Config, bus *bus.MessageBus) (*VKChannel, error) {
vkCfg := cfg.Channels.VK
vk := api.NewVK(vkCfg.Token.String())
base := channels.NewBaseChannel(
"vk",
vkCfg,
bus,
vkCfg.AllowFrom,
channels.WithMaxMessageLength(4000),
channels.WithGroupTrigger(vkCfg.GroupTrigger),
channels.WithReasoningChannelID(vkCfg.ReasoningChannelID),
)
return &VKChannel{
BaseChannel: base,
vk: vk,
config: cfg,
}, nil
}
func (c *VKChannel) Start(ctx context.Context) error {
logger.InfoC("vk", "Starting VK bot (Long Poll mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
groupID := c.config.Channels.VK.GroupID
if groupID == 0 {
c.cancel()
return fmt.Errorf("group_id is required for VK bot")
}
lp, err := longpoll.NewLongPoll(c.vk, groupID)
if err != nil {
c.cancel()
return fmt.Errorf("failed to create long poll: %w", err)
}
c.lp = lp
lp.MessageNew(func(_ context.Context, obj events.MessageNewObject) {
c.handleMessage(obj.Message)
})
c.SetRunning(true)
logger.InfoCF("vk", "VK bot connected", map[string]any{
"group_id": groupID,
})
go func() {
if err := lp.Run(); err != nil {
logger.ErrorCF("vk", "Long poll failed", map[string]any{
"error": err.Error(),
})
}
}()
return nil
}
func (c *VKChannel) Stop(ctx context.Context) error {
logger.InfoC("vk", "Stopping VK bot...")
c.SetRunning(false)
if c.lp != nil {
c.lp.Shutdown()
}
if c.cancel != nil {
c.cancel()
}
return nil
}
func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
if msg.Action.Type != "" {
return
}
if bool(msg.Out) {
return
}
peerID := msg.PeerID
chatID := strconv.Itoa(peerID)
fromID := msg.FromID
userID := strconv.Itoa(fromID)
platformID := userID
sender := bus.SenderInfo{
Platform: "vk",
PlatformID: platformID,
CanonicalID: identity.BuildCanonicalID("vk", platformID),
DisplayName: c.getUserName(fromID),
}
if !c.IsAllowedSender(sender) {
logger.DebugCF("vk", "Message from unauthorized user", map[string]any{
"peer_id": peerID,
})
return
}
text := msg.Text
if text == "" && len(msg.Attachments) > 0 {
text = c.processAttachments(msg.Attachments)
}
if text == "" {
return
}
groupTrigger := c.config.Channels.VK.GroupTrigger
isGroupChat := peerID != fromID
if isGroupChat {
isMentioned := c.isMentioned(msg)
if isMentioned {
text = c.stripBotMention(text)
}
respond, cleaned := c.ShouldRespondInGroup(isMentioned, text)
if !respond {
return
}
text = cleaned
_ = groupTrigger
}
chatType := "direct"
if isGroupChat {
chatType = "group"
}
messageID := strconv.Itoa(msg.ConversationMessageID)
metadata := map[string]string{
"user_id": userID,
"is_group": fmt.Sprintf("%t", isGroupChat),
}
c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{
Channel: "vk",
ChatID: chatID,
ChatType: chatType,
SenderID: userID,
MessageID: messageID,
Mentioned: isGroupChat && c.isMentioned(msg),
Raw: metadata,
}, sender)
}
func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
peerID, err := strconv.Atoi(msg.ChatID)
if err != nil {
return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
if msg.Content == "" {
return nil, nil
}
var messageIDs []string
chunks := channels.SplitMessage(msg.Content, 4000)
for _, chunk := range chunks {
if chunk == "" {
continue
}
b := params.NewMessagesSendBuilder()
b.Message(chunk)
b.RandomID(0)
b.PeerID(peerID)
if msg.ReplyToMessageID != "" {
if replyID, err := strconv.Atoi(msg.ReplyToMessageID); err == nil {
b.ReplyTo(replyID)
}
}
resp, err := c.vk.MessagesSend(b.Params)
if err != nil {
logger.ErrorCF("vk", "Failed to send message", map[string]any{
"error": err.Error(),
"peer_id": peerID,
})
return messageIDs, fmt.Errorf("failed to send message: %w", err)
}
messageIDs = append(messageIDs, strconv.Itoa(resp))
}
return messageIDs, nil
}
func (c *VKChannel) isMentioned(msg object.MessagesMessage) bool {
return false
}
func (c *VKChannel) stripBotMention(text string) string {
return strings.TrimSpace(text)
}
func (c *VKChannel) getUserName(userID int) string {
users, err := c.vk.UsersGet(api.Params{
"user_ids": userID,
})
if err != nil || len(users) == 0 {
return strconv.Itoa(userID)
}
user := users[0]
return fmt.Sprintf("%s %s", user.FirstName, user.LastName)
}
func (c *VKChannel) processAttachments(attachments []object.MessagesMessageAttachment) string {
var parts []string
for _, att := range attachments {
switch att.Type {
case "photo":
parts = append(parts, "[photo]")
case "video":
parts = append(parts, "[video]")
case "audio":
parts = append(parts, "[audio]")
case "doc":
if att.Doc.Title != "" {
parts = append(parts, fmt.Sprintf("[document: %s]", att.Doc.Title))
} else {
parts = append(parts, "[document]")
}
case "audio_message":
parts = append(parts, "[voice]")
case "sticker":
parts = append(parts, "[sticker]")
}
}
return strings.Join(parts, " ")
}
func (c *VKChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}
+260
View File
@@ -0,0 +1,260 @@
package vk
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestNewVKChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing group_id", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error during creation: %v", err)
}
if ch.Name() != "vk" {
t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
t.Run("valid config with group_id", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "vk" {
t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
t.Run("with allow_from", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
AllowFrom: []string{"123456789"},
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !ch.IsAllowedSender(bus.SenderInfo{PlatformID: "123456789"}) {
t.Error("user 123456789 should be allowed")
}
if ch.IsAllowedSender(bus.SenderInfo{PlatformID: "999999999"}) {
t.Error("user 999999999 should not be allowed")
}
})
t.Run("with group_trigger", func(t *testing.T) {
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
GroupTrigger: config.GroupTriggerConfig{
MentionOnly: false,
Prefixes: []string{"/bot", "!bot"},
},
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "vk" {
t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
}
})
}
func TestVKChannel_MaxMessageLength(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
maxLen := ch.MaxMessageLength()
if maxLen != 4000 {
t.Errorf("MaxMessageLength() = %d, want 4000", maxLen)
}
}
func TestVKChannel_SplitMessage(t *testing.T) {
tests := []struct {
name string
content string
maxLen int
want int
}{
{
name: "short message",
content: "hello",
maxLen: 4000,
want: 1,
},
{
name: "exact length",
content: string(make([]byte, 4000)),
maxLen: 4000,
want: 1,
},
{
name: "needs split",
content: string(make([]byte, 5000)),
maxLen: 4000,
want: 2,
},
{
name: "empty message",
content: "",
maxLen: 4000,
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := channels.SplitMessage(tt.content, tt.maxLen)
if len(got) != tt.want {
t.Errorf("SplitMessage() got %d parts, want %d parts", len(got), tt.want)
}
})
}
}
func TestVKChannel_ProcessAttachments(t *testing.T) {
tests := []struct {
name string
attachments []string
want string
}{
{
name: "empty attachments",
attachments: []string{},
want: "",
},
{
name: "photo attachment",
attachments: []string{"photo"},
want: "[photo]",
},
{
name: "video attachment",
attachments: []string{"video"},
want: "[video]",
},
{
name: "audio attachment",
attachments: []string{"audio"},
want: "[audio]",
},
{
name: "document attachment",
attachments: []string{"doc"},
want: "[doc]",
},
{
name: "sticker attachment",
attachments: []string{"sticker"},
want: "[sticker]",
},
{
name: "audio_message attachment",
attachments: []string{"audio_message"},
want: "[voice]",
},
{
name: "multiple attachments",
attachments: []string{"photo", "video", "audio"},
want: "[photo] [video] [audio]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result string
for i, att := range tt.attachments {
if i > 0 {
result += " "
}
if att == "audio_message" {
result += "[voice]"
} else {
result += "[" + att + "]"
}
}
if result != tt.want {
t.Errorf("processAttachments() = %q, want %q", result, tt.want)
}
})
}
}
func TestVKChannel_VoiceCapabilities(t *testing.T) {
msgBus := bus.NewMessageBus()
cfg := &config.Config{
Channels: config.ChannelsConfig{
VK: config.VKConfig{
Enabled: true,
Token: *config.NewSecureString("test_token"),
GroupID: 123456789,
},
},
}
ch, err := NewVKChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
caps := ch.VoiceCapabilities()
if !caps.ASR {
t.Error("VoiceCapabilities().ASR should be true")
}
if !caps.TTS {
t.Error("VoiceCapabilities().TTS should be true")
}
}
+58
View File
@@ -0,0 +1,58 @@
package channels
// VoiceCapabilities describes whether ASR (speech-to-text) and TTS (text-to-speech)
// are available for a channel under the current configuration.
type VoiceCapabilities struct {
ASR bool
TTS bool
}
// VoiceCapabilityProvider is an optional interface for channels that want to
// explicitly declare their ASR/TTS support.
type VoiceCapabilityProvider interface {
VoiceCapabilities() VoiceCapabilities
}
// Deprecated: Channels should implement VoiceCapabilityProvider instead.
// To be removed once all existing capable channels conform to the interface.
var asrCapableChannels = map[string]bool{
"discord": true,
"telegram": true,
"matrix": true,
"qq": true,
"weixin": true,
"line": true,
"feishu": true,
"onebot": true,
}
// DetectVoiceCapabilities returns ASR/TTS availability for a channel, gated by
// whether providers are configured.
func DetectVoiceCapabilities(channelName string, ch Channel, asrAvailable bool, ttsAvailable bool) VoiceCapabilities {
if ch == nil {
return VoiceCapabilities{}
}
if vcp, ok := ch.(VoiceCapabilityProvider); ok {
caps := vcp.VoiceCapabilities()
if !asrAvailable {
caps.ASR = false
}
if !ttsAvailable {
caps.TTS = false
}
return caps
}
caps := VoiceCapabilities{}
if asrAvailable {
caps.ASR = asrCapableChannels[channelName]
}
if ttsAvailable {
if _, ok := ch.(MediaSender); ok {
caps.TTS = true
}
}
return caps
}
+5
View File
@@ -414,3 +414,8 @@ func (c *WeixinChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
return nil, nil
}
// VoiceCapabilities returns the voice capabilities of the channel.
func (c *WeixinChannel) VoiceCapabilities() channels.VoiceCapabilities {
return channels.VoiceCapabilities{ASR: true, TTS: true}
}