mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
feat(host): complete launcher and gateway multi-host binding support
- add shared netbind planning for strict tcp4/tcp6 bind semantics - support launcher/gateway host env overrides and launcher-to-gateway forwarding - cover host binding and forwarding with network and subprocess env tests
This commit is contained in:
@@ -3,7 +3,6 @@ package gateway
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -11,15 +10,19 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/gateway"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
func resolveGatewayHostOverride(explicit bool, host string) (string, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if explicit && host == "" {
|
||||
return "", fmt.Errorf("the --host option cannot be empty")
|
||||
if !explicit {
|
||||
return "", nil
|
||||
}
|
||||
return host, nil
|
||||
normalized, err := netbind.NormalizeHostInput(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid --host value: %w", err)
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func NewGatewayCommand() *cobra.Command {
|
||||
|
||||
@@ -43,6 +43,7 @@ func TestResolveGatewayHostOverride(t *testing.T) {
|
||||
{name: "implicit empty host is allowed", explicit: false, host: "", wantHost: "", wantErr: false},
|
||||
{name: "explicit empty host rejected", explicit: true, host: " ", wantHost: "", wantErr: true},
|
||||
{name: "explicit localhost kept", explicit: true, host: " localhost ", wantHost: "localhost", wantErr: false},
|
||||
{name: "explicit multi host normalized", explicit: true, host: " [::1] , 127.0.0.1 ", wantHost: "::1,127.0.0.1", wantErr: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -465,7 +465,7 @@
|
||||
},
|
||||
"gateway": {
|
||||
"_comment": "Default log level is set to 'fatal'. Other available options are 'debug', 'info', 'warn' and 'error'.",
|
||||
"host": "127.0.0.1",
|
||||
"host": "localhost",
|
||||
"port": 18790,
|
||||
"hot_reload": false,
|
||||
"log_level": "fatal"
|
||||
|
||||
+36
-9
@@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
@@ -86,6 +87,7 @@ type Manager struct {
|
||||
dispatchTask *asyncTask
|
||||
mux *dynamicServeMux
|
||||
httpServer *http.Server
|
||||
httpListeners []net.Listener
|
||||
mu sync.RWMutex
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
@@ -474,6 +476,12 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
// It registers health endpoints from the health server and discovers channels
|
||||
// that implement WebhookHandler and/or HealthChecker to register their handlers.
|
||||
func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
m.SetupHTTPServerListeners(nil, addr, healthServer)
|
||||
}
|
||||
|
||||
// SetupHTTPServerListeners creates a shared HTTP server on pre-opened listeners.
|
||||
// When listeners is empty it falls back to Addr-based ListenAndServe behavior.
|
||||
func (m *Manager) SetupHTTPServerListeners(listeners []net.Listener, addr string, healthServer *health.Server) {
|
||||
m.mux = newDynamicServeMux()
|
||||
|
||||
// Register health endpoints
|
||||
@@ -490,6 +498,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
m.httpListeners = append([]net.Listener(nil), listeners...)
|
||||
}
|
||||
|
||||
// registerHTTPHandlersLocked registers webhook and health-check handlers for
|
||||
@@ -619,16 +628,33 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
// Start shared HTTP server if configured
|
||||
if m.httpServer != nil {
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
if len(m.httpListeners) > 0 {
|
||||
for _, listener := range m.httpListeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": ln.Addr().String(),
|
||||
})
|
||||
if err := m.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"addr": ln.Addr().String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel startup completed", map[string]any{
|
||||
@@ -655,6 +681,7 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
||||
})
|
||||
}
|
||||
m.httpServer = nil
|
||||
m.httpListeners = nil
|
||||
}
|
||||
|
||||
// Cancel dispatcher
|
||||
|
||||
@@ -1082,7 +1082,10 @@ func LoadConfig(path string) (*Config, error) {
|
||||
if err = InitChannelList(cfg.Channels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.Gateway.Host = resolveGatewayHostFromEnv(gatewayHostBeforeEnv)
|
||||
cfg.Gateway.Host, err = resolveGatewayHostFromEnv(gatewayHostBeforeEnv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid gateway host: %w", err)
|
||||
}
|
||||
|
||||
// Expand multi-key configs into separate entries for key-level failover
|
||||
cfg.ModelList = expandMultiKeyModels(cfg.ModelList)
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
EnvBinary = "PICOCLAW_BINARY"
|
||||
|
||||
// EnvGatewayHost overrides the host address for the gateway server.
|
||||
// Default: "127.0.0.1"
|
||||
// Default: "localhost"
|
||||
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
|
||||
)
|
||||
|
||||
|
||||
+16
-107
@@ -2,12 +2,11 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
const DefaultGatewayLogLevel = "warn"
|
||||
@@ -52,119 +51,29 @@ func EffectiveGatewayLogLevel(cfg *Config) string {
|
||||
return normalizeGatewayLogLevel(cfg.Gateway.LogLevel)
|
||||
}
|
||||
|
||||
var (
|
||||
gatewayIPFamiliesOnce sync.Once
|
||||
gatewayHasIPv4 bool
|
||||
gatewayHasIPv6 bool
|
||||
)
|
||||
|
||||
func detectGatewayIPFamilies() (bool, bool) {
|
||||
gatewayIPFamiliesOnce.Do(func() {
|
||||
if ips, err := net.LookupIP("localhost"); err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
gatewayHasIPv4 = true
|
||||
continue
|
||||
}
|
||||
gatewayHasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
if gatewayHasIPv4 && gatewayHasIPv6 {
|
||||
return
|
||||
}
|
||||
|
||||
if addrs, err := net.InterfaceAddrs(); err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipnet.IP == nil {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
gatewayHasIPv4 = true
|
||||
continue
|
||||
}
|
||||
gatewayHasIPv6 = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return gatewayHasIPv4, gatewayHasIPv6
|
||||
}
|
||||
|
||||
func selectAdaptiveGatewayLoopbackHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "localhost"
|
||||
case hasIPv6:
|
||||
return "::1"
|
||||
case hasIPv4:
|
||||
return "127.0.0.1"
|
||||
default:
|
||||
return "localhost"
|
||||
}
|
||||
}
|
||||
|
||||
func selectAdaptiveGatewayAnyHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "::"
|
||||
case hasIPv6:
|
||||
return "::"
|
||||
case hasIPv4:
|
||||
return "0.0.0.0"
|
||||
default:
|
||||
return "::"
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAdaptiveGatewayLoopbackHost() string {
|
||||
hasIPv4, hasIPv6 := detectGatewayIPFamilies()
|
||||
return selectAdaptiveGatewayLoopbackHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func resolveAdaptiveGatewayAnyHost() string {
|
||||
hasIPv4, hasIPv6 := detectGatewayIPFamilies()
|
||||
return selectAdaptiveGatewayAnyHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func normalizeGatewayHost(host string) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
host = strings.TrimSpace(DefaultConfig().Gateway.Host)
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return resolveAdaptiveGatewayLoopbackHost()
|
||||
}
|
||||
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
if ip := net.ParseIP(trimmed); ip != nil && ip.IsUnspecified() {
|
||||
return resolveAdaptiveGatewayAnyHost()
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
func resolveGatewayHostFromEnv(baseHost string) string {
|
||||
func resolveGatewayHostFromEnv(baseHost string) (string, error) {
|
||||
envHost, ok := os.LookupEnv(EnvGatewayHost)
|
||||
if !ok {
|
||||
return normalizeGatewayHost(baseHost)
|
||||
return normalizeGatewayHostInput(baseHost)
|
||||
}
|
||||
|
||||
envHost = strings.TrimSpace(envHost)
|
||||
if envHost == "" {
|
||||
return normalizeGatewayHost(baseHost)
|
||||
return normalizeGatewayHostInput(baseHost)
|
||||
}
|
||||
|
||||
return normalizeGatewayHost(envHost)
|
||||
return normalizeGatewayHostInput(envHost)
|
||||
}
|
||||
|
||||
func normalizeGatewayHostInput(host string) (string, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
host = strings.TrimSpace(DefaultConfig().Gateway.Host)
|
||||
}
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
return netbind.NormalizeHostInput(host)
|
||||
}
|
||||
|
||||
// ResolveGatewayLogLevel reads the configured gateway log level without triggering
|
||||
|
||||
@@ -39,7 +39,10 @@ func TestLoadConfig_GatewayHostBlankEnvFallsBackToConfigHost(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
want := normalizeGatewayHost("localhost")
|
||||
want, err := normalizeGatewayHostInput("localhost")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != want {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want)
|
||||
}
|
||||
@@ -54,13 +57,16 @@ func TestLoadConfig_GatewayHostBlankEnvAndConfigFallsBackToDefault(t *testing.T)
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
defaultHost := normalizeGatewayHost(DefaultConfig().Gateway.Host)
|
||||
defaultHost, err := normalizeGatewayHostInput(DefaultConfig().Gateway.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != defaultHost {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, defaultHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_GatewayHostEnvWildcardUsesAdaptiveAnyHost(t *testing.T) {
|
||||
func TestLoadConfig_GatewayHostEnvPreservesExplicitWildcardHost(t *testing.T) {
|
||||
configPath := writeGatewayHostTestConfig(t, "localhost")
|
||||
t.Setenv(EnvGatewayHost, " 0.0.0.0 ")
|
||||
|
||||
@@ -69,8 +75,24 @@ func TestLoadConfig_GatewayHostEnvWildcardUsesAdaptiveAnyHost(t *testing.T) {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
|
||||
want := normalizeGatewayHost("0.0.0.0")
|
||||
want, err := normalizeGatewayHostInput("0.0.0.0")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeGatewayHostInput() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != want {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_GatewayHostEnvNormalizesMultiHostInput(t *testing.T) {
|
||||
configPath := writeGatewayHostTestConfig(t, "localhost")
|
||||
t.Setenv(EnvGatewayHost, " [::1] , 127.0.0.1 , ::1 ")
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.Host != "::1,127.0.0.1" {
|
||||
t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1,127.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
+39
-8
@@ -44,6 +44,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
"github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
@@ -161,13 +162,30 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
logger.Infof("Log level set to %q", effectiveLogLevel)
|
||||
}
|
||||
|
||||
bindPlan, listenResult, err := openGatewayListeners(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening gateway listeners: %w", err)
|
||||
}
|
||||
|
||||
// Enforce singleton: write PID file with generated token.
|
||||
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
pidData, err := pid.WritePidFile(homePath, bindPlan.ProbeHost, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
logger.Warnf("write pid file failed: %v", err)
|
||||
for _, ln := range listenResult.Listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
return fmt.Errorf("singleton check failed: %w", err)
|
||||
}
|
||||
defer pid.RemovePidFile(homePath)
|
||||
closeListeners := true
|
||||
defer func() {
|
||||
if !closeListeners {
|
||||
return
|
||||
}
|
||||
for _, ln := range listenResult.Listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
@@ -195,10 +213,11 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token)
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token, listenResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
closeListeners = false
|
||||
|
||||
// Setup manual reload channel for /reload endpoint
|
||||
manualReloadChan := make(chan struct{}, 1)
|
||||
@@ -219,8 +238,9 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
runningServices.HealthServer.SetReloadFunc(reloadTrigger)
|
||||
agentLoop.SetReloadFunc(reloadTrigger)
|
||||
|
||||
listenAddr := net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port))
|
||||
fmt.Printf("✓ Gateway started on %s\n", listenAddr)
|
||||
for _, bindHost := range listenResult.BindHosts {
|
||||
fmt.Printf("✓ Gateway started on %s\n", net.JoinHostPort(bindHost, strconv.Itoa(cfg.Gateway.Port)))
|
||||
}
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -323,6 +343,7 @@ func setupAndStartServices(
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
authToken string,
|
||||
listenResult netbind.OpenResult,
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
@@ -393,10 +414,20 @@ func setupAndStartServices(
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
addr := net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port))
|
||||
runningServices.authToken = authToken
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
runningServices.HealthServer = health.NewServer(listenResult.ProbeHost, cfg.Gateway.Port, authToken)
|
||||
|
||||
listenAddr := ""
|
||||
if len(listenResult.Listeners) > 0 {
|
||||
listenAddr = listenResult.Listeners[0].Addr().String()
|
||||
} else {
|
||||
listenAddr = net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port))
|
||||
}
|
||||
runningServices.ChannelManager.SetupHTTPServerListeners(
|
||||
listenResult.Listeners,
|
||||
listenAddr,
|
||||
runningServices.HealthServer,
|
||||
)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
@@ -412,7 +443,7 @@ func setupAndStartServices(
|
||||
voiceAgent.Start(vaCtx)
|
||||
}
|
||||
|
||||
healthAddr := net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port))
|
||||
healthAddr := net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port))
|
||||
fmt.Printf(
|
||||
"✓ Health endpoints available at http://%s/health, /ready and /reload (POST)\n",
|
||||
healthAddr,
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
func openGatewayListeners(host string, port int) (netbind.Plan, netbind.OpenResult, error) {
|
||||
plan, err := netbind.BuildPlan(host, netbind.DefaultLoopback)
|
||||
if err != nil {
|
||||
return netbind.Plan{}, netbind.OpenResult{}, err
|
||||
}
|
||||
|
||||
result, err := netbind.OpenPlan(plan, strconv.Itoa(port))
|
||||
if err != nil {
|
||||
return netbind.Plan{}, netbind.OpenResult{}, err
|
||||
}
|
||||
|
||||
return plan, result, nil
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
func TestOpenGatewayListeners_HonorsIPv6OnlyHost(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
|
||||
if !hasIPv6 {
|
||||
t.Skip("IPv6 is unavailable in this environment")
|
||||
}
|
||||
|
||||
_, result, err := openGatewayListeners("::", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("openGatewayListeners() error = %v", err)
|
||||
}
|
||||
startGatewayTestHTTPServer(t, result.Listeners)
|
||||
port := mustGatewayAtoi(t, result.Port)
|
||||
|
||||
requireGatewayHTTPReachable(t, "::1", port)
|
||||
if hasIPv4 {
|
||||
requireGatewayHTTPUnreachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenGatewayListeners_SupportsExplicitMultiHost(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
|
||||
if !hasIPv4 || !hasIPv6 {
|
||||
t.Skip("dual-stack loopback is unavailable in this environment")
|
||||
}
|
||||
|
||||
_, result, err := openGatewayListeners("127.0.0.1,::1", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("openGatewayListeners() error = %v", err)
|
||||
}
|
||||
startGatewayTestHTTPServer(t, result.Listeners)
|
||||
port := mustGatewayAtoi(t, result.Port)
|
||||
|
||||
requireGatewayHTTPReachable(t, "127.0.0.1", port)
|
||||
requireGatewayHTTPReachable(t, "::1", port)
|
||||
}
|
||||
|
||||
func startGatewayTestHTTPServer(t *testing.T, listeners []net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}),
|
||||
}
|
||||
|
||||
errCh := make(chan error, len(listeners))
|
||||
for _, listener := range listeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(ctx)
|
||||
for range listeners {
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Fatalf("server.Serve() error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func requireGatewayHTTPReachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
err := gatewayHTTPGet(host, port)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("expected %s:%d to be reachable: %v", host, port, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func requireGatewayHTTPUnreachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
if err := gatewayHTTPGet(host, port); err == nil {
|
||||
t.Fatalf("expected %s:%d to be unreachable", host, port)
|
||||
}
|
||||
}
|
||||
|
||||
func gatewayHTTPGet(host string, port int) error {
|
||||
client := &http.Client{
|
||||
Timeout: 300 * time.Millisecond,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustGatewayAtoi(t *testing.T, value string) int {
|
||||
t.Helper()
|
||||
n, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", value, err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,580 @@
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type DefaultMode int
|
||||
|
||||
const (
|
||||
DefaultLoopback DefaultMode = iota
|
||||
DefaultAny
|
||||
)
|
||||
|
||||
type groupKind int
|
||||
|
||||
const (
|
||||
groupAdaptiveLoopback groupKind = iota
|
||||
groupAdaptiveAny
|
||||
groupExact
|
||||
)
|
||||
|
||||
type exactBinding struct {
|
||||
host string
|
||||
network string
|
||||
v6Only bool
|
||||
}
|
||||
|
||||
type bindGroup struct {
|
||||
kind groupKind
|
||||
allowIPv4 bool
|
||||
allowIPv6 bool
|
||||
exact exactBinding
|
||||
}
|
||||
|
||||
type Plan struct {
|
||||
groups []bindGroup
|
||||
ProbeHost string
|
||||
}
|
||||
|
||||
type OpenResult struct {
|
||||
Listeners []net.Listener
|
||||
BindHosts []string
|
||||
Port string
|
||||
ProbeHost string
|
||||
}
|
||||
|
||||
type tokenKind int
|
||||
|
||||
const (
|
||||
tokenName tokenKind = iota
|
||||
tokenLocalhost
|
||||
tokenStar
|
||||
tokenIPv4
|
||||
tokenIPv6
|
||||
tokenIPv4Any
|
||||
tokenIPv6Any
|
||||
)
|
||||
|
||||
type hostToken struct {
|
||||
kind tokenKind
|
||||
canonical string
|
||||
key string
|
||||
}
|
||||
|
||||
var (
|
||||
ipFamiliesOnce sync.Once
|
||||
hasIPv4 bool
|
||||
hasIPv6 bool
|
||||
)
|
||||
|
||||
func DetectIPFamilies() (bool, bool) {
|
||||
ipFamiliesOnce.Do(func() {
|
||||
if ips, err := net.LookupIP("localhost"); err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
hasIPv4 = true
|
||||
continue
|
||||
}
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasIPv4 && hasIPv6 {
|
||||
return
|
||||
}
|
||||
|
||||
if addrs, err := net.InterfaceAddrs(); err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipnet.IP == nil {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
hasIPv4 = true
|
||||
continue
|
||||
}
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return hasIPv4, hasIPv6
|
||||
}
|
||||
|
||||
func SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "localhost"
|
||||
case hasIPv6:
|
||||
return "::1"
|
||||
case hasIPv4:
|
||||
return "127.0.0.1"
|
||||
default:
|
||||
return "localhost"
|
||||
}
|
||||
}
|
||||
|
||||
func SelectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "::"
|
||||
case hasIPv6:
|
||||
return "::"
|
||||
case hasIPv4:
|
||||
return "0.0.0.0"
|
||||
default:
|
||||
return "::"
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveAdaptiveLoopbackHost() string {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
return SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func ResolveAdaptiveAnyHost() string {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
return SelectAdaptiveAnyHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func IsLoopbackHost(host string) bool {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(strings.Trim(host, "[]"))
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func IsUnspecifiedHost(host string) bool {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
ip := net.ParseIP(strings.Trim(host, "[]"))
|
||||
return ip != nil && ip.IsUnspecified()
|
||||
}
|
||||
|
||||
func NormalizeHostInput(raw string) (string, error) {
|
||||
tokens, err := parseHostTokens(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
parts = append(parts, token.canonical)
|
||||
}
|
||||
return strings.Join(parts, ","), nil
|
||||
}
|
||||
|
||||
func BuildPlan(raw string, defaultMode DefaultMode) (Plan, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return buildDefaultPlan(defaultMode), nil
|
||||
}
|
||||
|
||||
tokens, err := parseHostTokens(raw)
|
||||
if err != nil {
|
||||
return Plan{}, err
|
||||
}
|
||||
|
||||
for _, token := range tokens {
|
||||
if token.kind == tokenStar {
|
||||
return Plan{
|
||||
groups: []bindGroup{{kind: groupAdaptiveAny}},
|
||||
ProbeHost: ResolveAdaptiveLoopbackHost(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasIPv4Any := false
|
||||
hasIPv6Any := false
|
||||
for _, token := range tokens {
|
||||
switch token.kind {
|
||||
case tokenIPv4Any:
|
||||
hasIPv4Any = true
|
||||
case tokenIPv6Any:
|
||||
hasIPv6Any = true
|
||||
}
|
||||
}
|
||||
|
||||
allowLocalhostIPv4 := !hasIPv4Any
|
||||
allowLocalhostIPv6 := !hasIPv6Any
|
||||
|
||||
groups := make([]bindGroup, 0, len(tokens))
|
||||
seenExact := make(map[string]struct{}, len(tokens))
|
||||
addedLocalhost := false
|
||||
|
||||
for _, token := range tokens {
|
||||
switch token.kind {
|
||||
case tokenLocalhost:
|
||||
if addedLocalhost || (!allowLocalhostIPv4 && !allowLocalhostIPv6) {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupAdaptiveLoopback,
|
||||
allowIPv4: allowLocalhostIPv4,
|
||||
allowIPv6: allowLocalhostIPv6,
|
||||
})
|
||||
addedLocalhost = true
|
||||
case tokenIPv4Any:
|
||||
key := "exact:tcp4:0.0.0.0"
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: "0.0.0.0",
|
||||
network: "tcp4",
|
||||
},
|
||||
})
|
||||
case tokenIPv6Any:
|
||||
key := "exact:tcp6:::"
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: "::",
|
||||
network: "tcp6",
|
||||
v6Only: true,
|
||||
},
|
||||
})
|
||||
case tokenIPv4:
|
||||
if hasIPv4Any {
|
||||
continue
|
||||
}
|
||||
key := "exact:tcp4:" + strings.ToLower(token.canonical)
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: token.canonical,
|
||||
network: "tcp4",
|
||||
},
|
||||
})
|
||||
case tokenIPv6:
|
||||
if hasIPv6Any {
|
||||
continue
|
||||
}
|
||||
key := "exact:tcp6:" + strings.ToLower(token.canonical)
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: token.canonical,
|
||||
network: "tcp6",
|
||||
v6Only: true,
|
||||
},
|
||||
})
|
||||
case tokenName:
|
||||
key := "exact:tcp:" + token.key
|
||||
if _, ok := seenExact[key]; ok {
|
||||
continue
|
||||
}
|
||||
seenExact[key] = struct{}{}
|
||||
groups = append(groups, bindGroup{
|
||||
kind: groupExact,
|
||||
exact: exactBinding{
|
||||
host: token.canonical,
|
||||
network: "tcp",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
plan := Plan{groups: groups}
|
||||
plan.ProbeHost = probeHostForGroups(groups)
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func OpenPlan(plan Plan, port string) (OpenResult, error) {
|
||||
if port == "" {
|
||||
return OpenResult{}, errors.New("port cannot be empty")
|
||||
}
|
||||
|
||||
selectedPort := port
|
||||
listeners := make([]net.Listener, 0, len(plan.groups))
|
||||
bindHosts := make([]string, 0, len(plan.groups))
|
||||
bindSeen := make(map[string]struct{}, len(plan.groups))
|
||||
|
||||
closeAll := func() {
|
||||
for _, ln := range listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
}
|
||||
|
||||
for _, group := range plan.groups {
|
||||
groupListeners, groupHosts, actualPort, err := openGroup(group, selectedPort)
|
||||
if err != nil {
|
||||
closeAll()
|
||||
return OpenResult{}, err
|
||||
}
|
||||
if selectedPort == "0" && actualPort != "" {
|
||||
selectedPort = actualPort
|
||||
}
|
||||
listeners = append(listeners, groupListeners...)
|
||||
for _, host := range groupHosts {
|
||||
key := strings.ToLower(host)
|
||||
if _, ok := bindSeen[key]; ok {
|
||||
continue
|
||||
}
|
||||
bindSeen[key] = struct{}{}
|
||||
bindHosts = append(bindHosts, host)
|
||||
}
|
||||
}
|
||||
|
||||
return OpenResult{
|
||||
Listeners: listeners,
|
||||
BindHosts: bindHosts,
|
||||
Port: selectedPort,
|
||||
ProbeHost: plan.ProbeHost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildDefaultPlan(defaultMode DefaultMode) Plan {
|
||||
switch defaultMode {
|
||||
case DefaultAny:
|
||||
return Plan{
|
||||
groups: []bindGroup{{kind: groupAdaptiveAny}},
|
||||
ProbeHost: ResolveAdaptiveLoopbackHost(),
|
||||
}
|
||||
default:
|
||||
return Plan{
|
||||
groups: []bindGroup{{
|
||||
kind: groupAdaptiveLoopback,
|
||||
allowIPv4: true,
|
||||
allowIPv6: true,
|
||||
}},
|
||||
ProbeHost: ResolveAdaptiveLoopbackHost(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func probeHostForGroups(groups []bindGroup) string {
|
||||
hasIPv4Any := false
|
||||
hasIPv6Any := false
|
||||
for _, group := range groups {
|
||||
if group.kind == groupAdaptiveLoopback {
|
||||
switch {
|
||||
case group.allowIPv4 && group.allowIPv6:
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
case group.allowIPv6:
|
||||
return "::1"
|
||||
case group.allowIPv4:
|
||||
return "127.0.0.1"
|
||||
}
|
||||
}
|
||||
if group.kind == groupAdaptiveAny {
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
if group.kind != groupExact {
|
||||
continue
|
||||
}
|
||||
switch group.exact.host {
|
||||
case "0.0.0.0":
|
||||
hasIPv4Any = true
|
||||
case "::":
|
||||
hasIPv6Any = true
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case hasIPv4Any && hasIPv6Any:
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
case hasIPv6Any:
|
||||
return "::1"
|
||||
case hasIPv4Any:
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.kind == groupExact {
|
||||
return group.exact.host
|
||||
}
|
||||
}
|
||||
return ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func parseHostTokens(raw string) ([]hostToken, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, errors.New("host cannot be empty")
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, ",")
|
||||
tokens := make([]hostToken, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
token, err := parseHostToken(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := seen[token.key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[token.key] = struct{}{}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return nil, errors.New("host cannot be empty")
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func parseHostToken(raw string) (hostToken, error) {
|
||||
host := strings.TrimSpace(raw)
|
||||
if host == "" {
|
||||
return hostToken{}, errors.New("host list contains an empty entry")
|
||||
}
|
||||
|
||||
if host == "*" {
|
||||
return hostToken{kind: tokenStar, canonical: "*", key: "*"}, nil
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return hostToken{kind: tokenLocalhost, canonical: "localhost", key: "localhost"}, nil
|
||||
}
|
||||
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
if ip := net.ParseIP(trimmed); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
canonical := ip4.String()
|
||||
kind := tokenIPv4
|
||||
if ip4.IsUnspecified() {
|
||||
kind = tokenIPv4Any
|
||||
}
|
||||
return hostToken{kind: kind, canonical: canonical, key: canonical}, nil
|
||||
}
|
||||
|
||||
canonical := ip.String()
|
||||
kind := tokenIPv6
|
||||
if ip.IsUnspecified() {
|
||||
kind = tokenIPv6Any
|
||||
}
|
||||
return hostToken{kind: kind, canonical: canonical, key: strings.ToLower(canonical)}, nil
|
||||
}
|
||||
|
||||
return hostToken{
|
||||
kind: tokenName,
|
||||
canonical: host,
|
||||
key: strings.ToLower(host),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func openGroup(group bindGroup, port string) ([]net.Listener, []string, string, error) {
|
||||
switch group.kind {
|
||||
case groupAdaptiveLoopback:
|
||||
return openAdaptiveLoopbackGroup(group.allowIPv6, group.allowIPv4, port)
|
||||
case groupAdaptiveAny:
|
||||
return openAdaptiveAnyGroup(port)
|
||||
case groupExact:
|
||||
ln, actualPort, err := openExactListener(group.exact, port)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
return []net.Listener{ln}, []string{group.exact.host}, actualPort, nil
|
||||
default:
|
||||
return nil, nil, "", fmt.Errorf("unsupported bind group kind: %d", group.kind)
|
||||
}
|
||||
}
|
||||
|
||||
func openAdaptiveLoopbackGroup(allowIPv6, allowIPv4 bool, port string) ([]net.Listener, []string, string, error) {
|
||||
if allowIPv6 && allowIPv4 {
|
||||
if ln6, actualPort, err6 := openExactListener(exactBinding{host: "::1", network: "tcp6", v6Only: true}, port); err6 == nil {
|
||||
if ln4, _, err4 := openExactListener(exactBinding{host: "127.0.0.1", network: "tcp4"}, actualPort); err4 == nil {
|
||||
return []net.Listener{ln6, ln4}, []string{"::1", "127.0.0.1"}, actualPort, nil
|
||||
}
|
||||
_ = ln6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if allowIPv6 {
|
||||
ln6, actualPort, err := openExactListener(exactBinding{host: "::1", network: "tcp6", v6Only: true}, port)
|
||||
if err == nil {
|
||||
return []net.Listener{ln6}, []string{"::1"}, actualPort, nil
|
||||
}
|
||||
}
|
||||
|
||||
if allowIPv4 {
|
||||
ln4, actualPort, err := openExactListener(exactBinding{host: "127.0.0.1", network: "tcp4"}, port)
|
||||
if err == nil {
|
||||
return []net.Listener{ln4}, []string{"127.0.0.1"}, actualPort, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, "", fmt.Errorf("failed to open adaptive localhost listener on port %s", port)
|
||||
}
|
||||
|
||||
func openAdaptiveAnyGroup(port string) ([]net.Listener, []string, string, error) {
|
||||
// Intentionally bind tcp/:: here. Go's compatibility layer handles dual-stack
|
||||
// wildcard binding where the platform supports it, while tcp4 remains the
|
||||
// fallback for IPv4-only environments.
|
||||
if ln, actualPort, err := openExactListener(exactBinding{host: "::", network: "tcp"}, port); err == nil {
|
||||
return []net.Listener{ln}, []string{"::"}, actualPort, nil
|
||||
}
|
||||
|
||||
ln4, actualPort, err := openExactListener(exactBinding{host: "0.0.0.0", network: "tcp4"}, port)
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to open adaptive any-host listener on port %s", port)
|
||||
}
|
||||
return []net.Listener{ln4}, []string{"0.0.0.0"}, actualPort, nil
|
||||
}
|
||||
|
||||
func openExactListener(binding exactBinding, port string) (net.Listener, string, error) {
|
||||
listenConfig := net.ListenConfig{}
|
||||
if binding.network == "tcp6" && binding.v6Only {
|
||||
listenConfig.Control = applyIPv6OnlyControl(true)
|
||||
}
|
||||
|
||||
ln, err := listenConfig.Listen(context.Background(), binding.network, net.JoinHostPort(binding.host, port))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
actualPort, err := listenerPort(ln)
|
||||
if err != nil {
|
||||
_ = ln.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return ln, actualPort, nil
|
||||
}
|
||||
|
||||
func listenerPort(ln net.Listener) (string, error) {
|
||||
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if ok {
|
||||
return strconv.Itoa(addr.Port), nil
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(ln.Addr().String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return port, nil
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizeHostInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "single host", raw: "127.0.0.1", want: "127.0.0.1"},
|
||||
{name: "trim and dedupe", raw: " [::1] , ::1 , 127.0.0.1 ", want: "::1,127.0.0.1"},
|
||||
{name: "star preserved", raw: "*,127.0.0.1", want: "*,127.0.0.1"},
|
||||
{name: "reject empty", raw: "127.0.0.1, ", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NormalizeHostInput(tt.raw)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("NormalizeHostInput() err = %v, wantErr %t", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("NormalizeHostInput() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPlan_DefaultAnyUsesLoopbackProbe(t *testing.T) {
|
||||
plan, err := BuildPlan("", DefaultAny)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
if plan.ProbeHost != ResolveAdaptiveLoopbackHost() {
|
||||
t.Fatalf("ProbeHost = %q, want %q", plan.ProbeHost, ResolveAdaptiveLoopbackHost())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_LocalhostSupportsLoopbackCommunication(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
|
||||
plan, err := BuildPlan("localhost", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
if hasIPv6 {
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
}
|
||||
if hasIPv4 {
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_DefaultAnySupportsDualStackLoopback(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
|
||||
plan, err := BuildPlan("", DefaultAny)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
if hasIPv6 {
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
}
|
||||
if hasIPv4 {
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_ExplicitIPv6AnyIsIPv6Only(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv6 {
|
||||
t.Skip("IPv6 is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("::", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
if hasIPv4 {
|
||||
requireHTTPUnreachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_ExplicitIPv4AnyIsIPv4Only(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv4 {
|
||||
t.Skip("IPv4 is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("0.0.0.0", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
if hasIPv6 {
|
||||
requireHTTPUnreachable(t, "::1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPlan_MultiHostSupportsExplicitIPv4AndIPv6(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv4 || !hasIPv6 {
|
||||
t.Skip("dual-stack loopback is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("127.0.0.1,::1", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
}
|
||||
|
||||
func TestOpenPlan_WildcardRulesKeepIPv4AndIPv6AnyHosts(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
if !hasIPv4 || !hasIPv6 {
|
||||
t.Skip("dual-stack loopback is unavailable in this environment")
|
||||
}
|
||||
|
||||
plan, err := BuildPlan("::,::1,0.0.0.0,127.0.0.1", DefaultLoopback)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildPlan() error = %v", err)
|
||||
}
|
||||
result, err := OpenPlan(plan, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenPlan() error = %v", err)
|
||||
}
|
||||
startTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireHTTPReachable(t, "127.0.0.1", port)
|
||||
requireHTTPReachable(t, "::1", port)
|
||||
if len(result.BindHosts) != 2 {
|
||||
t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts)
|
||||
}
|
||||
}
|
||||
|
||||
func startTestHTTPServer(t *testing.T, listeners []net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}),
|
||||
}
|
||||
|
||||
errCh := make(chan error, len(listeners))
|
||||
for _, listener := range listeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(ctx)
|
||||
for range listeners {
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Fatalf("server.Serve() error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func requireHTTPReachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
err := httpGET(host, port)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("expected %s:%d to be reachable: %v", host, port, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func requireHTTPUnreachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
|
||||
if err := httpGET(host, port); err == nil {
|
||||
t.Fatalf("expected %s:%d to be unreachable", host, port)
|
||||
}
|
||||
}
|
||||
|
||||
func httpGET(host string, port int) error {
|
||||
client := &http.Client{
|
||||
Timeout: 300 * time.Millisecond,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustAtoi(t *testing.T, value string) int {
|
||||
t.Helper()
|
||||
n, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", value, err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
//go:build !windows
|
||||
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error {
|
||||
return func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var controlErr error
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
value := 0
|
||||
if enabled {
|
||||
value = 1
|
||||
}
|
||||
controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, value)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
//go:build windows
|
||||
|
||||
package netbind
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error {
|
||||
return func(_, _ string, rawConn syscall.RawConn) error {
|
||||
var controlErr error
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
value := 0
|
||||
if enabled {
|
||||
value = 1
|
||||
}
|
||||
controlErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, value)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
@@ -119,6 +120,7 @@ var (
|
||||
gatewayRestartGracePeriod = 5 * time.Second
|
||||
gatewayRestartForceKillWindow = 3 * time.Second
|
||||
gatewayRestartPollInterval = 100 * time.Millisecond
|
||||
gatewayExecCommand = exec.Command
|
||||
)
|
||||
|
||||
var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
@@ -262,7 +264,7 @@ func (h *Handler) getGatewayHealthForPidData(
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
}
|
||||
if host == "" {
|
||||
host = resolveDefaultLoopbackHost()
|
||||
host = netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health"
|
||||
@@ -723,7 +725,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath))
|
||||
|
||||
cmd = exec.Command(execPath, h.gatewayCommandArgs()...)
|
||||
cmd = gatewayExecCommand(execPath, h.gatewayCommandArgs()...)
|
||||
cmd.Env = os.Environ()
|
||||
// Forward the launcher's config path via the environment variable that
|
||||
// GetConfigPath() already reads, so the gateway sub-process uses the same
|
||||
@@ -731,17 +733,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
if h.configPath != "" {
|
||||
cmd.Env = append(cmd.Env, config.EnvConfig+"="+h.configPath)
|
||||
}
|
||||
gatewayHostOverride := h.gatewayHostOverrideForConfig(cfg)
|
||||
if h.serverHostExplicit && gatewayHostOverride == "" {
|
||||
logger.WarnC(
|
||||
"gateway",
|
||||
fmt.Sprintf(
|
||||
"Explicit launcher host %q was not forwarded to gateway because configured gateway host is %q; gateway keeps original bind host",
|
||||
strings.TrimSpace(h.serverHost),
|
||||
strings.TrimSpace(cfg.Gateway.Host),
|
||||
),
|
||||
)
|
||||
}
|
||||
gatewayHostOverride := h.gatewayHostOverride()
|
||||
if gatewayHostOverride != "" {
|
||||
cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+gatewayHostOverride)
|
||||
}
|
||||
|
||||
@@ -8,38 +8,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
)
|
||||
|
||||
func selectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string {
|
||||
return utils.SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func selectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string {
|
||||
return utils.SelectAdaptiveAnyHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func isLoopbackEquivalentHost(host string) bool {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
ip := net.ParseIP(trimmed)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func resolveDefaultLoopbackHost() string {
|
||||
return utils.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func resolveDefaultAnyHost() string {
|
||||
return utils.ResolveAdaptiveAnyHost()
|
||||
}
|
||||
|
||||
func (h *Handler) effectiveLauncherPublic() bool {
|
||||
if h.serverHostExplicit {
|
||||
// -host takes precedence over -public and launcher-config public setting.
|
||||
@@ -58,64 +29,18 @@ func (h *Handler) effectiveLauncherPublic() bool {
|
||||
return h.serverPublic
|
||||
}
|
||||
|
||||
func canonicalLauncherBindHost(host string) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return resolveDefaultLoopbackHost()
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return resolveDefaultLoopbackHost()
|
||||
}
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
if ip := net.ParseIP(trimmed); ip != nil && ip.IsUnspecified() {
|
||||
return resolveDefaultAnyHost()
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func (h *Handler) launcherAndGatewayBindHostsAligned(cfg *config.Config) bool {
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// With -host specified, -public is ignored, so launcher baseline bind host is loopback.
|
||||
launcherHost := canonicalLauncherBindHost("")
|
||||
gatewayHost := canonicalLauncherBindHost(cfg.Gateway.Host)
|
||||
if isLoopbackEquivalentHost(launcherHost) && isLoopbackEquivalentHost(gatewayHost) {
|
||||
return true
|
||||
}
|
||||
|
||||
return launcherHost == gatewayHost
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayHostOverrideForConfig(cfg *config.Config) string {
|
||||
func (h *Handler) gatewayHostOverride() string {
|
||||
if h.serverHostExplicit {
|
||||
if h.launcherAndGatewayBindHostsAligned(cfg) {
|
||||
return strings.TrimSpace(h.serverHost)
|
||||
}
|
||||
return ""
|
||||
return strings.TrimSpace(h.serverHostInput)
|
||||
}
|
||||
|
||||
if h.effectiveLauncherPublic() {
|
||||
return resolveDefaultAnyHost()
|
||||
return "*"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayHostOverride() string {
|
||||
if !h.serverHostExplicit {
|
||||
return h.gatewayHostOverrideForConfig(nil)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return h.gatewayHostOverrideForConfig(cfg)
|
||||
}
|
||||
|
||||
func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string {
|
||||
if override := h.gatewayHostOverrideForConfig(cfg); override != "" {
|
||||
if override := h.gatewayHostOverride(); override != "" {
|
||||
return override
|
||||
}
|
||||
if cfg == nil {
|
||||
@@ -125,19 +50,11 @@ func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string {
|
||||
}
|
||||
|
||||
func gatewayProbeHost(bindHost string) string {
|
||||
bindHost = strings.TrimSpace(bindHost)
|
||||
if bindHost == "" {
|
||||
return resolveDefaultLoopbackHost()
|
||||
plan, err := netbind.BuildPlan(bindHost, netbind.DefaultLoopback)
|
||||
if err != nil || strings.TrimSpace(plan.ProbeHost) == "" {
|
||||
return netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
if strings.EqualFold(bindHost, "localhost") {
|
||||
return resolveDefaultLoopbackHost()
|
||||
}
|
||||
|
||||
trimmed := strings.Trim(bindHost, "[]")
|
||||
if ip := net.ParseIP(trimmed); ip != nil && ip.IsUnspecified() {
|
||||
return resolveDefaultLoopbackHost()
|
||||
}
|
||||
return bindHost
|
||||
return plan.ProbeHost
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayProxyURL() *url.URL {
|
||||
@@ -165,7 +82,7 @@ func requestHostName(r *http.Request) string {
|
||||
if strings.TrimSpace(r.Host) != "" {
|
||||
return r.Host
|
||||
}
|
||||
return resolveDefaultLoopbackHost()
|
||||
return netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func requestWSScheme(r *http.Request) string {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
@@ -27,8 +28,8 @@ func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) {
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != resolveDefaultAnyHost() {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, resolveDefaultAnyHost())
|
||||
if got := h.gatewayHostOverride(); got != "*" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "*")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,78 +65,40 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectAdaptiveLoopbackHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasIPv4 bool
|
||||
hasIPv6 bool
|
||||
want string
|
||||
}{
|
||||
{name: "dual stack prefers localhost", hasIPv4: true, hasIPv6: true, want: "localhost"},
|
||||
{name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::1"},
|
||||
{name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "127.0.0.1"},
|
||||
{name: "fallback", hasIPv4: false, hasIPv6: false, want: "localhost"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := selectAdaptiveLoopbackHost(tt.hasIPv4, tt.hasIPv6); got != tt.want {
|
||||
t.Fatalf("selectAdaptiveLoopbackHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectAdaptiveAnyHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasIPv4 bool
|
||||
hasIPv6 bool
|
||||
want string
|
||||
}{
|
||||
{name: "dual stack prefers ipv6 wildcard", hasIPv4: true, hasIPv6: true, want: "::"},
|
||||
{name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::"},
|
||||
{name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "0.0.0.0"},
|
||||
{name: "fallback", hasIPv4: false, hasIPv6: false, want: "::"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := selectAdaptiveAnyHost(tt.hasIPv4, tt.hasIPv6); got != tt.want {
|
||||
t.Fatalf("selectAdaptiveAnyHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
|
||||
want := resolveDefaultLoopbackHost()
|
||||
want := "127.0.0.1"
|
||||
if got := gatewayProbeHost("0.0.0.0"); got != want {
|
||||
t.Fatalf("gatewayProbeHost() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesPreferredLoopbackForEmptyBind(t *testing.T) {
|
||||
want := resolveDefaultLoopbackHost()
|
||||
want := netbind.ResolveAdaptiveLoopbackHost()
|
||||
if got := gatewayProbeHost(""); got != want {
|
||||
t.Fatalf("gatewayProbeHost(empty) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesPreferredLoopbackForLocalhostBind(t *testing.T) {
|
||||
want := resolveDefaultLoopbackHost()
|
||||
want := netbind.ResolveAdaptiveLoopbackHost()
|
||||
if got := gatewayProbeHost("localhost"); got != want {
|
||||
t.Fatalf("gatewayProbeHost(localhost) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesLoopbackForIPv6WildcardBind(t *testing.T) {
|
||||
want := resolveDefaultLoopbackHost()
|
||||
want := "::1"
|
||||
if got := gatewayProbeHost("::"); got != want {
|
||||
t.Fatalf("gatewayProbeHost(::) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesFirstConcreteHostForMultiHostBind(t *testing.T) {
|
||||
if got := gatewayProbeHost("127.0.0.1,::1"); got != "127.0.0.1" {
|
||||
t.Fatalf("gatewayProbeHost(multi) = %q, want %q", got, "127.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
@@ -204,7 +167,7 @@ func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
want := "http://" + net.JoinHostPort(resolveDefaultLoopbackHost(), "18791") + "/health"
|
||||
want := "http://" + net.JoinHostPort(netbind.ResolveAdaptiveLoopbackHost(), "18791") + "/health"
|
||||
if requestedURL != want {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, want)
|
||||
}
|
||||
@@ -310,23 +273,17 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGatewayHostOverrideWithExplicitHostAndAlignedGatewayHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
writeGatewayHostConfig(t, configPath, "127.0.0.1")
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
h.SetServerBindHost("0.0.0.0", true)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != resolveDefaultAnyHost() {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, resolveDefaultAnyHost())
|
||||
if got := h.gatewayHostOverride(); got != "0.0.0.0" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHostOverrideWithExplicitHostAndLocalhostGatewayHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
writeGatewayHostConfig(t, configPath, "localhost")
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
h.SetServerBindHost("::", true)
|
||||
|
||||
@@ -335,24 +292,18 @@ func TestGatewayHostOverrideWithExplicitHostAndLocalhostGatewayHost(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHostOverrideWithExplicitHostAndMismatchedGatewayHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
writeGatewayHostConfig(t, configPath, "0.0.0.0")
|
||||
|
||||
h := NewHandler(configPath)
|
||||
func TestGatewayHostOverrideWithExplicitMultiHost(t *testing.T) {
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
h.SetServerBindHost("192.168.1.10", true)
|
||||
h.SetServerBindHost("127.0.0.1,::1", true)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != "" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want empty", got)
|
||||
if got := h.gatewayHostOverride(); got != "127.0.0.1,::1" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "127.0.0.1,::1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHostExplicitIgnoresPublicFlag(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
writeGatewayHostConfig(t, configPath, "127.0.0.1")
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h := NewHandler(filepath.Join(t.TempDir(), "config.json"))
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
h.SetServerBindHost("127.0.0.1", true)
|
||||
|
||||
@@ -360,13 +311,3 @@ func TestGatewayHostExplicitIgnoresPublicFlag(t *testing.T) {
|
||||
t.Fatalf("effectiveLauncherPublic() = %t, want false when explicit host is set", got)
|
||||
}
|
||||
}
|
||||
|
||||
func writeGatewayHostConfig(t *testing.T, configPath, host string) {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = host
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ func resetGatewayTestState(t *testing.T) {
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
originalProcessMatcher := gatewayProcessMatcher
|
||||
originalExecCommand := gatewayExecCommand
|
||||
originalRestartGracePeriod := gatewayRestartGracePeriod
|
||||
originalRestartForceKillWindow := gatewayRestartForceKillWindow
|
||||
originalRestartPollInterval := gatewayRestartPollInterval
|
||||
@@ -104,6 +105,7 @@ func resetGatewayTestState(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
gatewayProcessMatcher = originalProcessMatcher
|
||||
gatewayExecCommand = originalExecCommand
|
||||
gatewayRestartGracePeriod = originalRestartGracePeriod
|
||||
gatewayRestartForceKillWindow = originalRestartForceKillWindow
|
||||
gatewayRestartPollInterval = originalRestartPollInterval
|
||||
@@ -119,6 +121,158 @@ func resetGatewayTestState(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type gatewayStartEnvSnapshot struct {
|
||||
GatewayHost string `json:"gateway_host"`
|
||||
GatewayHostSet bool `json:"gateway_host_set"`
|
||||
ConfigPath string `json:"config_path"`
|
||||
}
|
||||
|
||||
func TestGatewayStartHelperProcess(t *testing.T) {
|
||||
var envPath string
|
||||
for i, arg := range os.Args {
|
||||
if arg == "--" && i+2 < len(os.Args) && os.Args[i+1] == "gateway-env-helper" {
|
||||
envPath = os.Args[i+2]
|
||||
break
|
||||
}
|
||||
}
|
||||
if envPath == "" {
|
||||
t.Skip("helper process")
|
||||
}
|
||||
|
||||
host, ok := os.LookupEnv(config.EnvGatewayHost)
|
||||
raw, err := json.Marshal(gatewayStartEnvSnapshot{
|
||||
GatewayHost: host,
|
||||
GatewayHostSet: ok,
|
||||
ConfigPath: os.Getenv(config.EnvConfig),
|
||||
})
|
||||
if err != nil {
|
||||
_, _ = io.WriteString(os.Stderr, err.Error())
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := os.WriteFile(envPath, raw, 0o600); err != nil {
|
||||
_, _ = io.WriteString(os.Stderr, err.Error())
|
||||
os.Exit(2)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func unsetGatewayStartEnvForTest(t *testing.T, key string) {
|
||||
t.Helper()
|
||||
|
||||
prev, hadPrev := os.LookupEnv(key)
|
||||
if err := os.Unsetenv(key); err != nil {
|
||||
t.Fatalf("Unsetenv(%q) error = %v", key, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if hadPrev {
|
||||
_ = os.Setenv(key, prev)
|
||||
return
|
||||
}
|
||||
_ = os.Unsetenv(key)
|
||||
})
|
||||
}
|
||||
|
||||
func newGatewayStartTestHandler(t *testing.T) *Handler {
|
||||
t.Helper()
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func startGatewayAndCaptureEnv(t *testing.T, h *Handler) gatewayStartEnvSnapshot {
|
||||
t.Helper()
|
||||
|
||||
unsetGatewayStartEnvForTest(t, config.EnvGatewayHost)
|
||||
|
||||
envPath := filepath.Join(t.TempDir(), "gateway-child-env.json")
|
||||
gatewayExecCommand = func(_ string, _ ...string) *exec.Cmd {
|
||||
return exec.Command(
|
||||
os.Args[0],
|
||||
"-test.run=TestGatewayStartHelperProcess",
|
||||
"--",
|
||||
"gateway-env-helper",
|
||||
envPath,
|
||||
)
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked("starting", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("startGatewayLocked() error = %v", err)
|
||||
}
|
||||
if pid <= 0 {
|
||||
t.Fatalf("startGatewayLocked() pid = %d, want > 0", pid)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for {
|
||||
raw, err := os.ReadFile(envPath)
|
||||
if err == nil {
|
||||
var snapshot gatewayStartEnvSnapshot
|
||||
if err := json.Unmarshal(raw, &snapshot); err != nil {
|
||||
t.Fatalf("Unmarshal(child env) error = %v", err)
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatalf("ReadFile(%q) error = %v", envPath, err)
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("timed out waiting for gateway child env snapshot %q", envPath)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGatewayLocked_ForwardsLauncherHostOverrideToGatewayEnv(t *testing.T) {
|
||||
h := newGatewayStartTestHandler(t)
|
||||
h.SetServerBindHost("127.0.0.1,::1", true)
|
||||
|
||||
snapshot := startGatewayAndCaptureEnv(t, h)
|
||||
if !snapshot.GatewayHostSet {
|
||||
t.Fatal("gateway host env was not set")
|
||||
}
|
||||
if snapshot.GatewayHost != "127.0.0.1,::1" {
|
||||
t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "127.0.0.1,::1")
|
||||
}
|
||||
if snapshot.ConfigPath != h.configPath {
|
||||
t.Fatalf("config env = %q, want %q", snapshot.ConfigPath, h.configPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGatewayLocked_ForwardsLauncherHostFromEnvironmentToGatewayEnv(t *testing.T) {
|
||||
h := newGatewayStartTestHandler(t)
|
||||
h.SetServerBindHost("::", true)
|
||||
|
||||
snapshot := startGatewayAndCaptureEnv(t, h)
|
||||
if !snapshot.GatewayHostSet {
|
||||
t.Fatal("gateway host env was not set")
|
||||
}
|
||||
if snapshot.GatewayHost != "::" {
|
||||
t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "::")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGatewayLocked_ForwardsWildcardHostForPublicLauncher(t *testing.T) {
|
||||
h := newGatewayStartTestHandler(t)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
snapshot := startGatewayAndCaptureEnv(t, h)
|
||||
if !snapshot.GatewayHostSet {
|
||||
t.Fatal("gateway host env was not set")
|
||||
}
|
||||
if snapshot.GatewayHost != "*" {
|
||||
t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "*")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
@@ -14,7 +14,7 @@ type Handler struct {
|
||||
serverPort int
|
||||
serverPublic bool
|
||||
serverPublicExplicit bool
|
||||
serverHost string
|
||||
serverHostInput string
|
||||
serverHostExplicit bool
|
||||
serverCIDRs []string
|
||||
debug bool
|
||||
@@ -32,7 +32,6 @@ func NewHandler(configPath string) *Handler {
|
||||
return &Handler{
|
||||
configPath: configPath,
|
||||
serverPort: launcherconfig.DefaultPort,
|
||||
serverHost: resolveDefaultLoopbackHost(),
|
||||
oauthFlows: make(map[string]*oauthFlow),
|
||||
oauthState: make(map[string]string),
|
||||
weixinFlows: make(map[string]*weixinFlow),
|
||||
@@ -45,28 +44,18 @@ func (h *Handler) SetServerOptions(port int, public bool, publicExplicit bool, a
|
||||
h.serverPort = port
|
||||
h.serverPublic = public
|
||||
h.serverPublicExplicit = publicExplicit
|
||||
h.serverHost = resolveDefaultLoopbackHost()
|
||||
if public {
|
||||
h.serverHost = resolveDefaultAnyHost()
|
||||
}
|
||||
h.serverHostInput = ""
|
||||
h.serverHostExplicit = false
|
||||
h.serverCIDRs = append([]string(nil), allowedCIDRs...)
|
||||
}
|
||||
|
||||
// SetServerBindHost stores the launcher's effective bind host.
|
||||
// When explicit is true, the value came from the -host flag.
|
||||
func (h *Handler) SetServerBindHost(host string, explicit bool) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
host = resolveDefaultLoopbackHost()
|
||||
if h.serverPublic {
|
||||
host = resolveDefaultAnyHost()
|
||||
}
|
||||
explicit = false
|
||||
// When explicit is true, hostInput is the normalized -host / PICOCLAW_LAUNCHER_HOST value.
|
||||
func (h *Handler) SetServerBindHost(hostInput string, explicit bool) {
|
||||
h.serverHostInput = strings.TrimSpace(hostInput)
|
||||
if !explicit {
|
||||
h.serverHostInput = ""
|
||||
}
|
||||
host = canonicalLauncherBindHost(host)
|
||||
|
||||
h.serverHost = host
|
||||
h.serverHostExplicit = explicit
|
||||
}
|
||||
|
||||
|
||||
+86
-301
@@ -28,6 +28,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
"github.com/sipeed/picoclaw/web/backend/api"
|
||||
"github.com/sipeed/picoclaw/web/backend/dashboardauth"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
@@ -56,50 +57,6 @@ var (
|
||||
noBrowser *bool
|
||||
)
|
||||
|
||||
type launcherBindMode string
|
||||
|
||||
type launcherRuntimeBinding struct {
|
||||
mode launcherBindMode
|
||||
host string
|
||||
}
|
||||
|
||||
const (
|
||||
launcherBindModeAutoPrivate launcherBindMode = "auto-private"
|
||||
launcherBindModeAutoPublic launcherBindMode = "auto-public"
|
||||
launcherBindModeExplicitLiteral launcherBindMode = "explicit-literal"
|
||||
launcherBindModeExplicitAdaptiveAny launcherBindMode = "explicit-adaptive-any"
|
||||
launcherBindModeExplicitAdaptiveLocal launcherBindMode = "explicit-adaptive-localhost"
|
||||
)
|
||||
|
||||
func parseLauncherHostList(raw string) ([]string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, errors.New("host cannot be empty")
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, ",")
|
||||
hosts := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
host := strings.TrimSpace(part)
|
||||
if host == "" {
|
||||
return nil, errors.New("host list contains an empty entry")
|
||||
}
|
||||
key := strings.ToLower(host)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
|
||||
if len(hosts) == 0 {
|
||||
return nil, errors.New("host cannot be empty")
|
||||
}
|
||||
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func shouldEnableLauncherFileLogging(enableConsole, debug bool) bool {
|
||||
return !enableConsole || debug
|
||||
}
|
||||
@@ -111,108 +68,38 @@ func dashboardTokenConfigHelpPath(source launcherconfig.DashboardTokenSource, la
|
||||
return launcherPath
|
||||
}
|
||||
|
||||
func resolveDefaultLauncherAnyHost() string {
|
||||
return utils.ResolveAdaptiveAnyHost()
|
||||
}
|
||||
|
||||
func resolveDefaultLauncherPrivateHost() string {
|
||||
return utils.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func normalizeLauncherSpecialHost(host string) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return host
|
||||
}
|
||||
if host == "*" {
|
||||
return resolveDefaultLauncherAnyHost()
|
||||
}
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return resolveDefaultLauncherPrivateHost()
|
||||
}
|
||||
if ip := net.ParseIP(strings.Trim(host, "[]")); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func resolveLauncherBindMode(rawHost string, hostExplicit bool, effectivePublic bool) launcherBindMode {
|
||||
if !hostExplicit {
|
||||
if effectivePublic {
|
||||
return launcherBindModeAutoPublic
|
||||
func resolveLauncherHostInput(flagHost string, explicitFlag bool, envHost string) (string, bool, error) {
|
||||
if explicitFlag {
|
||||
normalized, err := netbind.NormalizeHostInput(flagHost)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
return launcherBindModeAutoPrivate
|
||||
}
|
||||
|
||||
rawHost = strings.TrimSpace(rawHost)
|
||||
if rawHost == "*" {
|
||||
return launcherBindModeExplicitAdaptiveAny
|
||||
}
|
||||
if strings.EqualFold(rawHost, "localhost") {
|
||||
return launcherBindModeExplicitAdaptiveLocal
|
||||
}
|
||||
return launcherBindModeExplicitLiteral
|
||||
}
|
||||
|
||||
func resolveLauncherBindHost(
|
||||
host string,
|
||||
explicitHost bool,
|
||||
envHost string,
|
||||
effectivePublic bool,
|
||||
) (string, bool, bool, error) {
|
||||
if explicitHost {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return "", false, false, errors.New("host cannot be empty")
|
||||
}
|
||||
// When -host is specified, -public is ignored.
|
||||
return normalizeLauncherSpecialHost(host), false, true, nil
|
||||
return normalized, true, nil
|
||||
}
|
||||
|
||||
envHost = strings.TrimSpace(envHost)
|
||||
if envHost != "" {
|
||||
// Environment host follows explicit override semantics.
|
||||
return normalizeLauncherSpecialHost(envHost), false, true, nil
|
||||
if envHost == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
if effectivePublic {
|
||||
return resolveDefaultLauncherAnyHost(), true, false, nil
|
||||
normalized, err := netbind.NormalizeHostInput(envHost)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
return resolveDefaultLauncherPrivateHost(), false, false, nil
|
||||
return normalized, true, nil
|
||||
}
|
||||
|
||||
func isWildcardBindHost(host string) bool {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
ip := net.ParseIP(trimmed)
|
||||
return ip != nil && ip.IsUnspecified()
|
||||
}
|
||||
|
||||
func browserHostForLauncher(bindHost string) string {
|
||||
bindHost = strings.TrimSpace(bindHost)
|
||||
if bindHost == "" || isWildcardBindHost(bindHost) {
|
||||
return "localhost"
|
||||
}
|
||||
return bindHost
|
||||
}
|
||||
|
||||
func wildcardAdvertiseIP(bindHost, ipv4, ipv6 string) string {
|
||||
if !isWildcardBindHost(bindHost) {
|
||||
return ""
|
||||
func openLauncherListeners(hostInput string, public bool, port string) (netbind.OpenResult, error) {
|
||||
defaultMode := netbind.DefaultLoopback
|
||||
if strings.TrimSpace(hostInput) == "" && public {
|
||||
defaultMode = netbind.DefaultAny
|
||||
}
|
||||
|
||||
if v6 := strings.TrimSpace(ipv6); v6 != "" {
|
||||
return v6
|
||||
plan, err := netbind.BuildPlan(hostInput, defaultMode)
|
||||
if err != nil {
|
||||
return netbind.OpenResult{}, err
|
||||
}
|
||||
return strings.TrimSpace(ipv4)
|
||||
}
|
||||
|
||||
func advertiseIPForWildcardBindHost(bindHost string) string {
|
||||
return wildcardAdvertiseIP(bindHost, utils.GetLocalIPv4(), utils.GetLocalIPv6())
|
||||
return netbind.OpenPlan(plan, port)
|
||||
}
|
||||
|
||||
func appendUniqueHost(hosts []string, seen map[string]struct{}, host string) []string {
|
||||
@@ -228,124 +115,77 @@ func appendUniqueHost(hosts []string, seen map[string]struct{}, host string) []s
|
||||
return append(hosts, host)
|
||||
}
|
||||
|
||||
func launcherConsoleHosts(bindMode launcherBindMode, bindHost string, effectivePublic bool) []string {
|
||||
func hasWildcardBindHosts(bindHosts []string) bool {
|
||||
for _, bindHost := range bindHosts {
|
||||
if netbind.IsUnspecifiedHost(bindHost) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func wildcardAdvertiseIP(bindHosts []string, ipv4, ipv6 string) string {
|
||||
if !hasWildcardBindHosts(bindHosts) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if v6 := strings.TrimSpace(ipv6); v6 != "" {
|
||||
return v6
|
||||
}
|
||||
return strings.TrimSpace(ipv4)
|
||||
}
|
||||
|
||||
func advertiseIPForWildcardBindHosts(bindHosts []string) string {
|
||||
return wildcardAdvertiseIP(bindHosts, utils.GetLocalIPv4(), utils.GetLocalIPv6())
|
||||
}
|
||||
|
||||
func launcherConsoleHosts(bindHosts []string, probeHost string) []string {
|
||||
hosts := make([]string, 0, 6)
|
||||
seen := make(map[string]struct{}, 6)
|
||||
|
||||
hosts = appendUniqueHost(hosts, seen, "localhost")
|
||||
hosts = appendUniqueHost(hosts, seen, probeHost)
|
||||
|
||||
switch bindMode {
|
||||
case launcherBindModeAutoPrivate, launcherBindModeExplicitAdaptiveLocal:
|
||||
hosts = appendUniqueHost(hosts, seen, "::1")
|
||||
hosts = appendUniqueHost(hosts, seen, "127.0.0.1")
|
||||
return hosts
|
||||
case launcherBindModeAutoPublic, launcherBindModeExplicitAdaptiveAny:
|
||||
hosts = appendUniqueHost(hosts, seen, "::1")
|
||||
hosts = appendUniqueHost(hosts, seen, "127.0.0.1")
|
||||
hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv6())
|
||||
hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv4())
|
||||
return hosts
|
||||
case launcherBindModeExplicitLiteral:
|
||||
trimmed := strings.Trim(strings.TrimSpace(bindHost), "[]")
|
||||
if ip := net.ParseIP(trimmed); ip != nil {
|
||||
if ip.IsUnspecified() {
|
||||
for _, bindHost := range bindHosts {
|
||||
switch {
|
||||
case netbind.IsUnspecifiedHost(bindHost):
|
||||
if ip := net.ParseIP(strings.Trim(bindHost, "[]")); ip != nil && ip.To4() != nil {
|
||||
hosts = appendUniqueHost(hosts, seen, "127.0.0.1")
|
||||
} else {
|
||||
hosts = appendUniqueHost(hosts, seen, "::1")
|
||||
}
|
||||
case netbind.IsLoopbackHost(bindHost):
|
||||
hosts = appendUniqueHost(hosts, seen, "localhost")
|
||||
if ip := net.ParseIP(strings.Trim(bindHost, "[]")); ip != nil {
|
||||
if ip.To4() != nil {
|
||||
hosts = appendUniqueHost(hosts, seen, "127.0.0.1")
|
||||
hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv4())
|
||||
return hosts
|
||||
} else {
|
||||
hosts = appendUniqueHost(hosts, seen, "::1")
|
||||
}
|
||||
hosts = appendUniqueHost(hosts, seen, "::1")
|
||||
hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv6())
|
||||
return hosts
|
||||
}
|
||||
hosts = appendUniqueHost(hosts, seen, ip.String())
|
||||
return hosts
|
||||
default:
|
||||
hosts = appendUniqueHost(hosts, seen, bindHost)
|
||||
}
|
||||
}
|
||||
|
||||
if effectivePublic && isWildcardBindHost(bindHost) {
|
||||
if hasWildcardBindHosts(bindHosts) {
|
||||
hosts = appendUniqueHost(hosts, seen, "localhost")
|
||||
hosts = appendUniqueHost(hosts, seen, "::1")
|
||||
hosts = appendUniqueHost(hosts, seen, "127.0.0.1")
|
||||
hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv6())
|
||||
hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv4())
|
||||
return hosts
|
||||
}
|
||||
|
||||
hosts = appendUniqueHost(hosts, seen, bindHost)
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
func openLauncherListener(network, host, port string) (net.Listener, error) {
|
||||
return net.Listen(network, net.JoinHostPort(host, port))
|
||||
}
|
||||
|
||||
func openLauncherPrivateListeners(port string) ([]net.Listener, string, error) {
|
||||
if ln6, err6 := openLauncherListener("tcp6", "::1", port); err6 == nil {
|
||||
if ln4, err4 := openLauncherListener("tcp4", "127.0.0.1", port); err4 == nil {
|
||||
return []net.Listener{ln6, ln4}, "localhost", nil
|
||||
}
|
||||
_ = ln6.Close()
|
||||
}
|
||||
|
||||
if ln6, err := openLauncherListener("tcp6", "::1", port); err == nil {
|
||||
return []net.Listener{ln6}, "::1", nil
|
||||
}
|
||||
|
||||
if ln4, err := openLauncherListener("tcp4", "127.0.0.1", port); err == nil {
|
||||
return []net.Listener{ln4}, "127.0.0.1", nil
|
||||
}
|
||||
|
||||
return nil, "", fmt.Errorf("failed to open private localhost listener on port %s", port)
|
||||
}
|
||||
|
||||
func openLauncherAnyListener(port string) ([]net.Listener, string, error) {
|
||||
// For auto-public and -host=* we intentionally bind :: on "tcp" first.
|
||||
// Go's compatibility layer will provide dual-stack behavior on environments where it is supported.
|
||||
if ln, err := openLauncherListener("tcp", "::", port); err == nil {
|
||||
return []net.Listener{ln}, "::", nil
|
||||
}
|
||||
|
||||
if ln4, err := openLauncherListener("tcp4", "0.0.0.0", port); err == nil {
|
||||
return []net.Listener{ln4}, "0.0.0.0", nil
|
||||
}
|
||||
|
||||
return nil, "", fmt.Errorf("failed to open adaptive any-host listener on port %s", port)
|
||||
}
|
||||
|
||||
func openLauncherLiteralListener(host, port string) ([]net.Listener, string, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
trimmed := strings.Trim(host, "[]")
|
||||
network := "tcp"
|
||||
|
||||
if ip := net.ParseIP(trimmed); ip != nil {
|
||||
host = ip.String()
|
||||
if ip.To4() != nil {
|
||||
network = "tcp4"
|
||||
} else {
|
||||
network = "tcp6"
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
ln, err := openLauncherListener(network, host, port)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return []net.Listener{ln}, host, nil
|
||||
}
|
||||
|
||||
func openLauncherListeners(mode launcherBindMode, bindHost, port string) ([]net.Listener, string, error) {
|
||||
switch mode {
|
||||
case launcherBindModeAutoPrivate, launcherBindModeExplicitAdaptiveLocal:
|
||||
return openLauncherPrivateListeners(port)
|
||||
case launcherBindModeAutoPublic, launcherBindModeExplicitAdaptiveAny:
|
||||
return openLauncherAnyListener(port)
|
||||
case launcherBindModeExplicitLiteral:
|
||||
return openLauncherLiteralListener(bindHost, port)
|
||||
default:
|
||||
return nil, "", fmt.Errorf("unsupported launcher bind mode: %s", mode)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// maskSecret masks a secret for display. It always shows up to the first 3
|
||||
@@ -397,7 +237,7 @@ func main() {
|
||||
)
|
||||
fmt.Fprintf(os.Stderr, " Allow access from other devices on the local network\n")
|
||||
fmt.Fprintf(os.Stderr, " %s -host :: ./config.json\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, " Bind launcher host explicitly (dual-stack normalization applies)\n")
|
||||
fmt.Fprintf(os.Stderr, " Bind launcher host explicitly with exact host semantics\n")
|
||||
fmt.Fprintf(os.Stderr, " %s -console -d ./config.json\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, " Run in the terminal with debug logs enabled\n")
|
||||
}
|
||||
@@ -502,54 +342,19 @@ func main() {
|
||||
}
|
||||
envHost := strings.TrimSpace(os.Getenv(launcherconfig.EnvLauncherHost))
|
||||
|
||||
rawHostInput := strings.TrimSpace(*host)
|
||||
if !explicitHost {
|
||||
rawHostInput = envHost
|
||||
hostInput, hostOverrideActive, err := resolveLauncherHostInput(*host, explicitHost, envHost)
|
||||
if err != nil {
|
||||
logger.Fatalf("Invalid host %q: %v", firstNonEmpty(strings.TrimSpace(*host), envHost), err)
|
||||
}
|
||||
|
||||
hostExplicit := false
|
||||
effectiveHost := ""
|
||||
bindMode := launcherBindModeAutoPrivate
|
||||
bindTargets := make([]launcherRuntimeBinding, 0, 1)
|
||||
if rawHostInput != "" {
|
||||
hosts, parseErr := parseLauncherHostList(rawHostInput)
|
||||
if parseErr != nil {
|
||||
logger.Fatalf("Invalid host %q: %v", rawHostInput, parseErr)
|
||||
}
|
||||
hostExplicit = true
|
||||
if hostOverrideActive {
|
||||
effectivePublic = false
|
||||
for _, raw := range hosts {
|
||||
resolvedHost, _, _, resolveErr := resolveLauncherBindHost(raw, true, "", false)
|
||||
if resolveErr != nil {
|
||||
logger.Fatalf("Invalid host %q: %v", raw, resolveErr)
|
||||
}
|
||||
mode := resolveLauncherBindMode(raw, true, false)
|
||||
bindTargets = append(bindTargets, launcherRuntimeBinding{mode: mode, host: resolvedHost})
|
||||
}
|
||||
effectiveHost = bindTargets[0].host
|
||||
bindMode = bindTargets[0].mode
|
||||
} else {
|
||||
resolvedHost, resolvedPublic, resolvedExplicit, resolveErr := resolveLauncherBindHost(
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
effectivePublic,
|
||||
)
|
||||
if resolveErr != nil {
|
||||
logger.Fatalf("Invalid default host: %v", resolveErr)
|
||||
}
|
||||
effectiveHost = resolvedHost
|
||||
effectivePublic = resolvedPublic
|
||||
hostExplicit = resolvedExplicit
|
||||
bindMode = resolveLauncherBindMode("", false, effectivePublic)
|
||||
bindTargets = append(bindTargets, launcherRuntimeBinding{mode: bindMode, host: effectiveHost})
|
||||
}
|
||||
|
||||
if !explicitHost && envHost != "" {
|
||||
if !explicitHost && hostOverrideActive {
|
||||
logger.InfoC("web", "Using launcher host from environment PICOCLAW_LAUNCHER_HOST")
|
||||
}
|
||||
|
||||
if hostExplicit && explicitPublic {
|
||||
if hostOverrideActive && explicitPublic {
|
||||
logger.InfoC("web", "Ignoring -public because launcher host was explicitly set")
|
||||
}
|
||||
|
||||
@@ -561,21 +366,11 @@ func main() {
|
||||
logger.Fatalf("Invalid port %q: %v", effectivePort, err)
|
||||
}
|
||||
|
||||
listeners := make([]net.Listener, 0, len(bindTargets))
|
||||
runtimeBindings := make([]launcherRuntimeBinding, 0, len(bindTargets))
|
||||
for _, target := range bindTargets {
|
||||
targetListeners, runtimeHost, listenErr := openLauncherListeners(target.mode, target.host, effectivePort)
|
||||
if listenErr != nil {
|
||||
for _, ln := range listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
logger.Fatalf("Failed to open launcher listener(s): %v", listenErr)
|
||||
}
|
||||
listeners = append(listeners, targetListeners...)
|
||||
runtimeBindings = append(runtimeBindings, launcherRuntimeBinding{mode: target.mode, host: runtimeHost})
|
||||
openResult, err := openLauncherListeners(hostInput, effectivePublic, effectivePort)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open launcher listener(s): %v", err)
|
||||
}
|
||||
effectiveHost = runtimeBindings[0].host
|
||||
bindMode = runtimeBindings[0].mode
|
||||
listeners := openResult.Listeners
|
||||
|
||||
dashboardToken, dashboardSigningKey, dashboardTokenSource, dashErr := launcherconfig.EnsureDashboardSecrets(
|
||||
launcherCfg,
|
||||
@@ -620,12 +415,8 @@ func main() {
|
||||
if _, err = apiHandler.EnsurePicoChannel(""); err != nil {
|
||||
logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err))
|
||||
}
|
||||
gatewayHostExplicit := hostExplicit && len(runtimeBindings) == 1
|
||||
if hostExplicit && len(runtimeBindings) > 1 {
|
||||
logger.WarnC("web", "Multiple launcher hosts are configured; gateway host override is disabled for this run")
|
||||
}
|
||||
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
|
||||
apiHandler.SetServerBindHost(effectiveHost, gatewayHostExplicit)
|
||||
apiHandler.SetServerBindHost(hostInput, hostOverrideActive)
|
||||
apiHandler.RegisterRoutes(mux)
|
||||
|
||||
// Frontend Embedded Assets
|
||||
@@ -652,13 +443,7 @@ func main() {
|
||||
|
||||
// Print startup banner and token (console mode only).
|
||||
if enableConsole || debug {
|
||||
consoleHosts := make([]string, 0, 8)
|
||||
consoleSeen := make(map[string]struct{}, 8)
|
||||
for _, binding := range runtimeBindings {
|
||||
for _, host := range launcherConsoleHosts(binding.mode, binding.host, effectivePublic) {
|
||||
consoleHosts = appendUniqueHost(consoleHosts, consoleSeen, host)
|
||||
}
|
||||
}
|
||||
consoleHosts := launcherConsoleHosts(openResult.BindHosts, openResult.ProbeHost)
|
||||
|
||||
fmt.Print(utils.Banner)
|
||||
fmt.Println()
|
||||
@@ -694,14 +479,14 @@ func main() {
|
||||
for _, ln := range listeners {
|
||||
logger.InfoC("web", fmt.Sprintf("Server will listen on http://%s", ln.Addr().String()))
|
||||
}
|
||||
if isWildcardBindHost(effectiveHost) {
|
||||
if ip := advertiseIPForWildcardBindHost(effectiveHost); ip != "" {
|
||||
if hasWildcardBindHosts(openResult.BindHosts) {
|
||||
if ip := advertiseIPForWildcardBindHosts(openResult.BindHosts); ip != "" {
|
||||
logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s", net.JoinHostPort(ip, effectivePort)))
|
||||
}
|
||||
}
|
||||
|
||||
// Share the local URL with the launcher runtime.
|
||||
serverAddr = fmt.Sprintf("http://%s", net.JoinHostPort(browserHostForLauncher(effectiveHost), effectivePort))
|
||||
serverAddr = fmt.Sprintf("http://%s", net.JoinHostPort(openResult.ProbeHost, effectivePort))
|
||||
if dashboardToken != "" {
|
||||
browserLaunchURL = serverAddr + "?token=" + url.QueryEscape(dashboardToken)
|
||||
} else {
|
||||
|
||||
+157
-219
@@ -1,8 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/netbind"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
@@ -42,21 +50,9 @@ func TestDashboardTokenConfigHelpPath(t *testing.T) {
|
||||
source launcherconfig.DashboardTokenSource
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "env token does not expose config path",
|
||||
source: launcherconfig.DashboardTokenSourceEnv,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "config token exposes config path",
|
||||
source: launcherconfig.DashboardTokenSourceConfig,
|
||||
want: launcherPath,
|
||||
},
|
||||
{
|
||||
name: "random token does not expose config path",
|
||||
source: launcherconfig.DashboardTokenSourceRandom,
|
||||
want: "",
|
||||
},
|
||||
{name: "env token does not expose config path", source: launcherconfig.DashboardTokenSourceEnv, want: ""},
|
||||
{name: "config token exposes config path", source: launcherconfig.DashboardTokenSourceConfig, want: launcherPath},
|
||||
{name: "random token does not expose config path", source: launcherconfig.DashboardTokenSourceRandom, want: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -73,22 +69,17 @@ func TestMaskSecret(t *testing.T) {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
// Long token (>=12 chars): first 3 + 10 stars + last 4
|
||||
{"sdhjflsjdflksdf", "sdh**********ksdf"},
|
||||
{"abcdefghijklmnopqrstuvwxyz", "abc**********wxyz"},
|
||||
// Exactly 12 chars (3+4+5 hidden): suffix shown
|
||||
{"abcdefghijkl", "abc**********ijkl"},
|
||||
// 8 chars (minimum password length): suffix NOT shown — only prefix+stars
|
||||
{"abcdefgh", "abc**********"},
|
||||
// 11 chars (one below threshold): suffix NOT shown
|
||||
{"abcdefghijk", "abc**********"},
|
||||
// 4..3 chars: prefix shown, no suffix
|
||||
{"abcdefg", "abc**********"},
|
||||
{"abcd", "abc**********"},
|
||||
// <=3 chars: fully masked
|
||||
{"abc", "**********"},
|
||||
{"", "**********"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := maskSecret(tt.input); got != tt.want {
|
||||
t.Errorf("maskSecret(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
@@ -96,185 +87,46 @@ func TestMaskSecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLauncherHostList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "single host", raw: "127.0.0.1", want: []string{"127.0.0.1"}},
|
||||
{name: "multiple hosts", raw: "127.0.0.1, 192.168.2.5", want: []string{"127.0.0.1", "192.168.2.5"}},
|
||||
{name: "dedupe hosts", raw: "127.0.0.1,127.0.0.1", want: []string{"127.0.0.1"}},
|
||||
{name: "reject empty entry", raw: "127.0.0.1, ", wantErr: true},
|
||||
{name: "reject empty input", raw: " ", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseLauncherHostList(tt.raw)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("parseLauncherHostList() err = %v, wantErr %t", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("len(got) = %d, want %d (%#v)", len(got), len(tt.want), got)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Fatalf("got[%d] = %q, want %q", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLauncherBindHost(t *testing.T) {
|
||||
func TestResolveLauncherHostInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
flagHost string
|
||||
explicitFlag bool
|
||||
envHost string
|
||||
explicitHost bool
|
||||
effectivePub bool
|
||||
wantHost string
|
||||
wantPublic bool
|
||||
wantExplicit bool
|
||||
wantActive bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "explicit host overrides public",
|
||||
host: "0.0.0.0",
|
||||
explicitHost: true,
|
||||
effectivePub: true,
|
||||
wantHost: "0.0.0.0",
|
||||
wantPublic: false,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit host overrides env host",
|
||||
host: "127.0.0.1",
|
||||
envHost: "0.0.0.0",
|
||||
explicitHost: true,
|
||||
effectivePub: true,
|
||||
wantHost: "127.0.0.1",
|
||||
wantPublic: false,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit host cannot be empty",
|
||||
host: " ",
|
||||
explicitHost: true,
|
||||
effectivePub: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "env host overrides public",
|
||||
envHost: "0.0.0.0",
|
||||
explicitHost: false,
|
||||
effectivePub: true,
|
||||
wantHost: "0.0.0.0",
|
||||
wantPublic: false,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit localhost uses adaptive private host",
|
||||
host: "localhost",
|
||||
explicitHost: true,
|
||||
effectivePub: false,
|
||||
wantHost: resolveDefaultLauncherPrivateHost(),
|
||||
wantPublic: false,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "explicit star uses adaptive any host",
|
||||
host: "*",
|
||||
explicitHost: true,
|
||||
effectivePub: false,
|
||||
wantHost: resolveDefaultLauncherAnyHost(),
|
||||
wantPublic: false,
|
||||
wantExplicit: true,
|
||||
},
|
||||
{
|
||||
name: "public mode without explicit host",
|
||||
host: "",
|
||||
explicitHost: false,
|
||||
effectivePub: true,
|
||||
wantHost: resolveDefaultLauncherAnyHost(),
|
||||
wantPublic: true,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{
|
||||
name: "private mode without explicit host",
|
||||
host: "",
|
||||
explicitHost: false,
|
||||
effectivePub: false,
|
||||
wantHost: resolveDefaultLauncherPrivateHost(),
|
||||
wantPublic: false,
|
||||
wantExplicit: false,
|
||||
},
|
||||
{name: "flag host wins", flagHost: "127.0.0.1", explicitFlag: true, envHost: "::", wantHost: "127.0.0.1", wantActive: true},
|
||||
{name: "env host used when flag absent", envHost: "127.0.0.1,::1", wantHost: "127.0.0.1,::1", wantActive: true},
|
||||
{name: "blank env ignored", envHost: " ", wantHost: "", wantActive: false},
|
||||
{name: "invalid flag rejected", flagHost: "127.0.0.1, ", explicitFlag: true, wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotHost, gotPublic, gotExplicit, err := resolveLauncherBindHost(
|
||||
tt.host,
|
||||
tt.explicitHost,
|
||||
tt.envHost,
|
||||
tt.effectivePub,
|
||||
)
|
||||
gotHost, gotActive, err := resolveLauncherHostInput(tt.flagHost, tt.explicitFlag, tt.envHost)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("resolveLauncherBindHost() error = %v, wantErr %t", err, tt.wantErr)
|
||||
t.Fatalf("resolveLauncherHostInput() err = %v, wantErr %t", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
if gotHost != tt.wantHost {
|
||||
t.Fatalf("resolveLauncherBindHost() host = %q, want %q", gotHost, tt.wantHost)
|
||||
t.Fatalf("resolveLauncherHostInput() host = %q, want %q", gotHost, tt.wantHost)
|
||||
}
|
||||
if gotPublic != tt.wantPublic {
|
||||
t.Fatalf("resolveLauncherBindHost() public = %t, want %t", gotPublic, tt.wantPublic)
|
||||
}
|
||||
if gotExplicit != tt.wantExplicit {
|
||||
t.Fatalf("resolveLauncherBindHost() explicit = %t, want %t", gotExplicit, tt.wantExplicit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLauncherBindMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawHost string
|
||||
hostExplicit bool
|
||||
effectivePub bool
|
||||
wantMode launcherBindMode
|
||||
}{
|
||||
{name: "auto private", rawHost: "", hostExplicit: false, effectivePub: false, wantMode: launcherBindModeAutoPrivate},
|
||||
{name: "auto public", rawHost: "", hostExplicit: false, effectivePub: true, wantMode: launcherBindModeAutoPublic},
|
||||
{name: "explicit localhost", rawHost: "localhost", hostExplicit: true, effectivePub: false, wantMode: launcherBindModeExplicitAdaptiveLocal},
|
||||
{name: "explicit star", rawHost: "*", hostExplicit: true, effectivePub: false, wantMode: launcherBindModeExplicitAdaptiveAny},
|
||||
{name: "explicit literal", rawHost: "0.0.0.0", hostExplicit: true, effectivePub: false, wantMode: launcherBindModeExplicitLiteral},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveLauncherBindMode(tt.rawHost, tt.hostExplicit, tt.effectivePub); got != tt.wantMode {
|
||||
t.Fatalf("resolveLauncherBindMode() = %q, want %q", got, tt.wantMode)
|
||||
if gotActive != tt.wantActive {
|
||||
t.Fatalf("resolveLauncherHostInput() active = %t, want %t", gotActive, tt.wantActive)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLauncherConsoleHosts(t *testing.T) {
|
||||
t.Run("auto private includes dual loopback hints", func(t *testing.T) {
|
||||
hosts := launcherConsoleHosts(launcherBindModeAutoPrivate, "localhost", false)
|
||||
t.Run("wildcard exposes local loopback hints", func(t *testing.T) {
|
||||
hosts := launcherConsoleHosts([]string{"::"}, netbind.ResolveAdaptiveLoopbackHost())
|
||||
seen := make(map[string]bool, len(hosts))
|
||||
for _, host := range hosts {
|
||||
if seen[host] {
|
||||
t.Fatalf("duplicate host %q in %#v", host, hosts)
|
||||
}
|
||||
seen[host] = true
|
||||
}
|
||||
if !seen["localhost"] {
|
||||
@@ -288,63 +140,149 @@ func TestLauncherConsoleHosts(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit ipv4 wildcard excludes ipv6 loopback", func(t *testing.T) {
|
||||
hosts := launcherConsoleHosts(launcherBindModeExplicitLiteral, "0.0.0.0", false)
|
||||
seen := make(map[string]bool, len(hosts))
|
||||
for _, host := range hosts {
|
||||
seen[host] = true
|
||||
}
|
||||
if seen["::1"] {
|
||||
t.Fatalf("did not expect ::1 in %#v", hosts)
|
||||
}
|
||||
if !seen["127.0.0.1"] {
|
||||
t.Fatalf("expected 127.0.0.1 in %#v", hosts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit ipv6 host remains visible", func(t *testing.T) {
|
||||
hosts := launcherConsoleHosts(launcherBindModeExplicitLiteral, "::1", false)
|
||||
if len(hosts) != 2 {
|
||||
t.Fatalf("len(hosts) = %d, want 2 (%#v)", len(hosts), hosts)
|
||||
}
|
||||
if hosts[0] != "localhost" || hosts[1] != "::1" {
|
||||
t.Fatalf("hosts = %#v, want [localhost ::1]", hosts)
|
||||
hosts := launcherConsoleHosts([]string{"::1"}, "::1")
|
||||
if len(hosts) < 1 || hosts[0] != "::1" {
|
||||
t.Fatalf("hosts = %#v, want probe host first", hosts)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBrowserHostForLauncher(t *testing.T) {
|
||||
if got := browserHostForLauncher("0.0.0.0"); got != "localhost" {
|
||||
t.Fatalf("browserHostForLauncher(0.0.0.0) = %q, want %q", got, "localhost")
|
||||
}
|
||||
if got := browserHostForLauncher("::"); got != "localhost" {
|
||||
t.Fatalf("browserHostForLauncher(::) = %q, want %q", got, "localhost")
|
||||
}
|
||||
if got := browserHostForLauncher("192.168.1.10"); got != "192.168.1.10" {
|
||||
t.Fatalf("browserHostForLauncher(192.168.1.10) = %q, want %q", got, "192.168.1.10")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWildcardAdvertiseIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bindHost string
|
||||
ipv4 string
|
||||
ipv6 string
|
||||
want string
|
||||
name string
|
||||
bindHosts []string
|
||||
ipv4 string
|
||||
ipv6 string
|
||||
want string
|
||||
}{
|
||||
{name: "ipv4 wildcard prefers ipv6 when available", bindHost: "0.0.0.0", ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: "2001:db8::1"},
|
||||
{name: "ipv6 wildcard uses ipv6", bindHost: "::", ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: "2001:db8::1"},
|
||||
{name: "ipv6 wildcard falls back to ipv4", bindHost: "::", ipv4: "192.168.1.2", ipv6: "", want: "192.168.1.2"},
|
||||
{name: "ipv4 wildcard uses ipv6-only network", bindHost: "0.0.0.0", ipv4: "", ipv6: "2001:db8::1", want: "2001:db8::1"},
|
||||
{name: "non wildcard does not advertise", bindHost: "127.0.0.1", ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: ""},
|
||||
{name: "ipv4 wildcard prefers ipv6 when available", bindHosts: []string{"0.0.0.0"}, ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: "2001:db8::1"},
|
||||
{name: "ipv6 wildcard uses ipv6", bindHosts: []string{"::"}, ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: "2001:db8::1"},
|
||||
{name: "ipv6 wildcard falls back to ipv4", bindHosts: []string{"::"}, ipv4: "192.168.1.2", ipv6: "", want: "192.168.1.2"},
|
||||
{name: "non wildcard does not advertise", bindHosts: []string{"127.0.0.1"}, ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := wildcardAdvertiseIP(tt.bindHost, tt.ipv4, tt.ipv6); got != tt.want {
|
||||
t.Fatalf("wildcardAdvertiseIP(%q, %q, %q) = %q, want %q", tt.bindHost, tt.ipv4, tt.ipv6, got, tt.want)
|
||||
if got := wildcardAdvertiseIP(tt.bindHosts, tt.ipv4, tt.ipv6); got != tt.want {
|
||||
t.Fatalf("wildcardAdvertiseIP(%#v, %q, %q) = %q, want %q", tt.bindHosts, tt.ipv4, tt.ipv6, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenLauncherListeners_HonorsIPv6OnlyHost(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
|
||||
if !hasIPv6 {
|
||||
t.Skip("IPv6 is unavailable in this environment")
|
||||
}
|
||||
|
||||
result, err := openLauncherListeners("::", false, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("openLauncherListeners() error = %v", err)
|
||||
}
|
||||
startLauncherTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireLauncherHTTPReachable(t, "::1", port)
|
||||
if hasIPv4 {
|
||||
requireLauncherHTTPUnreachable(t, "127.0.0.1", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenLauncherListeners_SupportsExplicitMultiHost(t *testing.T) {
|
||||
hasIPv4, hasIPv6 := netbind.DetectIPFamilies()
|
||||
if !hasIPv4 || !hasIPv6 {
|
||||
t.Skip("dual-stack loopback is unavailable in this environment")
|
||||
}
|
||||
|
||||
result, err := openLauncherListeners("127.0.0.1,::1", false, "0")
|
||||
if err != nil {
|
||||
t.Fatalf("openLauncherListeners() error = %v", err)
|
||||
}
|
||||
startLauncherTestHTTPServer(t, result.Listeners)
|
||||
port := mustAtoi(t, result.Port)
|
||||
|
||||
requireLauncherHTTPReachable(t, "127.0.0.1", port)
|
||||
requireLauncherHTTPReachable(t, "::1", port)
|
||||
}
|
||||
|
||||
func startLauncherTestHTTPServer(t *testing.T, listeners []net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}),
|
||||
}
|
||||
|
||||
errCh := make(chan error, len(listeners))
|
||||
for _, listener := range listeners {
|
||||
ln := listener
|
||||
go func() {
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(ctx)
|
||||
for range listeners {
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Fatalf("server.Serve() error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func requireLauncherHTTPReachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
err := launcherHTTPGet(host, port)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("expected %s:%d to be reachable: %v", host, port, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func requireLauncherHTTPUnreachable(t *testing.T, host string, port int) {
|
||||
t.Helper()
|
||||
if err := launcherHTTPGet(host, port); err == nil {
|
||||
t.Fatalf("expected %s:%d to be unreachable", host, port)
|
||||
}
|
||||
}
|
||||
|
||||
func launcherHTTPGet(host string, port int) error {
|
||||
client := &http.Client{
|
||||
Timeout: 300 * time.Millisecond,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustAtoi(t *testing.T, value string) int {
|
||||
t.Helper()
|
||||
n, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", value, err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -7,91 +7,11 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ipFamiliesOnce sync.Once
|
||||
hasIPv4 bool
|
||||
hasIPv6 bool
|
||||
)
|
||||
|
||||
func DetectIPFamilies() (bool, bool) {
|
||||
ipFamiliesOnce.Do(func() {
|
||||
if ips, err := net.LookupIP("localhost"); err == nil {
|
||||
for _, ip := range ips {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
hasIPv4 = true
|
||||
continue
|
||||
}
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasIPv4 && hasIPv6 {
|
||||
return
|
||||
}
|
||||
|
||||
if addrs, err := net.InterfaceAddrs(); err == nil {
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipnet.IP == nil {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
hasIPv4 = true
|
||||
continue
|
||||
}
|
||||
hasIPv6 = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return hasIPv4, hasIPv6
|
||||
}
|
||||
|
||||
func SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "localhost"
|
||||
case hasIPv6:
|
||||
return "::1"
|
||||
case hasIPv4:
|
||||
return "127.0.0.1"
|
||||
default:
|
||||
return "localhost"
|
||||
}
|
||||
}
|
||||
|
||||
func SelectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string {
|
||||
switch {
|
||||
case hasIPv4 && hasIPv6:
|
||||
return "::"
|
||||
case hasIPv6:
|
||||
return "::"
|
||||
case hasIPv4:
|
||||
return "0.0.0.0"
|
||||
default:
|
||||
return "::"
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveAdaptiveLoopbackHost() string {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
return SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
func ResolveAdaptiveAnyHost() string {
|
||||
hasIPv4, hasIPv6 := DetectIPFamilies()
|
||||
return SelectAdaptiveAnyHost(hasIPv4, hasIPv6)
|
||||
}
|
||||
|
||||
// GetPicoclawHome returns the picoclaw home directory.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
func GetPicoclawHome() string {
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSelectAdaptiveLoopbackHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasIPv4 bool
|
||||
hasIPv6 bool
|
||||
want string
|
||||
}{
|
||||
{name: "dual stack", hasIPv4: true, hasIPv6: true, want: "localhost"},
|
||||
{name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::1"},
|
||||
{name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "127.0.0.1"},
|
||||
{name: "fallback", hasIPv4: false, hasIPv6: false, want: "localhost"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := SelectAdaptiveLoopbackHost(tt.hasIPv4, tt.hasIPv6); got != tt.want {
|
||||
t.Fatalf("SelectAdaptiveLoopbackHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectAdaptiveAnyHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasIPv4 bool
|
||||
hasIPv6 bool
|
||||
want string
|
||||
}{
|
||||
{name: "dual stack", hasIPv4: true, hasIPv6: true, want: "::"},
|
||||
{name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::"},
|
||||
{name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "0.0.0.0"},
|
||||
{name: "fallback", hasIPv4: false, hasIPv6: false, want: "::"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := SelectAdaptiveAnyHost(tt.hasIPv4, tt.hasIPv6); got != tt.want {
|
||||
t.Fatalf("SelectAdaptiveAnyHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAdaptiveHosts(t *testing.T) {
|
||||
loopback := ResolveAdaptiveLoopbackHost()
|
||||
if loopback == "" {
|
||||
t.Fatal("ResolveAdaptiveLoopbackHost() returned empty host")
|
||||
}
|
||||
|
||||
anyHost := ResolveAdaptiveAnyHost()
|
||||
if anyHost == "" {
|
||||
t.Fatal("ResolveAdaptiveAnyHost() returned empty host")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user