diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 1de910c83..c3bcbff8d 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "os" + "regexp" "strings" "sync" "time" @@ -26,6 +27,12 @@ const ( sendTimeout = 10 * time.Second ) +var ( + // Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call) + channelRefRe = regexp.MustCompile(`<#(\d+)>`) + msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`) +) + type DiscordChannel struct { *channels.BaseChannel session *discordgo.Session @@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = c.stripBotMention(content) } + // Resolve Discord refs in main content before concatenation to avoid + // double-expanding links that appear in the referenced message. + content = c.resolveDiscordRefs(s, content, m.GuildID) + + // Prepend referenced (quoted) message content if this is a reply + if m.MessageReference != nil && m.ReferencedMessage != nil { + refContent := m.ReferencedMessage.Content + if refContent != "" { + refAuthor := "unknown" + if m.ReferencedMessage.Author != nil { + refAuthor = m.ReferencedMessage.Author.Username + } + refContent = c.resolveDiscordRefs(s, refContent, m.GuildID) + content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s", + refAuthor, refContent, content) + } + } + senderID := m.Author.ID mediaPaths := make([]string, 0, len(m.Attachments)) @@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error { return nil } +// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and +// expands Discord message links to show the linked message content. +// Only links pointing to the same guild are expanded to prevent cross-guild leakage. +func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string { + // 1. Resolve channel references: <#id> → #channel-name + text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string { + parts := channelRefRe.FindStringSubmatch(match) + if len(parts) < 2 { + return match + } + // Prefer session state cache to avoid API calls + if ch, err := s.State.Channel(parts[1]); err == nil { + return "#" + ch.Name + } + if ch, err := s.Channel(parts[1]); err == nil { + return "#" + ch.Name + } + return match + }) + + // 2. Expand Discord message links (max 3, same guild only) + matches := msgLinkRe.FindAllStringSubmatch(text, 3) + for _, m := range matches { + if len(m) < 4 { + continue + } + linkGuildID, channelID, messageID := m[1], m[2], m[3] + // Security: only expand links from the same guild + if linkGuildID != guildID { + continue + } + msg, err := s.ChannelMessage(channelID, messageID) + if err != nil || msg == nil || msg.Content == "" { + continue + } + author := "unknown" + if msg.Author != nil { + author = msg.Author.Username + } + text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content) + } + + return text +} + // stripBotMention removes the bot mention from the message content. // Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). func (c *DiscordChannel) stripBotMention(text string) string { diff --git a/pkg/channels/discord/discord_resolve_test.go b/pkg/channels/discord/discord_resolve_test.go new file mode 100644 index 000000000..4bc65cc18 --- /dev/null +++ b/pkg/channels/discord/discord_resolve_test.go @@ -0,0 +1,98 @@ +package discord + +import ( + "testing" +) + +func TestChannelRefRegex(t *testing.T) { + tests := []struct { + name string + input string + wantID string + wantOK bool + }{ + {"basic channel ref", "<#123456789>", "123456789", true}, + {"long id", "<#9876543210123456>", "9876543210123456", true}, + {"no match plain text", "hello world", "", false}, + {"no match partial", "<#>", "", false}, + {"no match letters", "<#abc>", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := channelRefRe.FindStringSubmatch(tt.input) + if tt.wantOK { + if len(matches) < 2 || matches[1] != tt.wantID { + t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID) + } + } else { + if len(matches) >= 2 { + t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches) + } + } + }) + } +} + +func TestMsgLinkRegex(t *testing.T) { + tests := []struct { + name string + input string + wantGuild string + wantChan string + wantMsg string + wantOK bool + }{ + { + "discord.com link", + "https://discord.com/channels/111/222/333", + "111", "222", "333", true, + }, + { + "discordapp.com link", + "https://discordapp.com/channels/111/222/333", + "111", "222", "333", true, + }, + { + "real world ids", + "check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please", + "9000000000000001", "9000000000000002", "9000000000000003", true, + }, + {"no match http", "http://discord.com/channels/1/2/3", "", "", "", false}, + {"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false}, + {"no match plain text", "hello world", "", "", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := msgLinkRe.FindStringSubmatch(tt.input) + if tt.wantOK { + if len(matches) < 4 { + t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s", + tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg) + } + if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg { + t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s", + tt.input, matches[1], matches[2], matches[3], + tt.wantGuild, tt.wantChan, tt.wantMsg) + } + } else { + if len(matches) >= 4 { + t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches) + } + } + }) + } +} + +func TestMsgLinkRegex_MultipleMatches(t *testing.T) { + input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12" + matches := msgLinkRe.FindAllStringSubmatch(input, 3) + if len(matches) != 3 { + t.Fatalf("expected 3 matches (capped), got %d", len(matches)) + } + // Verify the 3rd match is 7/8/9 (not 10/11/12) + if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" { + t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2]) + } +}