From 7b38d437ba7fe5197a8e459195ad39fb220891c9 Mon Sep 17 00:00:00 2001 From: lc6464 <64722907+lc6464@users.noreply.github.com> Date: Tue, 14 Apr 2026 09:10:44 +0800 Subject: [PATCH] feat(launcher): support multi-host bind and strict host semantics --- web/backend/api/gateway_host.go | 91 +------ web/backend/api/gateway_host_test.go | 37 +-- web/backend/app_runtime.go | 33 ++- web/backend/main.go | 388 ++++++++++++++++++--------- web/backend/main_test.go | 99 ++++++- web/backend/utils/runtime.go | 80 ++++++ web/backend/utils/runtime_test.go | 59 ++++ 7 files changed, 526 insertions(+), 261 deletions(-) create mode 100644 web/backend/utils/runtime_test.go diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index 6934c2652..055c90bdf 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -6,43 +6,17 @@ import ( "net/url" "strconv" "strings" - "sync" "github.com/sipeed/picoclaw/pkg/config" -) - -var ( - adaptiveIPFamiliesOnce sync.Once - adaptiveHasIPv4 bool - adaptiveHasIPv6 bool - lookupLocalhostIPs = func() ([]net.IP, error) { return net.LookupIP("localhost") } - listInterfaceAddrs = net.InterfaceAddrs + "github.com/sipeed/picoclaw/web/backend/utils" ) func selectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string { - switch { - case hasIPv4 && hasIPv6: - return "localhost" - case hasIPv6: - return "::1" - case hasIPv4: - return "127.0.0.1" - default: - return "localhost" - } + return utils.SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6) } func selectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string { - switch { - case hasIPv4 && hasIPv6: - return "::" - case hasIPv6: - return "::" - case hasIPv4: - return "0.0.0.0" - default: - return "::" - } + return utils.SelectAdaptiveAnyHost(hasIPv4, hasIPv6) } func isLoopbackEquivalentHost(host string) bool { @@ -58,63 +32,12 @@ func isLoopbackEquivalentHost(host string) bool { return ip != nil && ip.IsLoopback() } -func detectAdaptiveIPFamilies() (bool, bool) { - adaptiveIPFamiliesOnce.Do(func() { - if ips, err := lookupLocalhostIPs(); err == nil { - for _, ip := range ips { - if ip == nil { - continue - } - if ip.To4() != nil { - adaptiveHasIPv4 = true - continue - } - adaptiveHasIPv6 = true - } - } - - if adaptiveHasIPv4 && adaptiveHasIPv6 { - return - } - - if addrs, err := listInterfaceAddrs(); err == nil { - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - if !ok || ipnet.IP == nil { - continue - } - if ipnet.IP.To4() != nil { - adaptiveHasIPv4 = true - continue - } - adaptiveHasIPv6 = true - } - } - }) - - return adaptiveHasIPv4, adaptiveHasIPv6 -} - -func resolveAdaptiveLoopbackHost() string { - hasIPv4, hasIPv6 := detectAdaptiveIPFamilies() - return selectAdaptiveLoopbackHost(hasIPv4, hasIPv6) -} - -func resolveAdaptiveAnyHost() string { - hasIPv4, hasIPv6 := detectAdaptiveIPFamilies() - return selectAdaptiveAnyHost(hasIPv4, hasIPv6) -} - func resolveDefaultLoopbackHost() string { - return resolveAdaptiveLoopbackHost() + return utils.ResolveAdaptiveLoopbackHost() } func resolveDefaultAnyHost() string { - return resolveAdaptiveAnyHost() -} - -func resolveLocalhostLoopbackHost() string { - return resolveAdaptiveLoopbackHost() + return utils.ResolveAdaptiveAnyHost() } func (h *Handler) effectiveLauncherPublic() bool { @@ -141,7 +64,7 @@ func canonicalLauncherBindHost(host string) string { return resolveDefaultLoopbackHost() } if strings.EqualFold(host, "localhost") { - return resolveLocalhostLoopbackHost() + return resolveDefaultLoopbackHost() } trimmed := strings.Trim(host, "[]") if ip := net.ParseIP(trimmed); ip != nil && ip.IsUnspecified() { @@ -207,7 +130,7 @@ func gatewayProbeHost(bindHost string) string { return resolveDefaultLoopbackHost() } if strings.EqualFold(bindHost, "localhost") { - return resolveLocalhostLoopbackHost() + return resolveDefaultLoopbackHost() } trimmed := strings.Trim(bindHost, "[]") diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index 71de515f9..5f3181085 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "path/filepath" - "sync" "testing" "time" @@ -15,12 +14,6 @@ import ( "github.com/sipeed/picoclaw/web/backend/launcherconfig" ) -func resetAdaptiveIPFamiliesForTest() { - adaptiveIPFamiliesOnce = sync.Once{} - adaptiveHasIPv4 = false - adaptiveHasIPv6 = false -} - func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") launcherPath := launcherconfig.PathForAppConfig(configPath) @@ -115,34 +108,6 @@ func TestSelectAdaptiveAnyHost(t *testing.T) { } } -func TestAdaptiveHostSelectionFallsBackToInterfaceAddrs(t *testing.T) { - oldLookup := lookupLocalhostIPs - oldList := listInterfaceAddrs - lookupLocalhostIPs = func() ([]net.IP, error) { - return nil, errors.New("lookup failed") - } - _, v4Net, err := net.ParseCIDR("192.0.2.10/24") - if err != nil { - t.Fatalf("ParseCIDR() error = %v", err) - } - listInterfaceAddrs = func() ([]net.Addr, error) { - return []net.Addr{v4Net}, nil - } - resetAdaptiveIPFamiliesForTest() - t.Cleanup(func() { - lookupLocalhostIPs = oldLookup - listInterfaceAddrs = oldList - resetAdaptiveIPFamiliesForTest() - }) - - if got := resolveDefaultAnyHost(); got != "0.0.0.0" { - t.Fatalf("resolveDefaultAnyHost() = %q, want %q", got, "0.0.0.0") - } - if got := resolveDefaultLoopbackHost(); got != "127.0.0.1" { - t.Fatalf("resolveDefaultLoopbackHost() = %q, want %q", got, "127.0.0.1") - } -} - func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) { want := resolveDefaultLoopbackHost() if got := gatewayProbeHost("0.0.0.0"); got != want { @@ -158,7 +123,7 @@ func TestGatewayProbeHostUsesPreferredLoopbackForEmptyBind(t *testing.T) { } func TestGatewayProbeHostUsesPreferredLoopbackForLocalhostBind(t *testing.T) { - want := resolveLocalhostLoopbackHost() + want := resolveDefaultLoopbackHost() if got := gatewayProbeHost("localhost"); got != want { t.Fatalf("gatewayProbeHost(localhost) = %q, want %q", got, want) } diff --git a/web/backend/app_runtime.go b/web/backend/app_runtime.go index ab564db2c..674c0d4e6 100644 --- a/web/backend/app_runtime.go +++ b/web/backend/app_runtime.go @@ -34,22 +34,29 @@ func shutdownApp() { apiHandler.Shutdown() } - if server != nil { - // Disable keep-alive to allow graceful shutdown - server.SetKeepAlivesEnabled(false) - + if len(servers) > 0 { ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() - if err := server.Shutdown(ctx); err != nil { - // Context deadline exceeded is expected if there are active connections - // This is not necessarily an error, so log it at info level - if errors.Is(err, context.DeadlineExceeded) { - logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout) - } else { - logger.Errorf("Server shutdown error: %v", err) + + for _, srv := range servers { + if srv == nil { + continue + } + + // Disable keep-alive to allow graceful shutdown + srv.SetKeepAlivesEnabled(false) + + if err := srv.Shutdown(ctx); err != nil { + // Context deadline exceeded is expected if there are active connections + // This is not necessarily an error, so log it at info level + if errors.Is(err, context.DeadlineExceeded) { + logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout) + } else { + logger.Errorf("Server shutdown error: %v", err) + } + } else { + logger.Infof("Server shutdown completed successfully") } - } else { - logger.Infof("Server shutdown completed successfully") } } } diff --git a/web/backend/main.go b/web/backend/main.go index e6cfa2247..6201c130a 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -23,7 +23,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "syscall" "time" @@ -47,11 +46,7 @@ const ( var ( appVersion = config.Version - launcherIPFamiliesOnce sync.Once - launcherHasIPv4 bool - launcherHasIPv6 bool - - server *http.Server + servers []*http.Server serverAddr string // browserLaunchURL is opened by openBrowser() (auto-open + tray "open console"). // Includes ?token= for same-machine dashboard login; keep serverAddr without secrets for other use. @@ -61,6 +56,50 @@ var ( noBrowser *bool ) +type launcherBindMode string + +type launcherRuntimeBinding struct { + mode launcherBindMode + host string +} + +const ( + launcherBindModeAutoPrivate launcherBindMode = "auto-private" + launcherBindModeAutoPublic launcherBindMode = "auto-public" + launcherBindModeExplicitLiteral launcherBindMode = "explicit-literal" + launcherBindModeExplicitAdaptiveAny launcherBindMode = "explicit-adaptive-any" + launcherBindModeExplicitAdaptiveLocal launcherBindMode = "explicit-adaptive-localhost" +) + +func parseLauncherHostList(raw string) ([]string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, errors.New("host cannot be empty") + } + + parts := strings.Split(raw, ",") + hosts := make([]string, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + for _, part := range parts { + host := strings.TrimSpace(part) + if host == "" { + return nil, errors.New("host list contains an empty entry") + } + key := strings.ToLower(host) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + hosts = append(hosts, host) + } + + if len(hosts) == 0 { + return nil, errors.New("host cannot be empty") + } + + return hosts, nil +} + func shouldEnableLauncherFileLogging(enableConsole, debug bool) bool { return !enableConsole || debug } @@ -72,86 +111,12 @@ func dashboardTokenConfigHelpPath(source launcherconfig.DashboardTokenSource, la return launcherPath } -func detectLauncherIPFamilies() (bool, bool) { - launcherIPFamiliesOnce.Do(func() { - if ips, err := net.LookupIP("localhost"); err == nil { - for _, ip := range ips { - if ip == nil { - continue - } - if ip.To4() != nil { - launcherHasIPv4 = true - continue - } - launcherHasIPv6 = true - } - } - - if launcherHasIPv4 && launcherHasIPv6 { - return - } - - if addrs, err := net.InterfaceAddrs(); err == nil { - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - if !ok || ipnet.IP == nil { - continue - } - if ipnet.IP.To4() != nil { - launcherHasIPv4 = true - continue - } - launcherHasIPv6 = true - } - } - }) - - return launcherHasIPv4, launcherHasIPv6 -} - -func selectAdaptiveLauncherLoopbackHost(hasIPv4, hasIPv6 bool) string { - switch { - case hasIPv4 && hasIPv6: - return "localhost" - case hasIPv6: - return "::1" - case hasIPv4: - return "127.0.0.1" - default: - return "localhost" - } -} - -func selectAdaptiveLauncherAnyHost(hasIPv4, hasIPv6 bool) string { - switch { - case hasIPv4 && hasIPv6: - return "::" - case hasIPv6: - return "::" - case hasIPv4: - return "0.0.0.0" - default: - return "::" - } -} - -func resolveDefaultLauncherLoopbackHost() string { - hasIPv4, hasIPv6 := detectLauncherIPFamilies() - return selectAdaptiveLauncherLoopbackHost(hasIPv4, hasIPv6) -} - func resolveDefaultLauncherAnyHost() string { - hasIPv4, hasIPv6 := detectLauncherIPFamilies() - return selectAdaptiveLauncherAnyHost(hasIPv4, hasIPv6) + return utils.ResolveAdaptiveAnyHost() } func resolveDefaultLauncherPrivateHost() string { - hasIPv4, hasIPv6 := detectLauncherIPFamilies() - if hasIPv4 && hasIPv6 { - // In dual-stack environments, use wildcard IPv6 bind so localhost can serve both families. - return selectAdaptiveLauncherAnyHost(hasIPv4, hasIPv6) - } - return selectAdaptiveLauncherLoopbackHost(hasIPv4, hasIPv6) + return utils.ResolveAdaptiveLoopbackHost() } func normalizeLauncherSpecialHost(host string) string { @@ -159,16 +124,36 @@ func normalizeLauncherSpecialHost(host string) string { if host == "" { return host } - if strings.EqualFold(host, "localhost") { - return resolveDefaultLauncherLoopbackHost() - } - trimmed := strings.Trim(host, "[]") - if ip := net.ParseIP(trimmed); ip != nil && ip.IsUnspecified() { + if host == "*" { return resolveDefaultLauncherAnyHost() } + if strings.EqualFold(host, "localhost") { + return resolveDefaultLauncherPrivateHost() + } + if ip := net.ParseIP(strings.Trim(host, "[]")); ip != nil { + return ip.String() + } return host } +func resolveLauncherBindMode(rawHost string, hostExplicit bool, effectivePublic bool) launcherBindMode { + if !hostExplicit { + if effectivePublic { + return launcherBindModeAutoPublic + } + return launcherBindModeAutoPrivate + } + + rawHost = strings.TrimSpace(rawHost) + if rawHost == "*" { + return launcherBindModeExplicitAdaptiveAny + } + if strings.EqualFold(rawHost, "localhost") { + return launcherBindModeExplicitAdaptiveLocal + } + return launcherBindModeExplicitLiteral +} + func resolveLauncherBindHost( host string, explicitHost bool, @@ -243,30 +228,126 @@ func appendUniqueHost(hosts []string, seen map[string]struct{}, host string) []s return append(hosts, host) } -func launcherConsoleHosts(bindHost string, hostExplicit bool, effectivePublic bool) []string { +func launcherConsoleHosts(bindMode launcherBindMode, bindHost string, effectivePublic bool) []string { hosts := make([]string, 0, 6) seen := make(map[string]struct{}, 6) hosts = appendUniqueHost(hosts, seen, "localhost") - if isWildcardBindHost(bindHost) { + switch bindMode { + case launcherBindModeAutoPrivate, launcherBindModeExplicitAdaptiveLocal: hosts = appendUniqueHost(hosts, seen, "::1") hosts = appendUniqueHost(hosts, seen, "127.0.0.1") - - if effectivePublic || hostExplicit { - hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv6()) - hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv4()) + return hosts + case launcherBindModeAutoPublic, launcherBindModeExplicitAdaptiveAny: + hosts = appendUniqueHost(hosts, seen, "::1") + hosts = appendUniqueHost(hosts, seen, "127.0.0.1") + hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv6()) + hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv4()) + return hosts + case launcherBindModeExplicitLiteral: + trimmed := strings.Trim(strings.TrimSpace(bindHost), "[]") + if ip := net.ParseIP(trimmed); ip != nil { + if ip.IsUnspecified() { + if ip.To4() != nil { + hosts = appendUniqueHost(hosts, seen, "127.0.0.1") + hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv4()) + return hosts + } + hosts = appendUniqueHost(hosts, seen, "::1") + hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv6()) + return hosts + } + hosts = appendUniqueHost(hosts, seen, ip.String()) + return hosts } + } + + if effectivePublic && isWildcardBindHost(bindHost) { + hosts = appendUniqueHost(hosts, seen, "::1") + hosts = appendUniqueHost(hosts, seen, "127.0.0.1") + hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv6()) + hosts = appendUniqueHost(hosts, seen, utils.GetLocalIPv4()) return hosts } - if hostExplicit { - hosts = appendUniqueHost(hosts, seen, bindHost) - } + hosts = appendUniqueHost(hosts, seen, bindHost) return hosts } +func openLauncherListener(network, host, port string) (net.Listener, error) { + return net.Listen(network, net.JoinHostPort(host, port)) +} + +func openLauncherPrivateListeners(port string) ([]net.Listener, string, error) { + if ln6, err6 := openLauncherListener("tcp6", "::1", port); err6 == nil { + if ln4, err4 := openLauncherListener("tcp4", "127.0.0.1", port); err4 == nil { + return []net.Listener{ln6, ln4}, "localhost", nil + } + _ = ln6.Close() + } + + if ln6, err := openLauncherListener("tcp6", "::1", port); err == nil { + return []net.Listener{ln6}, "::1", nil + } + + if ln4, err := openLauncherListener("tcp4", "127.0.0.1", port); err == nil { + return []net.Listener{ln4}, "127.0.0.1", nil + } + + return nil, "", fmt.Errorf("failed to open private localhost listener on port %s", port) +} + +func openLauncherAnyListener(port string) ([]net.Listener, string, error) { + // For auto-public and -host=* we intentionally bind :: on "tcp" first. + // Go's compatibility layer will provide dual-stack behavior on environments where it is supported. + if ln, err := openLauncherListener("tcp", "::", port); err == nil { + return []net.Listener{ln}, "::", nil + } + + if ln4, err := openLauncherListener("tcp4", "0.0.0.0", port); err == nil { + return []net.Listener{ln4}, "0.0.0.0", nil + } + + return nil, "", fmt.Errorf("failed to open adaptive any-host listener on port %s", port) +} + +func openLauncherLiteralListener(host, port string) ([]net.Listener, string, error) { + host = strings.TrimSpace(host) + trimmed := strings.Trim(host, "[]") + network := "tcp" + + if ip := net.ParseIP(trimmed); ip != nil { + host = ip.String() + if ip.To4() != nil { + network = "tcp4" + } else { + network = "tcp6" + } + } + + ln, err := openLauncherListener(network, host, port) + if err != nil { + return nil, "", err + } + + return []net.Listener{ln}, host, nil +} + +func openLauncherListeners(mode launcherBindMode, bindHost, port string) ([]net.Listener, string, error) { + switch mode { + case launcherBindModeAutoPrivate, launcherBindModeExplicitAdaptiveLocal: + return openLauncherPrivateListeners(port) + case launcherBindModeAutoPublic, launcherBindModeExplicitAdaptiveAny: + return openLauncherAnyListener(port) + case launcherBindModeExplicitLiteral: + return openLauncherLiteralListener(bindHost, port) + default: + return nil, "", fmt.Errorf("unsupported launcher bind mode: %s", mode) + } +} + // maskSecret masks a secret for display. It always shows up to the first 3 // runes. The last 4 runes are only appended when at least 5 runes remain // hidden in the middle (i.e. string length >= 12), so an 8-char minimum @@ -421,20 +502,47 @@ func main() { } envHost := strings.TrimSpace(os.Getenv(launcherconfig.EnvLauncherHost)) - effectiveHost, effectivePublic, hostExplicit, err := resolveLauncherBindHost( - *host, - explicitHost, - envHost, - effectivePublic, - ) - if err != nil { - logger.Fatalf("Invalid host %q: %v", *host, err) + rawHostInput := strings.TrimSpace(*host) + if !explicitHost { + rawHostInput = envHost } - effectiveAllowedCIDRs := append([]string(nil), launcherCfg.AllowedCIDRs...) - if len(effectiveAllowedCIDRs) == 0 && !effectivePublic && !hostExplicit && isWildcardBindHost(effectiveHost) { - effectiveAllowedCIDRs = []string{"127.0.0.1/32", "::1/128"} - logger.InfoC("web", "Applying loopback-only access policy for default dual-stack bind") + hostExplicit := false + effectiveHost := "" + bindMode := launcherBindModeAutoPrivate + bindTargets := make([]launcherRuntimeBinding, 0, 1) + if rawHostInput != "" { + hosts, parseErr := parseLauncherHostList(rawHostInput) + if parseErr != nil { + logger.Fatalf("Invalid host %q: %v", rawHostInput, parseErr) + } + hostExplicit = true + effectivePublic = false + for _, raw := range hosts { + resolvedHost, _, _, resolveErr := resolveLauncherBindHost(raw, true, "", false) + if resolveErr != nil { + logger.Fatalf("Invalid host %q: %v", raw, resolveErr) + } + mode := resolveLauncherBindMode(raw, true, false) + bindTargets = append(bindTargets, launcherRuntimeBinding{mode: mode, host: resolvedHost}) + } + effectiveHost = bindTargets[0].host + bindMode = bindTargets[0].mode + } else { + resolvedHost, resolvedPublic, resolvedExplicit, resolveErr := resolveLauncherBindHost( + "", + false, + "", + effectivePublic, + ) + if resolveErr != nil { + logger.Fatalf("Invalid default host: %v", resolveErr) + } + effectiveHost = resolvedHost + effectivePublic = resolvedPublic + hostExplicit = resolvedExplicit + bindMode = resolveLauncherBindMode("", false, effectivePublic) + bindTargets = append(bindTargets, launcherRuntimeBinding{mode: bindMode, host: effectiveHost}) } if !explicitHost && envHost != "" { @@ -453,6 +561,22 @@ func main() { logger.Fatalf("Invalid port %q: %v", effectivePort, err) } + listeners := make([]net.Listener, 0, len(bindTargets)) + runtimeBindings := make([]launcherRuntimeBinding, 0, len(bindTargets)) + for _, target := range bindTargets { + targetListeners, runtimeHost, listenErr := openLauncherListeners(target.mode, target.host, effectivePort) + if listenErr != nil { + for _, ln := range listeners { + _ = ln.Close() + } + logger.Fatalf("Failed to open launcher listener(s): %v", listenErr) + } + listeners = append(listeners, targetListeners...) + runtimeBindings = append(runtimeBindings, launcherRuntimeBinding{mode: target.mode, host: runtimeHost}) + } + effectiveHost = runtimeBindings[0].host + bindMode = runtimeBindings[0].mode + dashboardToken, dashboardSigningKey, dashboardTokenSource, dashErr := launcherconfig.EnsureDashboardSecrets( launcherCfg, ) @@ -480,9 +604,6 @@ func main() { logger.ErrorC("web", fmt.Sprintf("Warning: could not open auth store: %v", authStoreErr)) } - // Determine listen address - addr := net.JoinHostPort(effectiveHost, effectivePort) - // Initialize Server components mux := http.NewServeMux() @@ -499,14 +620,18 @@ func main() { if _, err = apiHandler.EnsurePicoChannel(""); err != nil { logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) } - apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, effectiveAllowedCIDRs) - apiHandler.SetServerBindHost(effectiveHost, hostExplicit) + gatewayHostExplicit := hostExplicit && len(runtimeBindings) == 1 + if hostExplicit && len(runtimeBindings) > 1 { + logger.WarnC("web", "Multiple launcher hosts are configured; gateway host override is disabled for this run") + } + apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs) + apiHandler.SetServerBindHost(effectiveHost, gatewayHostExplicit) apiHandler.RegisterRoutes(mux) // Frontend Embedded Assets registerEmbedRoutes(mux) - accessControlledMux, err := middleware.IPAllowlist(effectiveAllowedCIDRs, mux) + accessControlledMux, err := middleware.IPAllowlist(launcherCfg.AllowedCIDRs, mux) if err != nil { logger.Fatalf("Invalid allowed CIDR configuration: %v", err) } @@ -527,11 +652,19 @@ func main() { // Print startup banner and token (console mode only). if enableConsole || debug { + consoleHosts := make([]string, 0, 8) + consoleSeen := make(map[string]struct{}, 8) + for _, binding := range runtimeBindings { + for _, host := range launcherConsoleHosts(binding.mode, binding.host, effectivePublic) { + consoleHosts = appendUniqueHost(consoleHosts, consoleSeen, host) + } + } + fmt.Print(utils.Banner) fmt.Println() fmt.Println(" Open the following URL in your browser:") fmt.Println() - for _, host := range launcherConsoleHosts(effectiveHost, hostExplicit, effectivePublic) { + for _, host := range consoleHosts { fmt.Printf(" >> http://%s <<\n", net.JoinHostPort(host, effectivePort)) } fmt.Println() @@ -558,7 +691,9 @@ func main() { } // Log startup info to file - logger.InfoC("web", fmt.Sprintf("Server will listen on http://%s", net.JoinHostPort(effectiveHost, effectivePort))) + for _, ln := range listeners { + logger.InfoC("web", fmt.Sprintf("Server will listen on http://%s", ln.Addr().String())) + } if isWildcardBindHost(effectiveHost) { if ip := advertiseIPForWildcardBindHost(effectiveHost); ip != "" { logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s", net.JoinHostPort(ip, effectivePort))) @@ -581,14 +716,19 @@ func main() { apiHandler.TryAutoStartGateway() }() - // Start the Server in a goroutine - server = &http.Server{Addr: addr, Handler: handler} - go func() { - logger.InfoC("web", fmt.Sprintf("Server listening on %s", addr)) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Fatalf("Server failed to start: %v", err) - } - }() + // Start the server(s) in goroutines. + servers = make([]*http.Server, 0, len(listeners)) + for _, ln := range listeners { + srv := &http.Server{Handler: handler} + servers = append(servers, srv) + + go func(s *http.Server, l net.Listener) { + logger.InfoC("web", fmt.Sprintf("Server listening on %s", l.Addr().String())) + if serveErr := s.Serve(l); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + logger.Fatalf("Server failed to start on %s: %v", l.Addr().String(), serveErr) + } + }(srv, ln) + } defer shutdownApp() diff --git a/web/backend/main_test.go b/web/backend/main_test.go index 1ac3f0ccf..47df1c269 100644 --- a/web/backend/main_test.go +++ b/web/backend/main_test.go @@ -96,6 +96,41 @@ func TestMaskSecret(t *testing.T) { } } +func TestParseLauncherHostList(t *testing.T) { + tests := []struct { + name string + raw string + want []string + wantErr bool + }{ + {name: "single host", raw: "127.0.0.1", want: []string{"127.0.0.1"}}, + {name: "multiple hosts", raw: "127.0.0.1, 192.168.2.5", want: []string{"127.0.0.1", "192.168.2.5"}}, + {name: "dedupe hosts", raw: "127.0.0.1,127.0.0.1", want: []string{"127.0.0.1"}}, + {name: "reject empty entry", raw: "127.0.0.1, ", wantErr: true}, + {name: "reject empty input", raw: " ", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseLauncherHostList(tt.raw) + if (err != nil) != tt.wantErr { + t.Fatalf("parseLauncherHostList() err = %v, wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + return + } + if len(got) != len(tt.want) { + t.Fatalf("len(got) = %d, want %d (%#v)", len(got), len(tt.want), got) + } + for i := range got { + if got[i] != tt.want[i] { + t.Fatalf("got[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +} + func TestResolveLauncherBindHost(t *testing.T) { tests := []struct { name string @@ -113,7 +148,7 @@ func TestResolveLauncherBindHost(t *testing.T) { host: "0.0.0.0", explicitHost: true, effectivePub: true, - wantHost: resolveDefaultLauncherAnyHost(), + wantHost: "0.0.0.0", wantPublic: false, wantExplicit: true, }, @@ -139,6 +174,24 @@ func TestResolveLauncherBindHost(t *testing.T) { envHost: "0.0.0.0", explicitHost: false, effectivePub: true, + wantHost: "0.0.0.0", + wantPublic: false, + wantExplicit: true, + }, + { + name: "explicit localhost uses adaptive private host", + host: "localhost", + explicitHost: true, + effectivePub: false, + wantHost: resolveDefaultLauncherPrivateHost(), + wantPublic: false, + wantExplicit: true, + }, + { + name: "explicit star uses adaptive any host", + host: "*", + explicitHost: true, + effectivePub: false, wantHost: resolveDefaultLauncherAnyHost(), wantPublic: false, wantExplicit: true, @@ -190,9 +243,33 @@ func TestResolveLauncherBindHost(t *testing.T) { } } +func TestResolveLauncherBindMode(t *testing.T) { + tests := []struct { + name string + rawHost string + hostExplicit bool + effectivePub bool + wantMode launcherBindMode + }{ + {name: "auto private", rawHost: "", hostExplicit: false, effectivePub: false, wantMode: launcherBindModeAutoPrivate}, + {name: "auto public", rawHost: "", hostExplicit: false, effectivePub: true, wantMode: launcherBindModeAutoPublic}, + {name: "explicit localhost", rawHost: "localhost", hostExplicit: true, effectivePub: false, wantMode: launcherBindModeExplicitAdaptiveLocal}, + {name: "explicit star", rawHost: "*", hostExplicit: true, effectivePub: false, wantMode: launcherBindModeExplicitAdaptiveAny}, + {name: "explicit literal", rawHost: "0.0.0.0", hostExplicit: true, effectivePub: false, wantMode: launcherBindModeExplicitLiteral}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveLauncherBindMode(tt.rawHost, tt.hostExplicit, tt.effectivePub); got != tt.wantMode { + t.Fatalf("resolveLauncherBindMode() = %q, want %q", got, tt.wantMode) + } + }) + } +} + func TestLauncherConsoleHosts(t *testing.T) { - t.Run("explicit wildcard dedupes localhost and includes loopback ipv6", func(t *testing.T) { - hosts := launcherConsoleHosts("0.0.0.0", true, false) + t.Run("auto private includes dual loopback hints", func(t *testing.T) { + hosts := launcherConsoleHosts(launcherBindModeAutoPrivate, "localhost", false) seen := make(map[string]bool, len(hosts)) for _, host := range hosts { if seen[host] { @@ -211,8 +288,22 @@ func TestLauncherConsoleHosts(t *testing.T) { } }) + t.Run("explicit ipv4 wildcard excludes ipv6 loopback", func(t *testing.T) { + hosts := launcherConsoleHosts(launcherBindModeExplicitLiteral, "0.0.0.0", false) + seen := make(map[string]bool, len(hosts)) + for _, host := range hosts { + seen[host] = true + } + if seen["::1"] { + t.Fatalf("did not expect ::1 in %#v", hosts) + } + if !seen["127.0.0.1"] { + t.Fatalf("expected 127.0.0.1 in %#v", hosts) + } + }) + t.Run("explicit ipv6 host remains visible", func(t *testing.T) { - hosts := launcherConsoleHosts("::1", true, false) + hosts := launcherConsoleHosts(launcherBindModeExplicitLiteral, "::1", false) if len(hosts) != 2 { t.Fatalf("len(hosts) = %d, want 2 (%#v)", len(hosts), hosts) } diff --git a/web/backend/utils/runtime.go b/web/backend/utils/runtime.go index 7cceff707..9b5516fc1 100644 --- a/web/backend/utils/runtime.go +++ b/web/backend/utils/runtime.go @@ -7,11 +7,91 @@ import ( "os/exec" "path/filepath" "runtime" + "sync" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" ) +var ( + ipFamiliesOnce sync.Once + hasIPv4 bool + hasIPv6 bool +) + +func DetectIPFamilies() (bool, bool) { + ipFamiliesOnce.Do(func() { + if ips, err := net.LookupIP("localhost"); err == nil { + for _, ip := range ips { + if ip == nil { + continue + } + if ip.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + } + + if hasIPv4 && hasIPv6 { + return + } + + if addrs, err := net.InterfaceAddrs(); err == nil { + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok || ipnet.IP == nil { + continue + } + if ipnet.IP.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + } + }) + + return hasIPv4, hasIPv6 +} + +func SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string { + switch { + case hasIPv4 && hasIPv6: + return "localhost" + case hasIPv6: + return "::1" + case hasIPv4: + return "127.0.0.1" + default: + return "localhost" + } +} + +func SelectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string { + switch { + case hasIPv4 && hasIPv6: + return "::" + case hasIPv6: + return "::" + case hasIPv4: + return "0.0.0.0" + default: + return "::" + } +} + +func ResolveAdaptiveLoopbackHost() string { + hasIPv4, hasIPv6 := DetectIPFamilies() + return SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6) +} + +func ResolveAdaptiveAnyHost() string { + hasIPv4, hasIPv6 := DetectIPFamilies() + return SelectAdaptiveAnyHost(hasIPv4, hasIPv6) +} + // GetPicoclawHome returns the picoclaw home directory. // Priority: $PICOCLAW_HOME > ~/.picoclaw func GetPicoclawHome() string { diff --git a/web/backend/utils/runtime_test.go b/web/backend/utils/runtime_test.go new file mode 100644 index 000000000..dbcacdc9a --- /dev/null +++ b/web/backend/utils/runtime_test.go @@ -0,0 +1,59 @@ +package utils + +import "testing" + +func TestSelectAdaptiveLoopbackHost(t *testing.T) { + tests := []struct { + name string + hasIPv4 bool + hasIPv6 bool + want string + }{ + {name: "dual stack", hasIPv4: true, hasIPv6: true, want: "localhost"}, + {name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::1"}, + {name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "127.0.0.1"}, + {name: "fallback", hasIPv4: false, hasIPv6: false, want: "localhost"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := SelectAdaptiveLoopbackHost(tt.hasIPv4, tt.hasIPv6); got != tt.want { + t.Fatalf("SelectAdaptiveLoopbackHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want) + } + }) + } +} + +func TestSelectAdaptiveAnyHost(t *testing.T) { + tests := []struct { + name string + hasIPv4 bool + hasIPv6 bool + want string + }{ + {name: "dual stack", hasIPv4: true, hasIPv6: true, want: "::"}, + {name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::"}, + {name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "0.0.0.0"}, + {name: "fallback", hasIPv4: false, hasIPv6: false, want: "::"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := SelectAdaptiveAnyHost(tt.hasIPv4, tt.hasIPv6); got != tt.want { + t.Fatalf("SelectAdaptiveAnyHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want) + } + }) + } +} + +func TestResolveAdaptiveHosts(t *testing.T) { + loopback := ResolveAdaptiveLoopbackHost() + if loopback == "" { + t.Fatal("ResolveAdaptiveLoopbackHost() returned empty host") + } + + anyHost := ResolveAdaptiveAnyHost() + if anyHost == "" { + t.Fatal("ResolveAdaptiveAnyHost() returned empty host") + } +}