mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(host): complete launcher and gateway multi-host binding support
- add shared netbind planning for strict tcp4/tcp6 bind semantics - support launcher/gateway host env overrides and launcher-to-gateway forwarding - cover host binding and forwarding with network and subprocess env tests
This commit is contained in:
+36
-9
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
@@ -86,6 +87,7 @@ type Manager struct {
|
||||
dispatchTask *asyncTask
|
||||
mux *dynamicServeMux
|
||||
httpServer *http.Server
|
||||
httpListeners []net.Listener
|
||||
mu sync.RWMutex
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
@@ -474,6 +476,12 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
// It registers health endpoints from the health server and discovers channels
|
||||
// that implement WebhookHandler and/or HealthChecker to register their handlers.
|
||||
func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
m.SetupHTTPServerListeners(nil, addr, healthServer)
|
||||
}
|
||||
|
||||
// SetupHTTPServerListeners creates a shared HTTP server on pre-opened listeners.
|
||||
// When listeners is empty it falls back to Addr-based ListenAndServe behavior.
|
||||
func (m *Manager) SetupHTTPServerListeners(listeners []net.Listener, addr string, healthServer *health.Server) {
|
||||
m.mux = newDynamicServeMux()
|
||||
|
||||
// Register health endpoints
|
||||
@@ -490,6 +498,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
m.httpListeners = append([]net.Listener(nil), listeners...)
|
||||
}
|
||||
|
||||
// registerHTTPHandlersLocked registers webhook and health-check handlers for
|
||||
@@ -619,16 +628,33 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
// Start shared HTTP server if configured
|
||||
if m.httpServer != nil {
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
if len(m.httpListeners) > 0 {
|
||||
for _, listener := range m.httpListeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": ln.Addr().String(),
|
||||
})
|
||||
if err := m.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"addr": ln.Addr().String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel startup completed", map[string]any{
|
||||
@@ -655,6 +681,7 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
||||
})
|
||||
}
|
||||
m.httpServer = nil
|
||||
m.httpListeners = nil
|
||||
}
|
||||
|
||||
// Cancel dispatcher
|
||||
|
||||
@@ -1082,7 +1082,10 @@ func LoadConfig(path string) (*Config, error) {
|
||||
if err = InitChannelList(cfg.Channels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.Gateway.Host = resolveGatewayHostFromEnv(gatewayHostBeforeEnv)
|
||||
cfg.Gateway.Host, err = resolveGatewayHostFromEnv(gatewayHostBeforeEnv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid gateway host: %w", err)
|
||||
}
|
||||
|
||||
// Expand multi-key configs into separate entries for key-level failover
|
||||
cfg.ModelList = expandMultiKeyModels(cfg.ModelList)
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
EnvBinary = "PICOCLAW_BINARY"
|
||||
|
||||
// EnvGatewayHost overrides the host address for the gateway server.
|
||||
// Default: "127.0.0.1"
|
||||
// Default: "localhost"
|
||||
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
|
||||
)
|
||||
|
||||
|
||||
+16
-107
@@ -2,12 +2,11 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
const DefaultGatewayLogLevel = "warn"
|
||||
@@ -52,119 +51,29 @@ func EffectiveGatewayLogLevel(cfg *Config) string {
|
||||
return normalizeGatewayLogLevel(cfg.Gateway.LogLevel)
|
||||
}
|
||||
|
||||
var (
|
||||
gatewayIPFamiliesOnce sync.Once
|
||||
gatewayHasIPv4 bool
|
||||
gatewayHasIPv6 bool
|
||||
)
|
||||
|
||||
func detectGatewayIPFamilies() (bool, bool) {
|
||||
gatewayIPFamiliesOnce.Do(func() {
|
||||
if ips, err := net.LookupIP("localhost"); err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
gatewayHasIPv4 = true
|
||||
continue
|
||||
}
|
||||
gatewayHasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
if gatewayHasIPv4 && gatewayHasIPv6 {
|
||||
return
|
||||
}
|
||||
|
||||
if addrs, err := net.InterfaceAddrs(); err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipnet.IP == nil {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
gatewayHasIPv4 = true
|
||||
continue
|
||||
}
|
||||
gatewayHasIPv6 = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return gatewayHasIPv4, gatewayHasIPv6
|
||||
}
|
||||
|
||||
func selectAdaptiveGatewayLoopbackHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "localhost"
|
||||
case hasIPv6:
|
||||
return "::1"
|
||||
case hasIPv4:
|
||||
return "127.0.0.1"
|
||||
default:
|
||||
return "localhost"
|
||||
}
|
||||
}
|
||||
|
||||
func selectAdaptiveGatewayAnyHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "::"
|
||||
case hasIPv6:
|
||||
return "::"
|
||||
case hasIPv4:
|
||||
return "0.0.0.0"
|
||||
default:
|
||||
return "::"
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAdaptiveGatewayLoopbackHost() string {
|
||||
hasIPv4, hasIPv6 := detectGatewayIPFamilies()
|
||||
return selectAdaptiveGatewayLoopbackHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func resolveAdaptiveGatewayAnyHost() string {
|
||||
hasIPv4, hasIPv6 := detectGatewayIPFamilies()
|
||||
return selectAdaptiveGatewayAnyHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func normalizeGatewayHost(host string) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
host = strings.TrimSpace(DefaultConfig().Gateway.Host)
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return resolveAdaptiveGatewayLoopbackHost()
|
||||
}
|
||||
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
if ip := net.ParseIP(trimmed); ip != nil && ip.IsUnspecified() {
|
||||
return resolveAdaptiveGatewayAnyHost()
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
func resolveGatewayHostFromEnv(baseHost string) string {
|
||||
func resolveGatewayHostFromEnv(baseHost string) (string, error) {
|
||||
envHost, ok := os.LookupEnv(EnvGatewayHost)
|
||||
if !ok {
|
||||
return normalizeGatewayHost(baseHost)
|
||||
return normalizeGatewayHostInput(baseHost)
|
||||
}
|
||||
|
||||
envHost = strings.TrimSpace(envHost)
|
||||
if envHost == "" {
|
||||
return normalizeGatewayHost(baseHost)
|
||||
return normalizeGatewayHostInput(baseHost)
|
||||
}
|
||||
|
||||
return normalizeGatewayHost(envHost)
|
||||
return normalizeGatewayHostInput(envHost)
|
||||
}
|
||||
|
||||
func normalizeGatewayHostInput(host string) (string, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
host = strings.TrimSpace(DefaultConfig().Gateway.Host)
|
||||
}
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
return netbind.NormalizeHostInput(host)
|
||||
}
|
||||
|
||||
// ResolveGatewayLogLevel reads the configured gateway log level without triggering
|
||||
|
||||
@@ -39,7 +39,10 @@ func TestLoadConfig_GatewayHostBlankEnvFallsBackToConfigHost(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
want := normalizeGatewayHost("localhost")
|
||||
want, err := normalizeGatewayHostInput("localhost")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != want {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want)
|
||||
}
|
||||
@@ -54,13 +57,16 @@ func TestLoadConfig_GatewayHostBlankEnvAndConfigFallsBackToDefault(t *testing.T)
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
defaultHost := normalizeGatewayHost(DefaultConfig().Gateway.Host)
|
||||
defaultHost, err := normalizeGatewayHostInput(DefaultConfig().Gateway.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != defaultHost {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, defaultHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_GatewayHostEnvWildcardUsesAdaptiveAnyHost(t *testing.T) {
|
||||
func TestLoadConfig_GatewayHostEnvPreservesExplicitWildcardHost(t *testing.T) {
|
||||
configPath := writeGatewayHostTestConfig(t, "localhost")
|
||||
t.Setenv(EnvGatewayHost, " 0.0.0.0 ")
|
||||
|
||||
@@ -69,8 +75,24 @@ func TestLoadConfig_GatewayHostEnvWildcardUsesAdaptiveAnyHost(t *testing.T) {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
want := normalizeGatewayHost("0.0.0.0")
|
||||
want, err := normalizeGatewayHostInput("0.0.0.0")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != want {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_GatewayHostEnvNormalizesMultiHostInput(t *testing.T) {
|
||||
configPath := writeGatewayHostTestConfig(t, "localhost")
|
||||
t.Setenv(EnvGatewayHost, " [::1] , 127.0.0.1 , ::1 ")
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != "::1,127.0.0.1" {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1,127.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
+39
-8
@@ -44,6 +44,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
"github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
@@ -161,13 +162,30 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
logger.Infof("Log level set to %q", effectiveLogLevel)
|
||||
}
|
||||
|
||||
bindPlan, listenResult, err := openGatewayListeners(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening gateway listeners: %w", err)
|
||||
}
|
||||
|
||||
// Enforce singleton: write PID file with generated token.
|
||||
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
pidData, err := pid.WritePidFile(homePath, bindPlan.ProbeHost, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
logger.Warnf("write pid file failed: %v", err)
|
||||
for _, ln := range listenResult.Listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
return fmt.Errorf("singleton check failed: %w", err)
|
||||
}
|
||||
defer pid.RemovePidFile(homePath)
|
||||
closeListeners := true
|
||||
defer func() {
|
||||
if !closeListeners {
|
||||
return
|
||||
}
|
||||
for _, ln := range listenResult.Listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
@@ -195,10 +213,11 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token)
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token, listenResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
closeListeners = false
|
||||
|
||||
// Setup manual reload channel for /reload endpoint
|
||||
manualReloadChan := make(chan struct{}, 1)
|
||||
@@ -219,8 +238,9 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
runningServices.HealthServer.SetReloadFunc(reloadTrigger)
|
||||
agentLoop.SetReloadFunc(reloadTrigger)
|
||||
|
||||
listenAddr := net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port))
|
||||
fmt.Printf("✓ Gateway started on %s\n", listenAddr)
|
||||
for _, bindHost := range listenResult.BindHosts {
|
||||
fmt.Printf("✓ Gateway started on %s\n", net.JoinHostPort(bindHost, strconv.Itoa(cfg.Gateway.Port)))
|
||||
}
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -323,6 +343,7 @@ func setupAndStartServices(
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
authToken string,
|
||||
listenResult netbind.OpenResult,
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
@@ -393,10 +414,20 @@ func setupAndStartServices(
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port))
|
||||
runningServices.authToken = authToken
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
runningServices.HealthServer = health.NewServer(listenResult.ProbeHost, cfg.Gateway.Port, authToken)
|
||||
|
||||
listenAddr := ""
|
||||
if len(listenResult.Listeners) > 0 {
|
||||
listenAddr = listenResult.Listeners[0].Addr().String()
|
||||
} else {
|
||||
listenAddr = net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port))
|
||||
}
|
||||
runningServices.ChannelManager.SetupHTTPServerListeners(
|
||||
listenResult.Listeners,
|
||||
listenAddr,
|
||||
runningServices.HealthServer,
|
||||
)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
@@ -412,7 +443,7 @@ func setupAndStartServices(
|
||||
voiceAgent.Start(vaCtx)
|
||||
}
|
||||
|
||||
healthAddr := net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port))
|
||||
healthAddr := net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port))
|
||||
fmt.Printf(
|
||||
"✓ Health endpoints available at http://%s/health, /ready and /reload (POST)\n",
|
||||
healthAddr,
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
func openGatewayListeners(host string, port int) (netbind.Plan, netbind.OpenResult, error) {
|
||||
plan, err := netbind.BuildPlan(host, netbind.DefaultLoopback)
|
||||
if err != nil {
|
||||
return netbind.Plan{}, netbind.OpenResult{}, err
|
||||
}
|
||||
|
||||
result, err := netbind.OpenPlan(plan, strconv.Itoa(port))
|
||||
if err != nil {
|
||||
return netbind.Plan{}, netbind.OpenResult{}, err
|
||||
}
|
||||
|
||||
return plan, result, nil
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
func TestOpenGatewayListeners_HonorsIPv6OnlyHost(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
|
||||
if !hasIPv6 {
|
||||
t.Skip("IPv6 is unavailable in this environment")
|
||||
}
|
||||
|
||||
_, result, err := openGatewayListeners("::", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("openGatewayListeners() error = %v", err)
|
||||
}
|
||||
startGatewayTestHTTPServer(t, result.Listeners)
|
||||
port := mustGatewayAtoi(t, result.Port)
|
||||
|
||||
requireGatewayHTTPReachable(t, "::1", port)
|
||||
if hasIPv4 {
|
||||
requireGatewayHTTPUnreachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenGatewayListeners_SupportsExplicitMultiHost(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
|
||||
if !hasIPv4 || !hasIPv6 {
|
||||
t.Skip("dual-stack loopback is unavailable in this environment")
|
||||
}
|
||||
|
||||
_, result, err := openGatewayListeners("127.0.0.1,::1", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("openGatewayListeners() error = %v", err)
|
||||
}
|
||||
startGatewayTestHTTPServer(t, result.Listeners)
|
||||
port := mustGatewayAtoi(t, result.Port)
|
||||
|
||||
requireGatewayHTTPReachable(t, "127.0.0.1", port)
|
||||
requireGatewayHTTPReachable(t, "::1", port)
|
||||
}
|
||||
|
||||
func startGatewayTestHTTPServer(t *testing.T, listeners []net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}),
|
||||
}
|
||||
|
||||
errCh := make(chan error, len(listeners))
|
||||
for _, listener := range listeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(ctx)
|
||||
for range listeners {
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Fatalf("server.Serve() error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func requireGatewayHTTPReachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
err := gatewayHTTPGet(host, port)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("expected %s:%d to be reachable: %v", host, port, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func requireGatewayHTTPUnreachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
if err := gatewayHTTPGet(host, port); err == nil {
|
||||
t.Fatalf("expected %s:%d to be unreachable", host, port)
|
||||
}
|
||||
}
|
||||
|
||||
func gatewayHTTPGet(host string, port int) error {
|
||||
client := &http.Client{
|
||||
Timeout: 300 * time.Millisecond,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustGatewayAtoi(t *testing.T, value string) int {
|
||||
t.Helper()
|
||||
n, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", value, err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,580 @@
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type DefaultMode int
|
||||
|
||||
const (
|
||||
DefaultLoopback DefaultMode = iota
|
||||
DefaultAny
|
||||
)
|
||||
|
||||
type groupKind int
|
||||
|
||||
const (
|
||||
groupAdaptiveLoopback groupKind = iota
|
||||
groupAdaptiveAny
|
||||
groupExact
|
||||
)
|
||||
|
||||
type exactBinding struct {
|
||||
host string
|
||||
network string
|
||||
v6Only bool
|
||||
}
|
||||
|
||||
type bindGroup struct {
|
||||
kind groupKind
|
||||
allowIPv4 bool
|
||||
allowIPv6 bool
|
||||
exact exactBinding
|
||||
}
|
||||
|
||||
type Plan struct {
|
||||
groups []bindGroup
|
||||
ProbeHost string
|
||||
}
|
||||
|
||||
type OpenResult struct {
|
||||
Listeners []net.Listener
|
||||
BindHosts []string
|
||||
Port string
|
||||
ProbeHost string
|
||||
}
|
||||
|
||||
type tokenKind int
|
||||
|
||||
const (
|
||||
tokenName tokenKind = iota
|
||||
tokenLocalhost
|
||||
tokenStar
|
||||
tokenIPv4
|
||||
tokenIPv6
|
||||
tokenIPv4Any
|
||||
tokenIPv6Any
|
||||
)
|
||||
|
||||
type hostToken struct {
|
||||
kind tokenKind
|
||||
canonical string
|
||||
key string
|
||||
}
|
||||
|
||||
var (
|
||||
ipFamiliesOnce sync.Once
|
||||
hasIPv4 bool
|
||||
hasIPv6 bool
|
||||
)
|
||||
|
||||
func DetectIPFamilies() (bool, bool) {
|
||||
ipFamiliesOnce.Do(func() {
|
||||
if ips, err := net.LookupIP("localhost"); err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
hasIPv4 = true
|
||||
continue
|
||||
}
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasIPv4 && hasIPv6 {
|
||||
return
|
||||
}
|
||||
|
||||
if addrs, err := net.InterfaceAddrs(); err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipnet.IP == nil {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
hasIPv4 = true
|
||||
continue
|
||||
}
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return hasIPv4, hasIPv6
|
||||
}
|
||||
|
||||
func SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "localhost"
|
||||
case hasIPv6:
|
||||
return "::1"
|
||||
case hasIPv4:
|
||||
return "127.0.0.1"
|
||||
default:
|
||||
return "localhost"
|
||||
}
|
||||
}
|
||||
|
||||
func SelectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "::"
|
||||
case hasIPv6:
|
||||
return "::"
|
||||
case hasIPv4:
|
||||
return "0.0.0.0"
|
||||
default:
|
||||
return "::"
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveAdaptiveLoopbackHost() string {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
return SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func ResolveAdaptiveAnyHost() string {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
return SelectAdaptiveAnyHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func IsLoopbackHost(host string) bool {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(strings.Trim(host, "[]"))
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func IsUnspecifiedHost(host string) bool {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
ip := net.ParseIP(strings.Trim(host, "[]"))
|
||||
return ip != nil && ip.IsUnspecified()
|
||||
}
|
||||
|
||||
func NormalizeHostInput(raw string) (string, error) {
|
||||
tokens, err := parseHostTokens(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
parts = append(parts, token.canonical)
|
||||
}
|
||||
return strings.Join(parts, ","), nil
|
||||
}
|
||||
|
||||
func BuildPlan(raw string, defaultMode DefaultMode) (Plan, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return buildDefaultPlan(defaultMode), nil
|
||||
}
|
||||
|
||||
tokens, err := parseHostTokens(raw)
|
||||
if err != nil {
|
||||
return Plan{}, err
|
||||
}
|
||||
|
||||
for _, token := range tokens {
|
||||
if token.kind == tokenStar {
|
||||
return Plan{
|
||||
groups: []bindGroup{{kind: groupAdaptiveAny}},
|
||||
ProbeHost: ResolveAdaptiveLoopbackHost(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasIPv4Any := false
|
||||
hasIPv6Any := false
|
||||
for _, token := range tokens {
|
||||
switch token.kind {
|
||||
case tokenIPv4Any:
|
||||
hasIPv4Any = true
|
||||
case tokenIPv6Any:
|
||||
hasIPv6Any = true
|
||||
}
|
||||
}
|
||||
|
||||
allowLocalhostIPv4 := !hasIPv4Any
|
||||
allowLocalhostIPv6 := !hasIPv6Any
|
||||
|
||||
groups := make([]bindGroup, 0, len(tokens))
|
||||
seenExact := make(map[string]struct{}, len(tokens))
|
||||
addedLocalhost := false
|
||||
|
||||
for _, token := range tokens {
|
||||
switch token.kind {
|
||||
case tokenLocalhost:
|
||||
if addedLocalhost || (!allowLocalhostIPv4 && !allowLocalhostIPv6) {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupAdaptiveLoopback,
|
||||
allowIPv4: allowLocalhostIPv4,
|
||||
allowIPv6: allowLocalhostIPv6,
|
||||
})
|
||||
addedLocalhost = true
|
||||
case tokenIPv4Any:
|
||||
key := "exact:tcp4:0.0.0.0"
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: "0.0.0.0",
|
||||
network: "tcp4",
|
||||
},
|
||||
})
|
||||
case tokenIPv6Any:
|
||||
key := "exact:tcp6:::"
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: "::",
|
||||
network: "tcp6",
|
||||
v6Only: true,
|
||||
},
|
||||
})
|
||||
case tokenIPv4:
|
||||
if hasIPv4Any {
|
||||
continue
|
||||
}
|
||||
key := "exact:tcp4:" + strings.ToLower(token.canonical)
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: token.canonical,
|
||||
network: "tcp4",
|
||||
},
|
||||
})
|
||||
case tokenIPv6:
|
||||
if hasIPv6Any {
|
||||
continue
|
||||
}
|
||||
key := "exact:tcp6:" + strings.ToLower(token.canonical)
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: token.canonical,
|
||||
network: "tcp6",
|
||||
v6Only: true,
|
||||
},
|
||||
})
|
||||
case tokenName:
|
||||
key := "exact:tcp:" + token.key
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: token.canonical,
|
||||
network: "tcp",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
plan := Plan{groups: groups}
|
||||
plan.ProbeHost = probeHostForGroups(groups)
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func OpenPlan(plan Plan, port string) (OpenResult, error) {
|
||||
if port == "" {
|
||||
return OpenResult{}, errors.New("port cannot be empty")
|
||||
}
|
||||
|
||||
selectedPort := port
|
||||
listeners := make([]net.Listener, 0, len(plan.groups))
|
||||
bindHosts := make([]string, 0, len(plan.groups))
|
||||
bindSeen := make(map[string]struct{}, len(plan.groups))
|
||||
|
||||
closeAll := func() {
|
||||
for _, ln := range listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
}
|
||||
|
||||
for _, group := range plan.groups {
|
||||
groupListeners, groupHosts, actualPort, err := openGroup(group, selectedPort)
|
||||
if err != nil {
|
||||
closeAll()
|
||||
return OpenResult{}, err
|
||||
}
|
||||
if selectedPort == "0" && actualPort != "" {
|
||||
selectedPort = actualPort
|
||||
}
|
||||
listeners = append(listeners, groupListeners...)
|
||||
for _, host := range groupHosts {
|
||||
key := strings.ToLower(host)
|
||||
if _, ok := bindSeen[key]; ok {
|
||||
continue
|
||||
}
|
||||
bindSeen[key] = struct{}{}
|
||||
bindHosts = append(bindHosts, host)
|
||||
}
|
||||
}
|
||||
|
||||
return OpenResult{
|
||||
Listeners: listeners,
|
||||
BindHosts: bindHosts,
|
||||
Port: selectedPort,
|
||||
ProbeHost: plan.ProbeHost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildDefaultPlan(defaultMode DefaultMode) Plan {
|
||||
switch defaultMode {
|
||||
case DefaultAny:
|
||||
return Plan{
|
||||
groups: []bindGroup{{kind: groupAdaptiveAny}},
|
||||
ProbeHost: ResolveAdaptiveLoopbackHost(),
|
||||
}
|
||||
default:
|
||||
return Plan{
|
||||
groups: []bindGroup{{
|
||||
kind: groupAdaptiveLoopback,
|
||||
allowIPv4: true,
|
||||
allowIPv6: true,
|
||||
}},
|
||||
ProbeHost: ResolveAdaptiveLoopbackHost(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func probeHostForGroups(groups []bindGroup) string {
|
||||
hasIPv4Any := false
|
||||
hasIPv6Any := false
|
||||
for _, group := range groups {
|
||||
if group.kind == groupAdaptiveLoopback {
|
||||
switch {
|
||||
case group.allowIPv4 && group.allowIPv6:
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
case group.allowIPv6:
|
||||
return "::1"
|
||||
case group.allowIPv4:
|
||||
return "127.0.0.1"
|
||||
}
|
||||
}
|
||||
if group.kind == groupAdaptiveAny {
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
if group.kind != groupExact {
|
||||
continue
|
||||
}
|
||||
switch group.exact.host {
|
||||
case "0.0.0.0":
|
||||
hasIPv4Any = true
|
||||
case "::":
|
||||
hasIPv6Any = true
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case hasIPv4Any && hasIPv6Any:
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
case hasIPv6Any:
|
||||
return "::1"
|
||||
case hasIPv4Any:
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.kind == groupExact {
|
||||
return group.exact.host
|
||||
}
|
||||
}
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func parseHostTokens(raw string) ([]hostToken, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, errors.New("host cannot be empty")
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, ",")
|
||||
tokens := make([]hostToken, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
token, err := parseHostToken(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := seen[token.key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[token.key] = struct{}{}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return nil, errors.New("host cannot be empty")
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func parseHostToken(raw string) (hostToken, error) {
|
||||
host := strings.TrimSpace(raw)
|
||||
if host == "" {
|
||||
return hostToken{}, errors.New("host list contains an empty entry")
|
||||
}
|
||||
|
||||
if host == "*" {
|
||||
return hostToken{kind: tokenStar, canonical: "*", key: "*"}, nil
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return hostToken{kind: tokenLocalhost, canonical: "localhost", key: "localhost"}, nil
|
||||
}
|
||||
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
if ip := net.ParseIP(trimmed); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
canonical := ip4.String()
|
||||
kind := tokenIPv4
|
||||
if ip4.IsUnspecified() {
|
||||
kind = tokenIPv4Any
|
||||
}
|
||||
return hostToken{kind: kind, canonical: canonical, key: canonical}, nil
|
||||
}
|
||||
|
||||
canonical := ip.String()
|
||||
kind := tokenIPv6
|
||||
if ip.IsUnspecified() {
|
||||
kind = tokenIPv6Any
|
||||
}
|
||||
return hostToken{kind: kind, canonical: canonical, key: strings.ToLower(canonical)}, nil
|
||||
}
|
||||
|
||||
return hostToken{
|
||||
kind: tokenName,
|
||||
canonical: host,
|
||||
key: strings.ToLower(host),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func openGroup(group bindGroup, port string) ([]net.Listener, []string, string, error) {
|
||||
switch group.kind {
|
||||
case groupAdaptiveLoopback:
|
||||
return openAdaptiveLoopbackGroup(group.allowIPv6, group.allowIPv4, port)
|
||||
case groupAdaptiveAny:
|
||||
return openAdaptiveAnyGroup(port)
|
||||
case groupExact:
|
||||
ln, actualPort, err := openExactListener(group.exact, port)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
return []net.Listener{ln}, []string{group.exact.host}, actualPort, nil
|
||||
default:
|
||||
return nil, nil, "", fmt.Errorf("unsupported bind group kind: %d", group.kind)
|
||||
}
|
||||
}
|
||||
|
||||
func openAdaptiveLoopbackGroup(allowIPv6, allowIPv4 bool, port string) ([]net.Listener, []string, string, error) {
|
||||
if allowIPv6 && allowIPv4 {
|
||||
if ln6, actualPort, err6 := openExactListener(exactBinding{host: "::1", network: "tcp6", v6Only: true}, port); err6 == nil {
|
||||
if ln4, _, err4 := openExactListener(exactBinding{host: "127.0.0.1", network: "tcp4"}, actualPort); err4 == nil {
|
||||
return []net.Listener{ln6, ln4}, []string{"::1", "127.0.0.1"}, actualPort, nil
|
||||
}
|
||||
_ = ln6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if allowIPv6 {
|
||||
ln6, actualPort, err := openExactListener(exactBinding{host: "::1", network: "tcp6", v6Only: true}, port)
|
||||
if err == nil {
|
||||
return []net.Listener{ln6}, []string{"::1"}, actualPort, nil
|
||||
}
|
||||
}
|
||||
|
||||
if allowIPv4 {
|
||||
ln4, actualPort, err := openExactListener(exactBinding{host: "127.0.0.1", network: "tcp4"}, port)
|
||||
if err == nil {
|
||||
return []net.Listener{ln4}, []string{"127.0.0.1"}, actualPort, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, "", fmt.Errorf("failed to open adaptive localhost listener on port %s", port)
|
||||
}
|
||||
|
||||
func openAdaptiveAnyGroup(port string) ([]net.Listener, []string, string, error) {
|
||||
// Intentionally bind tcp/:: here. Go's compatibility layer handles dual-stack
|
||||
// wildcard binding where the platform supports it, while tcp4 remains the
|
||||
// fallback for IPv4-only environments.
|
||||
if ln, actualPort, err := openExactListener(exactBinding{host: "::", network: "tcp"}, port); err == nil {
|
||||
return []net.Listener{ln}, []string{"::"}, actualPort, nil
|
||||
}
|
||||
|
||||
ln4, actualPort, err := openExactListener(exactBinding{host: "0.0.0.0", network: "tcp4"}, port)
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to open adaptive any-host listener on port %s", port)
|
||||
}
|
||||
return []net.Listener{ln4}, []string{"0.0.0.0"}, actualPort, nil
|
||||
}
|
||||
|
||||
func openExactListener(binding exactBinding, port string) (net.Listener, string, error) {
|
||||
listenConfig := net.ListenConfig{}
|
||||
if binding.network == "tcp6" && binding.v6Only {
|
||||
listenConfig.Control = applyIPv6OnlyControl(true)
|
||||
}
|
||||
|
||||
ln, err := listenConfig.Listen(context.Background(), binding.network, net.JoinHostPort(binding.host, port))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
actualPort, err := listenerPort(ln)
|
||||
if err != nil {
|
||||
_ = ln.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return ln, actualPort, nil
|
||||
}
|
||||
|
||||
func listenerPort(ln net.Listener) (string, error) {
|
||||
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if ok {
|
||||
return strconv.Itoa(addr.Port), nil
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(ln.Addr().String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return port, nil
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizeHostInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "single host", raw: "127.0.0.1", want: "127.0.0.1"},
|
||||
{name: "trim and dedupe", raw: " [::1] , ::1 , 127.0.0.1 ", want: "::1,127.0.0.1"},
|
||||
{name: "star preserved", raw: "*,127.0.0.1", want: "*,127.0.0.1"},
|
||||
{name: "reject empty", raw: "127.0.0.1, ", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NormalizeHostInput(tt.raw)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("NormalizeHostInput() err = %v, wantErr %t", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("NormalizeHostInput() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPlan_DefaultAnyUsesLoopbackProbe(t *testing.T) {
|
||||
plan, err := BuildPlan("", DefaultAny)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
if plan.ProbeHost != ResolveAdaptiveLoopbackHost() {
|
||||
t.Fatalf("ProbeHost = %q, want %q", plan.ProbeHost, ResolveAdaptiveLoopbackHost())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_LocalhostSupportsLoopbackCommunication(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
|
||||
plan, err := BuildPlan("localhost", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
if hasIPv6 {
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
}
|
||||
if hasIPv4 {
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_DefaultAnySupportsDualStackLoopback(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
|
||||
plan, err := BuildPlan("", DefaultAny)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
if hasIPv6 {
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
}
|
||||
if hasIPv4 {
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_ExplicitIPv6AnyIsIPv6Only(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv6 {
|
||||
t.Skip("IPv6 is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("::", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
if hasIPv4 {
|
||||
requireHTTPUnreachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_ExplicitIPv4AnyIsIPv4Only(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv4 {
|
||||
t.Skip("IPv4 is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("0.0.0.0", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
if hasIPv6 {
|
||||
requireHTTPUnreachable(t, "::1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_MultiHostSupportsExplicitIPv4AndIPv6(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv4 || !hasIPv6 {
|
||||
t.Skip("dual-stack loopback is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("127.0.0.1,::1", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
}
|
||||
|
||||
func TestOpenPlan_WildcardRulesKeepIPv4AndIPv6AnyHosts(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv4 || !hasIPv6 {
|
||||
t.Skip("dual-stack loopback is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("::,::1,0.0.0.0,127.0.0.1", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
if len(result.BindHosts) != 2 {
|
||||
t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts)
|
||||
}
|
||||
}
|
||||
|
||||
func startTestHTTPServer(t *testing.T, listeners []net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}),
|
||||
}
|
||||
|
||||
errCh := make(chan error, len(listeners))
|
||||
for _, listener := range listeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(ctx)
|
||||
for range listeners {
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Fatalf("server.Serve() error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func requireHTTPReachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
err := httpGET(host, port)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("expected %s:%d to be reachable: %v", host, port, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func requireHTTPUnreachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
|
||||
if err := httpGET(host, port); err == nil {
|
||||
t.Fatalf("expected %s:%d to be unreachable", host, port)
|
||||
}
|
||||
}
|
||||
|
||||
func httpGET(host string, port int) error {
|
||||
client := &http.Client{
|
||||
Timeout: 300 * time.Millisecond,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustAtoi(t *testing.T, value string) int {
|
||||
t.Helper()
|
||||
n, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", value, err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
//go:build !windows
|
||||
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error {
|
||||
return func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var controlErr error
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
value := 0
|
||||
if enabled {
|
||||
value = 1
|
||||
}
|
||||
controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, value)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
//go:build windows
|
||||
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error {
|
||||
return func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var controlErr error
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
value := 0
|
||||
if enabled {
|
||||
value = 1
|
||||
}
|
||||
controlErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, value)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user