mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #853 from nayihz/feat_discord_proxy
feat(discord): add proxy support and tests
This commit is contained in:
@@ -3,12 +3,15 @@ package discord
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
return nil, fmt.Errorf("failed to create discord session: %w", err)
|
||||
}
|
||||
|
||||
if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(2000),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
@@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func()
|
||||
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "discord",
|
||||
ProxyURL: c.config.Proxy,
|
||||
})
|
||||
}
|
||||
|
||||
func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
|
||||
var proxyFunc func(*http.Request) (*url.URL, error)
|
||||
if proxyAddr != "" {
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
|
||||
}
|
||||
proxyFunc = http.ProxyURL(proxyURL)
|
||||
} else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
|
||||
proxyFunc = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
if proxyFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
transport := &http.Transport{Proxy: proxyFunc}
|
||||
session.Client = &http.Client{
|
||||
Timeout: sendTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
if session.Dialer != nil {
|
||||
dialerCopy := *session.Dialer
|
||||
dialerCopy.Proxy = proxyFunc
|
||||
session.Dialer = &dialerCopy
|
||||
} else {
|
||||
session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stripBotMention removes the bot mention from the message content.
|
||||
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
|
||||
func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
|
||||
t.Fatalf("applyDiscordProxy() error: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
|
||||
restProxy := session.Client.Transport.(*http.Transport).Proxy
|
||||
restProxyURL, err := restProxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("rest proxy func error: %v", err)
|
||||
}
|
||||
if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
|
||||
t.Fatalf("REST proxy = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
wsProxyURL, err := session.Dialer.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ws proxy func error: %v", err)
|
||||
}
|
||||
if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
|
||||
t.Fatalf("WS proxy = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
|
||||
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("http_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
|
||||
t.Setenv("https_proxy", "http://127.0.0.1:8888")
|
||||
t.Setenv("ALL_PROXY", "")
|
||||
t.Setenv("all_proxy", "")
|
||||
t.Setenv("NO_PROXY", "")
|
||||
t.Setenv("no_proxy", "")
|
||||
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, ""); err != nil {
|
||||
t.Fatalf("applyDiscordProxy() error: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error: %v", err)
|
||||
}
|
||||
|
||||
gotURL, err := session.Dialer.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ws proxy func error: %v", err)
|
||||
}
|
||||
|
||||
wantURL, err := url.Parse("http://127.0.0.1:8888")
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error: %v", err)
|
||||
}
|
||||
if gotURL.String() != wantURL.String() {
|
||||
t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
|
||||
session, err := discordgo.New("Bot test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("discordgo.New() error: %v", err)
|
||||
}
|
||||
|
||||
if err = applyDiscordProxy(session, "://bad-proxy"); err == nil {
|
||||
t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
|
||||
}
|
||||
}
|
||||
@@ -271,6 +271,7 @@ type FeishuConfig struct {
|
||||
type DiscordConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
|
||||
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
|
||||
MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
|
||||
+19
-4
@@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -52,11 +53,12 @@ type DownloadOptions struct {
|
||||
Timeout time.Duration
|
||||
ExtraHeaders map[string]string
|
||||
LoggerPrefix string
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
// DownloadFile downloads a file from URL to a local temp directory.
|
||||
// Returns the local file path or empty string on error.
|
||||
func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
|
||||
// Set defaults
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = 60 * time.Second
|
||||
@@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
req, err := http.NewRequest("GET", urlStr, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: opts.Timeout}
|
||||
if opts.ProxyURL != "" {
|
||||
proxyURL, parseErr := url.Parse(opts.ProxyURL)
|
||||
if parseErr != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{
|
||||
"error": parseErr.Error(),
|
||||
"proxy": opts.ProxyURL,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
|
||||
"error": err.Error(),
|
||||
"url": url,
|
||||
"url": urlStr,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
@@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
|
||||
"status": resp.StatusCode,
|
||||
"url": url,
|
||||
"url": urlStr,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user