feat(discord): add proxy support and tests

This commit is contained in:
nayihz
2026-02-27 14:35:23 +08:00
parent 2c8416e658
commit b5a4bb28b6
5 changed files with 156 additions and 4 deletions
+41
View File
@@ -3,12 +3,15 @@ package channels
import (
"context"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
@@ -39,6 +42,10 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
return nil, fmt.Errorf("failed to create discord session: %w", err)
}
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
return nil, err
}
base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom)
return &DiscordChannel{
@@ -357,9 +364,43 @@ func (c *DiscordChannel) stopTyping(chatID string) {
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
ProxyURL: c.config.Proxy,
})
}
func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
var proxyFunc func(*http.Request) (*url.URL, error)
if proxyAddr != "" {
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
}
proxyFunc = http.ProxyURL(proxyURL)
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
proxyFunc = http.ProxyFromEnvironment
}
if proxyFunc == nil {
return nil
}
transport := &http.Transport{Proxy: proxyFunc}
session.Client = &http.Client{
Timeout: 20 * time.Second,
Transport: transport,
}
if session.Dialer != nil {
dialerCopy := *session.Dialer
dialerCopy.Proxy = proxyFunc
session.Dialer = &dialerCopy
} else {
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
}
return nil
}
// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
+94
View File
@@ -0,0 +1,94 @@
//go:build discord_proxy
// +build discord_proxy
package channels
import (
"net/http"
"net/url"
"testing"
"github.com/bwmarrin/discordgo"
)
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err := applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
restProxy := session.Client.Transport.(*http.Transport).Proxy
restProxyURL, err := restProxy(req)
if err != nil {
t.Fatalf("rest proxy func error: %v", err)
}
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("REST proxy = %q, want %q", got, want)
}
wsProxyURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
t.Fatalf("WS proxy = %q, want %q", got, want)
}
}
func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
t.Setenv("http_proxy", "http://127.0.0.1:8888")
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
t.Setenv("https_proxy", "http://127.0.0.1:8888")
t.Setenv("ALL_PROXY", "")
t.Setenv("all_proxy", "")
t.Setenv("NO_PROXY", "")
t.Setenv("no_proxy", "")
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err := applyDiscordProxy(session, ""); err != nil {
t.Fatalf("applyDiscordProxy() error: %v", err)
}
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
gotURL, err := session.Dialer.Proxy(req)
if err != nil {
t.Fatalf("ws proxy func error: %v", err)
}
wantURL, err := url.Parse("http://127.0.0.1:8888")
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
if gotURL.String() != wantURL.String() {
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
}
}
func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
session, err := discordgo.New("Bot test-token")
if err != nil {
t.Fatalf("discordgo.New() error: %v", err)
}
if err := applyDiscordProxy(session, "://bad-proxy"); err == nil {
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
}
}