Merge pull request #2514 from lc6464/fix/issue-2488-host-binding

feat(launcher): add host overrides for launcher and gateway
This commit is contained in:
美電球
2026-04-14 23:48:24 +08:00
committed by GitHub
29 changed files with 2420 additions and 99 deletions
+36 -9
View File
@@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"math"
"net"
"net/http"
"sort"
"sync"
@@ -86,6 +87,7 @@ type Manager struct {
dispatchTask *asyncTask
mux *dynamicServeMux
httpServer *http.Server
httpListeners []net.Listener
mu sync.RWMutex
placeholders sync.Map // "channel:chatID" → placeholderID (string)
typingStops sync.Map // "channel:chatID" → func()
@@ -474,6 +476,12 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
// It registers health endpoints from the health server and discovers channels
// that implement WebhookHandler and/or HealthChecker to register their handlers.
func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
m.SetupHTTPServerListeners(nil, addr, healthServer)
}
// SetupHTTPServerListeners creates a shared HTTP server on pre-opened listeners.
// When listeners is empty it falls back to Addr-based ListenAndServe behavior.
func (m *Manager) SetupHTTPServerListeners(listeners []net.Listener, addr string, healthServer *health.Server) {
m.mux = newDynamicServeMux()
// Register health endpoints
@@ -490,6 +498,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
m.httpListeners = append([]net.Listener(nil), listeners...)
}
// registerHTTPHandlersLocked registers webhook and health-check handlers for
@@ -619,16 +628,33 @@ func (m *Manager) StartAll(ctx context.Context) error {
// Start shared HTTP server if configured
if m.httpServer != nil {
go func() {
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
"addr": m.httpServer.Addr,
})
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"error": err.Error(),
})
if len(m.httpListeners) > 0 {
for _, listener := range m.httpListeners {
ln := listener
go func() {
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
"addr": ln.Addr().String(),
})
if err := m.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"addr": ln.Addr().String(),
"error": err.Error(),
})
}
}()
}
}()
} else {
go func() {
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
"addr": m.httpServer.Addr,
})
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"error": err.Error(),
})
}
}()
}
}
logger.InfoCF("channels", "Channel startup completed", map[string]any{
@@ -655,6 +681,7 @@ func (m *Manager) StopAll(ctx context.Context) error {
})
}
m.httpServer = nil
m.httpListeners = nil
}
// Cancel dispatcher
+6
View File
@@ -1136,6 +1136,8 @@ func LoadConfig(path string) (*Config, error) {
applyLegacyBindingsMigration(data, cfg)
gatewayHostBeforeEnv := cfg.Gateway.Host
if err = env.Parse(cfg); err != nil {
return nil, err
}
@@ -1144,6 +1146,10 @@ func LoadConfig(path string) (*Config, error) {
if err = InitChannelList(cfg.Channels); err != nil {
return nil, err
}
cfg.Gateway.Host, err = resolveGatewayHostFromEnv(gatewayHostBeforeEnv)
if err != nil {
return nil, fmt.Errorf("invalid gateway host: %w", err)
}
// Expand multi-key configs into separate entries for key-level failover
cfg.ModelList = expandMultiKeyModels(cfg.ModelList)
+2 -2
View File
@@ -503,7 +503,7 @@ func TestDefaultConfig_Temperature(t *testing.T) {
func TestDefaultConfig_Gateway(t *testing.T) {
cfg := DefaultConfig()
if cfg.Gateway.Host != "127.0.0.1" {
if cfg.Gateway.Host != "localhost" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
@@ -739,7 +739,7 @@ func TestConfig_Complete(t *testing.T) {
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
if cfg.Gateway.Host != "127.0.0.1" {
if cfg.Gateway.Host != "localhost" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
+1 -1
View File
@@ -259,7 +259,7 @@ func DefaultConfig() *Config {
},
},
Gateway: GatewayConfig{
Host: "127.0.0.1",
Host: "localhost",
Port: 18790,
HotReload: false,
LogLevel: DefaultGatewayLogLevel,
+1 -1
View File
@@ -39,7 +39,7 @@ const (
EnvBinary = "PICOCLAW_BINARY"
// EnvGatewayHost overrides the host address for the gateway server.
// Default: "127.0.0.1"
// Default: "localhost"
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
)
+27
View File
@@ -3,8 +3,10 @@ package config
import (
"encoding/json"
"os"
"strings"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/netbind"
)
const DefaultGatewayLogLevel = "warn"
@@ -49,6 +51,31 @@ func EffectiveGatewayLogLevel(cfg *Config) string {
return normalizeGatewayLogLevel(cfg.Gateway.LogLevel)
}
func resolveGatewayHostFromEnv(baseHost string) (string, error) {
envHost, ok := os.LookupEnv(EnvGatewayHost)
if !ok {
return normalizeGatewayHostInput(baseHost)
}
envHost = strings.TrimSpace(envHost)
if envHost == "" {
return normalizeGatewayHostInput(baseHost)
}
return normalizeGatewayHostInput(envHost)
}
func normalizeGatewayHostInput(host string) (string, error) {
host = strings.TrimSpace(host)
if host == "" {
host = strings.TrimSpace(DefaultConfig().Gateway.Host)
}
if host == "" {
host = "localhost"
}
return netbind.NormalizeHostInput(host)
}
// ResolveGatewayLogLevel reads the configured gateway log level without triggering
// the full config loader, so startup code can apply logging before config load logs run.
// The PICOCLAW_LOG_LEVEL environment variable overrides the file value.
+98
View File
@@ -0,0 +1,98 @@
package config
import (
"fmt"
"os"
"path/filepath"
"testing"
)
func writeGatewayHostTestConfig(t *testing.T, host string) string {
t.Helper()
configPath := filepath.Join(t.TempDir(), "config.json")
raw := fmt.Sprintf(`{"version":2,"gateway":{"host":%q,"port":18790}}`, host)
if err := os.WriteFile(configPath, []byte(raw), 0o600); err != nil {
t.Fatalf("WriteFile(configPath): %v", err)
}
return configPath
}
func TestLoadConfig_GatewayHostEnvTrimmed(t *testing.T) {
configPath := writeGatewayHostTestConfig(t, "127.0.0.1")
t.Setenv(EnvGatewayHost, " ::1 ")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Gateway.Host != "::1" {
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1")
}
}
func TestLoadConfig_GatewayHostBlankEnvFallsBackToConfigHost(t *testing.T) {
configPath := writeGatewayHostTestConfig(t, " localhost ")
t.Setenv(EnvGatewayHost, " ")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
want, err := normalizeGatewayHostInput("localhost")
if err != nil {
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
}
if cfg.Gateway.Host != want {
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want)
}
}
func TestLoadConfig_GatewayHostBlankEnvAndConfigFallsBackToDefault(t *testing.T) {
configPath := writeGatewayHostTestConfig(t, " ")
t.Setenv(EnvGatewayHost, " ")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
defaultHost, err := normalizeGatewayHostInput(DefaultConfig().Gateway.Host)
if err != nil {
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
}
if cfg.Gateway.Host != defaultHost {
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, defaultHost)
}
}
func TestLoadConfig_GatewayHostEnvPreservesExplicitWildcardHost(t *testing.T) {
configPath := writeGatewayHostTestConfig(t, "localhost")
t.Setenv(EnvGatewayHost, " 0.0.0.0 ")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
want, err := normalizeGatewayHostInput("0.0.0.0")
if err != nil {
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
}
if cfg.Gateway.Host != want {
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want)
}
}
func TestLoadConfig_GatewayHostEnvNormalizesMultiHostInput(t *testing.T) {
configPath := writeGatewayHostTestConfig(t, "localhost")
t.Setenv(EnvGatewayHost, " [::1] , 127.0.0.1 , ::1 ")
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Gateway.Host != "::1,127.0.0.1" {
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1,127.0.0.1")
}
}
+43 -9
View File
@@ -3,10 +3,12 @@ package gateway
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -42,6 +44,7 @@ import (
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/netbind"
"github.com/sipeed/picoclaw/pkg/pid"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
@@ -159,13 +162,30 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
logger.Infof("Log level set to %q", effectiveLogLevel)
}
bindPlan, listenResult, err := openGatewayListeners(cfg.Gateway.Host, cfg.Gateway.Port)
if err != nil {
return fmt.Errorf("error opening gateway listeners: %w", err)
}
// Enforce singleton: write PID file with generated token.
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
pidData, err := pid.WritePidFile(homePath, bindPlan.ProbeHost, cfg.Gateway.Port)
if err != nil {
logger.Warnf("write pid file failed: %v", err)
for _, ln := range listenResult.Listeners {
_ = ln.Close()
}
return fmt.Errorf("singleton check failed: %w", err)
}
defer pid.RemovePidFile(homePath)
closeListeners := true
defer func() {
if !closeListeners {
return
}
for _, ln := range listenResult.Listeners {
_ = ln.Close()
}
}()
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
if err != nil {
@@ -193,10 +213,11 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
"skills_available": skillsInfo["available"],
})
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token)
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token, listenResult)
if err != nil {
return err
}
closeListeners = false
// Setup manual reload channel for /reload endpoint
manualReloadChan := make(chan struct{}, 1)
@@ -217,7 +238,9 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
runningServices.HealthServer.SetReloadFunc(reloadTrigger)
agentLoop.SetReloadFunc(reloadTrigger)
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
for _, bindHost := range listenResult.BindHosts {
fmt.Printf("✓ Gateway started on %s\n", net.JoinHostPort(bindHost, strconv.Itoa(cfg.Gateway.Port)))
}
fmt.Println("Press Ctrl+C to stop")
ctx, cancel := context.WithCancel(context.Background())
@@ -320,6 +343,7 @@ func setupAndStartServices(
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
authToken string,
listenResult netbind.OpenResult,
) (*services, error) {
runningServices := &services{}
@@ -390,10 +414,20 @@ func setupAndStartServices(
fmt.Println("⚠ Warning: No channels enabled")
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.authToken = authToken
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
runningServices.HealthServer = health.NewServer(listenResult.ProbeHost, cfg.Gateway.Port, authToken)
var listenAddr string
if len(listenResult.Listeners) > 0 {
listenAddr = listenResult.Listeners[0].Addr().String()
} else {
listenAddr = net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port))
}
runningServices.ChannelManager.SetupHTTPServerListeners(
listenResult.Listeners,
listenAddr,
runningServices.HealthServer,
)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return nil, fmt.Errorf("error starting channels: %w", err)
@@ -409,10 +443,10 @@ func setupAndStartServices(
voiceAgent.Start(vaCtx)
}
healthAddr := net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port))
fmt.Printf(
"✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n",
cfg.Gateway.Host,
cfg.Gateway.Port,
"✓ Health endpoints available at http://%s/health, /ready and /reload (POST)\n",
healthAddr,
)
stateManager := state.NewManager(cfg.WorkspacePath())
+21
View File
@@ -0,0 +1,21 @@
package gateway
import (
"strconv"
"github.com/sipeed/picoclaw/pkg/netbind"
)
func openGatewayListeners(host string, port int) (netbind.Plan, netbind.OpenResult, error) {
plan, err := netbind.BuildPlan(host, netbind.DefaultLoopback)
if err != nil {
return netbind.Plan{}, netbind.OpenResult{}, err
}
result, err := netbind.OpenPlan(plan, strconv.Itoa(port))
if err != nil {
return netbind.Plan{}, netbind.OpenResult{}, err
}
return plan, result, nil
}
+130
View File
@@ -0,0 +1,130 @@
package gateway
import (
"context"
"errors"
"io"
"net"
"net/http"
"strconv"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/netbind"
)
func TestOpenGatewayListeners_HonorsIPv6OnlyHost(t *testing.T) {
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
if !hasIPv6 {
t.Skip("IPv6 is unavailable in this environment")
}
_, result, err := openGatewayListeners("::", 0)
if err != nil {
t.Fatalf("openGatewayListeners() error = %v", err)
}
startGatewayTestHTTPServer(t, result.Listeners)
port := mustGatewayAtoi(t, result.Port)
requireGatewayHTTPReachable(t, "::1", port)
if hasIPv4 {
requireGatewayHTTPUnreachable(t, "127.0.0.1", port)
}
}
func TestOpenGatewayListeners_SupportsExplicitMultiHost(t *testing.T) {
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
if !hasIPv4 || !hasIPv6 {
t.Skip("dual-stack loopback is unavailable in this environment")
}
_, result, err := openGatewayListeners("127.0.0.1,::1", 0)
if err != nil {
t.Fatalf("openGatewayListeners() error = %v", err)
}
startGatewayTestHTTPServer(t, result.Listeners)
port := mustGatewayAtoi(t, result.Port)
requireGatewayHTTPReachable(t, "127.0.0.1", port)
requireGatewayHTTPReachable(t, "::1", port)
}
func startGatewayTestHTTPServer(t *testing.T, listeners []net.Listener) {
t.Helper()
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "ok")
}),
}
errCh := make(chan error, len(listeners))
for _, listener := range listeners {
ln := listener
go func() {
errCh <- server.Serve(ln)
}()
}
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
for range listeners {
err := <-errCh
if err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("server.Serve() error = %v", err)
}
}
})
}
func requireGatewayHTTPReachable(t *testing.T, host string, port int) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for {
err := gatewayHTTPGet(host, port)
if err == nil {
return
}
if time.Now().After(deadline) {
t.Fatalf("expected %s:%d to be reachable: %v", host, port, err)
}
time.Sleep(50 * time.Millisecond)
}
}
func requireGatewayHTTPUnreachable(t *testing.T, host string, port int) {
t.Helper()
if err := gatewayHTTPGet(host, port); err == nil {
t.Fatalf("expected %s:%d to be unreachable", host, port)
}
}
func gatewayHTTPGet(host string, port int) error {
client := &http.Client{
Timeout: 300 * time.Millisecond,
Transport: &http.Transport{
Proxy: nil,
},
}
resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port)))
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}
return nil
}
func mustGatewayAtoi(t *testing.T, value string) int {
t.Helper()
n, err := strconv.Atoi(value)
if err != nil {
t.Fatalf("Atoi(%q) error = %v", value, err)
}
return n
}
+3 -2
View File
@@ -4,10 +4,11 @@ import (
"context"
"crypto/subtle"
"encoding/json"
"fmt"
"maps"
"net"
"net/http"
"os"
"strconv"
"sync"
"time"
)
@@ -49,7 +50,7 @@ func NewServer(host string, port int, token string) *Server {
mux.HandleFunc("/ready", s.readyHandler)
mux.HandleFunc("/reload", s.reloadHandler)
addr := fmt.Sprintf("%s:%d", host, port)
addr := net.JoinHostPort(host, strconv.Itoa(port))
s.server = &http.Server{
Addr: addr,
Handler: mux,
+10
View File
@@ -305,6 +305,16 @@ func TestNewServer(t *testing.T) {
}
}
func TestNewServer_IPv6ListenAddrFormatting(t *testing.T) {
s := NewServer("::", 18790, "")
if s.server == nil {
t.Fatal("server should be initialized")
}
if s.server.Addr != "[::]:18790" {
t.Fatalf("server.Addr = %q, want %q", s.server.Addr, "[::]:18790")
}
}
func TestStartContext_Cancellation(t *testing.T) {
s := NewServer("127.0.0.1", 0, "")
+606
View File
@@ -0,0 +1,606 @@
package netbind
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
)
type DefaultMode int
const (
DefaultLoopback DefaultMode = iota
DefaultAny
)
type groupKind int
const (
groupAdaptiveLoopback groupKind = iota
groupAdaptiveAny
groupExact
)
type exactBinding struct {
host string
network string
v6Only bool
}
type bindGroup struct {
kind groupKind
allowIPv4 bool
allowIPv6 bool
exact exactBinding
}
type Plan struct {
groups []bindGroup
ProbeHost string
}
type OpenResult struct {
Listeners []net.Listener
BindHosts []string
Port string
ProbeHost string
}
type tokenKind int
const (
tokenName tokenKind = iota
tokenLocalhost
tokenStar
tokenIPv4
tokenIPv6
tokenIPv4Any
tokenIPv6Any
)
type hostToken struct {
kind tokenKind
canonical string
key string
}
var (
ipFamiliesOnce sync.Once
hasIPv4 bool
hasIPv6 bool
)
func DetectIPFamilies() (bool, bool) {
ipFamiliesOnce.Do(func() {
if ips, err := net.LookupIP("localhost"); err == nil {
for _, ip := range ips {
if ip == nil {
continue
}
if ip.To4() != nil {
hasIPv4 = true
continue
}
hasIPv6 = true
}
}
if hasIPv4 && hasIPv6 {
return
}
if addrs, err := net.InterfaceAddrs(); err == nil {
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok || ipnet.IP == nil {
continue
}
if ipnet.IP.To4() != nil {
hasIPv4 = true
continue
}
hasIPv6 = true
}
}
})
return hasIPv4, hasIPv6
}
func SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string {
switch {
case hasIPv4 && hasIPv6:
return "localhost"
case hasIPv6:
return "::1"
case hasIPv4:
return "127.0.0.1"
default:
return "localhost"
}
}
func SelectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string {
switch {
case hasIPv4 && hasIPv6:
return "::"
case hasIPv6:
return "::"
case hasIPv4:
return "0.0.0.0"
default:
return "::"
}
}
func ResolveAdaptiveLoopbackHost() string {
hasIPv4, hasIPv6 := DetectIPFamilies()
return SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6)
}
func ResolveAdaptiveAnyHost() string {
hasIPv4, hasIPv6 := DetectIPFamilies()
return SelectAdaptiveAnyHost(hasIPv4, hasIPv6)
}
func IsLoopbackHost(host string) bool {
host = strings.TrimSpace(host)
if host == "" {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(strings.Trim(host, "[]"))
return ip != nil && ip.IsLoopback()
}
func IsUnspecifiedHost(host string) bool {
host = strings.TrimSpace(host)
if host == "" {
return false
}
ip := net.ParseIP(strings.Trim(host, "[]"))
return ip != nil && ip.IsUnspecified()
}
func NormalizeHostInput(raw string) (string, error) {
tokens, err := parseHostTokens(raw)
if err != nil {
return "", err
}
parts := make([]string, 0, len(tokens))
for _, token := range tokens {
parts = append(parts, token.canonical)
}
return strings.Join(parts, ","), nil
}
func BuildPlan(raw string, defaultMode DefaultMode) (Plan, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return buildDefaultPlan(defaultMode), nil
}
tokens, err := parseHostTokens(raw)
if err != nil {
return Plan{}, err
}
for _, token := range tokens {
if token.kind == tokenStar {
return Plan{
groups: []bindGroup{{kind: groupAdaptiveAny}},
ProbeHost: ResolveAdaptiveLoopbackHost(),
}, nil
}
}
hasIPv4Any := false
hasIPv6Any := false
for _, token := range tokens {
switch token.kind {
case tokenIPv4Any:
hasIPv4Any = true
case tokenIPv6Any:
hasIPv6Any = true
}
}
allowLocalhostIPv4 := !hasIPv4Any
allowLocalhostIPv6 := !hasIPv6Any
groups := make([]bindGroup, 0, len(tokens))
seenExact := make(map[string]struct{}, len(tokens))
addedLocalhost := false
for _, token := range tokens {
switch token.kind {
case tokenLocalhost:
if addedLocalhost || (!allowLocalhostIPv4 && !allowLocalhostIPv6) {
continue
}
groups = append(groups, bindGroup{
kind: groupAdaptiveLoopback,
allowIPv4: allowLocalhostIPv4,
allowIPv6: allowLocalhostIPv6,
})
addedLocalhost = true
case tokenIPv4Any:
key := "exact:tcp4:0.0.0.0"
if _, ok := seenExact[key]; ok {
continue
}
seenExact[key] = struct{}{}
groups = append(groups, bindGroup{
kind: groupExact,
exact: exactBinding{
host: "0.0.0.0",
network: "tcp4",
},
})
case tokenIPv6Any:
key := "exact:tcp6:::"
if _, ok := seenExact[key]; ok {
continue
}
seenExact[key] = struct{}{}
groups = append(groups, bindGroup{
kind: groupExact,
exact: exactBinding{
host: "::",
network: "tcp6",
v6Only: true,
},
})
case tokenIPv4:
if hasIPv4Any {
continue
}
key := "exact:tcp4:" + strings.ToLower(token.canonical)
if _, ok := seenExact[key]; ok {
continue
}
seenExact[key] = struct{}{}
groups = append(groups, bindGroup{
kind: groupExact,
exact: exactBinding{
host: token.canonical,
network: "tcp4",
},
})
case tokenIPv6:
if hasIPv6Any {
continue
}
key := "exact:tcp6:" + strings.ToLower(token.canonical)
if _, ok := seenExact[key]; ok {
continue
}
seenExact[key] = struct{}{}
groups = append(groups, bindGroup{
kind: groupExact,
exact: exactBinding{
host: token.canonical,
network: "tcp6",
v6Only: true,
},
})
case tokenName:
key := "exact:tcp:" + token.key
if _, ok := seenExact[key]; ok {
continue
}
seenExact[key] = struct{}{}
groups = append(groups, bindGroup{
kind: groupExact,
exact: exactBinding{
host: token.canonical,
network: "tcp",
},
})
}
}
plan := Plan{groups: groups}
plan.ProbeHost = probeHostForGroups(groups)
return plan, nil
}
func OpenPlan(plan Plan, port string) (OpenResult, error) {
if port == "" {
return OpenResult{}, errors.New("port cannot be empty")
}
selectedPort := port
listeners := make([]net.Listener, 0, len(plan.groups))
bindHosts := make([]string, 0, len(plan.groups))
bindSeen := make(map[string]struct{}, len(plan.groups))
closeAll := func() {
for _, ln := range listeners {
_ = ln.Close()
}
}
for _, group := range plan.groups {
groupListeners, groupHosts, actualPort, err := openGroup(group, selectedPort)
if err != nil {
closeAll()
return OpenResult{}, err
}
if selectedPort == "0" && actualPort != "" {
selectedPort = actualPort
}
listeners = append(listeners, groupListeners...)
for _, host := range groupHosts {
key := strings.ToLower(host)
if _, ok := bindSeen[key]; ok {
continue
}
bindSeen[key] = struct{}{}
bindHosts = append(bindHosts, host)
}
}
return OpenResult{
Listeners: listeners,
BindHosts: bindHosts,
Port: selectedPort,
ProbeHost: plan.ProbeHost,
}, nil
}
func buildDefaultPlan(defaultMode DefaultMode) Plan {
switch defaultMode {
case DefaultAny:
return Plan{
groups: []bindGroup{{kind: groupAdaptiveAny}},
ProbeHost: ResolveAdaptiveLoopbackHost(),
}
default:
return Plan{
groups: []bindGroup{{
kind: groupAdaptiveLoopback,
allowIPv4: true,
allowIPv6: true,
}},
ProbeHost: ResolveAdaptiveLoopbackHost(),
}
}
}
func probeHostForGroups(groups []bindGroup) string {
hasIPv4Any := false
hasIPv6Any := false
for _, group := range groups {
if group.kind == groupAdaptiveLoopback {
switch {
case group.allowIPv4 && group.allowIPv6:
return ResolveAdaptiveLoopbackHost()
case group.allowIPv6:
return "::1"
case group.allowIPv4:
return "127.0.0.1"
}
}
if group.kind == groupAdaptiveAny {
return ResolveAdaptiveLoopbackHost()
}
if group.kind != groupExact {
continue
}
switch group.exact.host {
case "0.0.0.0":
hasIPv4Any = true
case "::":
hasIPv6Any = true
}
}
switch {
case hasIPv4Any && hasIPv6Any:
return ResolveAdaptiveLoopbackHost()
case hasIPv6Any:
return "::1"
case hasIPv4Any:
return "127.0.0.1"
}
for _, group := range groups {
if group.kind == groupExact {
return group.exact.host
}
}
return ResolveAdaptiveLoopbackHost()
}
func parseHostTokens(raw string) ([]hostToken, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, errors.New("host cannot be empty")
}
parts := strings.Split(raw, ",")
tokens := make([]hostToken, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, part := range parts {
token, err := parseHostToken(part)
if err != nil {
return nil, err
}
if _, ok := seen[token.key]; ok {
continue
}
seen[token.key] = struct{}{}
tokens = append(tokens, token)
}
if len(tokens) == 0 {
return nil, errors.New("host cannot be empty")
}
return tokens, nil
}
func parseHostToken(raw string) (hostToken, error) {
host := strings.TrimSpace(raw)
if host == "" {
return hostToken{}, errors.New("host list contains an empty entry")
}
if host == "*" {
return hostToken{kind: tokenStar, canonical: "*", key: "*"}, nil
}
if strings.EqualFold(host, "localhost") {
return hostToken{kind: tokenLocalhost, canonical: "localhost", key: "localhost"}, nil
}
trimmed := strings.Trim(host, "[]")
if ip := net.ParseIP(trimmed); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
canonical := ip4.String()
kind := tokenIPv4
if ip4.IsUnspecified() {
kind = tokenIPv4Any
}
return hostToken{kind: kind, canonical: canonical, key: canonical}, nil
}
canonical := ip.String()
kind := tokenIPv6
if ip.IsUnspecified() {
kind = tokenIPv6Any
}
return hostToken{kind: kind, canonical: canonical, key: strings.ToLower(canonical)}, nil
}
return hostToken{
kind: tokenName,
canonical: host,
key: strings.ToLower(host),
}, nil
}
func openGroup(group bindGroup, port string) ([]net.Listener, []string, string, error) {
switch group.kind {
case groupAdaptiveLoopback:
return openAdaptiveLoopbackGroup(group.allowIPv6, group.allowIPv4, port)
case groupAdaptiveAny:
return openAdaptiveAnyGroup(port)
case groupExact:
ln, actualPort, err := openExactListener(group.exact, port)
if err != nil {
return nil, nil, "", err
}
return []net.Listener{ln}, []string{group.exact.host}, actualPort, nil
default:
return nil, nil, "", fmt.Errorf("unsupported bind group kind: %d", group.kind)
}
}
func openAdaptiveLoopbackGroup(allowIPv6, allowIPv4 bool, port string) ([]net.Listener, []string, string, error) {
if allowIPv6 && allowIPv4 {
if ln6, actualPort, err6 := openExactListener(
exactBinding{host: "::1", network: "tcp6", v6Only: true},
port,
); err6 == nil {
if ln4, _, err4 := openExactListener(
exactBinding{host: "127.0.0.1", network: "tcp4"},
actualPort,
); err4 == nil {
return []net.Listener{ln6, ln4}, []string{"::1", "127.0.0.1"}, actualPort, nil
}
_ = ln6.Close()
}
}
if allowIPv6 {
ln6, actualPort, err := openExactListener(exactBinding{host: "::1", network: "tcp6", v6Only: true}, port)
if err == nil {
return []net.Listener{ln6}, []string{"::1"}, actualPort, nil
}
}
if allowIPv4 {
ln4, actualPort, err := openExactListener(exactBinding{host: "127.0.0.1", network: "tcp4"}, port)
if err == nil {
return []net.Listener{ln4}, []string{"127.0.0.1"}, actualPort, nil
}
}
return nil, nil, "", fmt.Errorf("failed to open adaptive localhost listener on port %s", port)
}
func openAdaptiveAnyGroup(port string) ([]net.Listener, []string, string, error) {
hasIPv4, hasIPv6 := DetectIPFamilies()
if hasIPv4 && hasIPv6 {
if ln6, actualPort, err6 := openExactListener(
exactBinding{host: "::", network: "tcp6", v6Only: true},
port,
); err6 == nil {
if ln4, _, err4 := openExactListener(
exactBinding{host: "0.0.0.0", network: "tcp4"},
actualPort,
); err4 == nil {
return []net.Listener{ln6, ln4}, []string{"::", "0.0.0.0"}, actualPort, nil
}
_ = ln6.Close()
}
}
if hasIPv6 {
ln6, actualPort, err := openExactListener(exactBinding{host: "::", network: "tcp6", v6Only: true}, port)
if err == nil {
return []net.Listener{ln6}, []string{"::"}, actualPort, nil
}
}
if hasIPv4 {
ln4, actualPort, err := openExactListener(exactBinding{host: "0.0.0.0", network: "tcp4"}, port)
if err == nil {
return []net.Listener{ln4}, []string{"0.0.0.0"}, actualPort, nil
}
}
return nil, nil, "", fmt.Errorf("failed to open adaptive any-host listener on port %s", port)
}
func openExactListener(binding exactBinding, port string) (net.Listener, string, error) {
listenConfig := net.ListenConfig{}
if binding.network == "tcp6" && binding.v6Only {
listenConfig.Control = applyIPv6OnlyControl(true)
}
ln, err := listenConfig.Listen(context.Background(), binding.network, net.JoinHostPort(binding.host, port))
if err != nil {
return nil, "", err
}
actualPort, err := listenerPort(ln)
if err != nil {
_ = ln.Close()
return nil, "", err
}
return ln, actualPort, nil
}
func listenerPort(ln net.Listener) (string, error) {
addr, ok := ln.Addr().(*net.TCPAddr)
if ok {
return strconv.Itoa(addr.Port), nil
}
_, port, err := net.SplitHostPort(ln.Addr().String())
if err != nil {
return "", err
}
return port, nil
}
+280
View File
@@ -0,0 +1,280 @@
package netbind
import (
"context"
"errors"
"io"
"net"
"net/http"
"strconv"
"testing"
"time"
)
func TestNormalizeHostInput(t *testing.T) {
tests := []struct {
name string
raw string
want string
wantErr bool
}{
{name: "single host", raw: "127.0.0.1", want: "127.0.0.1"},
{name: "trim and dedupe", raw: " [::1] , ::1 , 127.0.0.1 ", want: "::1,127.0.0.1"},
{name: "star preserved", raw: "*,127.0.0.1", want: "*,127.0.0.1"},
{name: "reject empty", raw: "127.0.0.1, ", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeHostInput(tt.raw)
if (err != nil) != tt.wantErr {
t.Fatalf("NormalizeHostInput() err = %v, wantErr %t", err, tt.wantErr)
}
if tt.wantErr {
return
}
if got != tt.want {
t.Fatalf("NormalizeHostInput() = %q, want %q", got, tt.want)
}
})
}
}
func TestBuildPlan_DefaultAnyUsesLoopbackProbe(t *testing.T) {
plan, err := BuildPlan("", DefaultAny)
if err != nil {
t.Fatalf("BuildPlan() error = %v", err)
}
if plan.ProbeHost != ResolveAdaptiveLoopbackHost() {
t.Fatalf("ProbeHost = %q, want %q", plan.ProbeHost, ResolveAdaptiveLoopbackHost())
}
}
func TestOpenPlan_LocalhostSupportsLoopbackCommunication(t *testing.T) {
hasIPv4, hasIPv6 := DetectIPFamilies()
plan, err := BuildPlan("localhost", DefaultLoopback)
if err != nil {
t.Fatalf("BuildPlan() error = %v", err)
}
result, err := OpenPlan(plan, "0")
if err != nil {
t.Fatalf("OpenPlan() error = %v", err)
}
startTestHTTPServer(t, result.Listeners)
port := mustAtoi(t, result.Port)
if hasIPv6 {
requireHTTPReachable(t, "::1", port)
}
if hasIPv4 {
requireHTTPReachable(t, "127.0.0.1", port)
}
}
func TestOpenPlan_DefaultAnySupportsDualStackLoopback(t *testing.T) {
hasIPv4, hasIPv6 := DetectIPFamilies()
plan, err := BuildPlan("", DefaultAny)
if err != nil {
t.Fatalf("BuildPlan() error = %v", err)
}
result, err := OpenPlan(plan, "0")
if err != nil {
t.Fatalf("OpenPlan() error = %v", err)
}
startTestHTTPServer(t, result.Listeners)
port := mustAtoi(t, result.Port)
if hasIPv6 {
requireHTTPReachable(t, "::1", port)
}
if hasIPv4 {
requireHTTPReachable(t, "127.0.0.1", port)
}
switch {
case hasIPv4 && hasIPv6:
if len(result.BindHosts) != 2 {
t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts)
}
case hasIPv6 || hasIPv4:
if len(result.BindHosts) != 1 {
t.Fatalf("len(BindHosts) = %d, want 1 (%#v)", len(result.BindHosts), result.BindHosts)
}
}
}
func TestOpenPlan_ExplicitIPv6AnyIsIPv6Only(t *testing.T) {
hasIPv4, hasIPv6 := DetectIPFamilies()
if !hasIPv6 {
t.Skip("IPv6 is unavailable in this environment")
}
plan, err := BuildPlan("::", DefaultLoopback)
if err != nil {
t.Fatalf("BuildPlan() error = %v", err)
}
result, err := OpenPlan(plan, "0")
if err != nil {
t.Fatalf("OpenPlan() error = %v", err)
}
startTestHTTPServer(t, result.Listeners)
port := mustAtoi(t, result.Port)
requireHTTPReachable(t, "::1", port)
if hasIPv4 {
requireHTTPUnreachable(t, "127.0.0.1", port)
}
}
func TestOpenPlan_ExplicitIPv4AnyIsIPv4Only(t *testing.T) {
hasIPv4, hasIPv6 := DetectIPFamilies()
if !hasIPv4 {
t.Skip("IPv4 is unavailable in this environment")
}
plan, err := BuildPlan("0.0.0.0", DefaultLoopback)
if err != nil {
t.Fatalf("BuildPlan() error = %v", err)
}
result, err := OpenPlan(plan, "0")
if err != nil {
t.Fatalf("OpenPlan() error = %v", err)
}
startTestHTTPServer(t, result.Listeners)
port := mustAtoi(t, result.Port)
requireHTTPReachable(t, "127.0.0.1", port)
if hasIPv6 {
requireHTTPUnreachable(t, "::1", port)
}
}
func TestOpenPlan_MultiHostSupportsExplicitIPv4AndIPv6(t *testing.T) {
hasIPv4, hasIPv6 := DetectIPFamilies()
if !hasIPv4 || !hasIPv6 {
t.Skip("dual-stack loopback is unavailable in this environment")
}
plan, err := BuildPlan("127.0.0.1,::1", DefaultLoopback)
if err != nil {
t.Fatalf("BuildPlan() error = %v", err)
}
result, err := OpenPlan(plan, "0")
if err != nil {
t.Fatalf("OpenPlan() error = %v", err)
}
startTestHTTPServer(t, result.Listeners)
port := mustAtoi(t, result.Port)
requireHTTPReachable(t, "127.0.0.1", port)
requireHTTPReachable(t, "::1", port)
}
func TestOpenPlan_WildcardRulesKeepIPv4AndIPv6AnyHosts(t *testing.T) {
hasIPv4, hasIPv6 := DetectIPFamilies()
if !hasIPv4 || !hasIPv6 {
t.Skip("dual-stack loopback is unavailable in this environment")
}
plan, err := BuildPlan("::,::1,0.0.0.0,127.0.0.1", DefaultLoopback)
if err != nil {
t.Fatalf("BuildPlan() error = %v", err)
}
result, err := OpenPlan(plan, "0")
if err != nil {
t.Fatalf("OpenPlan() error = %v", err)
}
startTestHTTPServer(t, result.Listeners)
port := mustAtoi(t, result.Port)
requireHTTPReachable(t, "127.0.0.1", port)
requireHTTPReachable(t, "::1", port)
if len(result.BindHosts) != 2 {
t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts)
}
}
func startTestHTTPServer(t *testing.T, listeners []net.Listener) {
t.Helper()
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "ok")
}),
}
errCh := make(chan error, len(listeners))
for _, listener := range listeners {
ln := listener
go func() {
errCh <- server.Serve(ln)
}()
}
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = server.Shutdown(ctx)
for range listeners {
err := <-errCh
if err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("server.Serve() error = %v", err)
}
}
})
}
func requireHTTPReachable(t *testing.T, host string, port int) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for {
err := httpGET(host, port)
if err == nil {
return
}
if time.Now().After(deadline) {
t.Fatalf("expected %s:%d to be reachable: %v", host, port, err)
}
time.Sleep(50 * time.Millisecond)
}
}
func requireHTTPUnreachable(t *testing.T, host string, port int) {
t.Helper()
if err := httpGET(host, port); err == nil {
t.Fatalf("expected %s:%d to be unreachable", host, port)
}
}
func httpGET(host string, port int) error {
client := &http.Client{
Timeout: 300 * time.Millisecond,
Transport: &http.Transport{
Proxy: nil,
},
}
resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port)))
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}
return nil
}
func mustAtoi(t *testing.T, value string) int {
t.Helper()
n, err := strconv.Atoi(value)
if err != nil {
t.Fatalf("Atoi(%q) error = %v", value, err)
}
return n
}
+25
View File
@@ -0,0 +1,25 @@
//go:build !windows
package netbind
import (
"syscall"
"golang.org/x/sys/unix"
)
func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error {
return func(_, _ string, rawConn syscall.RawConn) error {
var controlErr error
if err := rawConn.Control(func(fd uintptr) {
value := 0
if enabled {
value = 1
}
controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, value)
}); err != nil {
return err
}
return controlErr
}
}
+25
View File
@@ -0,0 +1,25 @@
//go:build windows
package netbind
import (
"syscall"
"golang.org/x/sys/windows"
)
func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error {
return func(_, _ string, rawConn syscall.RawConn) error {
var controlErr error
if err := rawConn.Control(func(fd uintptr) {
value := 0
if enabled {
value = 1
}
controlErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, value)
}); err != nil {
return err
}
return controlErr
}
}