diff --git a/cmd/picoclaw/internal/gateway/command.go b/cmd/picoclaw/internal/gateway/command.go index 7fa588c5c..7dd03b495 100644 --- a/cmd/picoclaw/internal/gateway/command.go +++ b/cmd/picoclaw/internal/gateway/command.go @@ -2,19 +2,34 @@ package gateway import ( "fmt" + "os" "github.com/spf13/cobra" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/gateway" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/pkg/utils" ) +func resolveGatewayHostOverride(explicit bool, host string) (string, error) { + if !explicit { + return "", nil + } + normalized, err := netbind.NormalizeHostInput(host) + if err != nil { + return "", fmt.Errorf("invalid --host value: %w", err) + } + return normalized, nil +} + func NewGatewayCommand() *cobra.Command { var debug bool var noTruncate bool var allowEmpty bool + var host string cmd := &cobra.Command{ Use: "gateway", @@ -33,7 +48,25 @@ func NewGatewayCommand() *cobra.Command { return nil }, - RunE: func(_ *cobra.Command, _ []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { + resolvedHost, err := resolveGatewayHostOverride(cmd.Flags().Changed("host"), host) + if err != nil { + return err + } + if resolvedHost != "" { + prevHost, hadPrev := os.LookupEnv(config.EnvGatewayHost) + if err := os.Setenv(config.EnvGatewayHost, resolvedHost); err != nil { + return fmt.Errorf("failed to set %s: %w", config.EnvGatewayHost, err) + } + defer func() { + if hadPrev { + _ = os.Setenv(config.EnvGatewayHost, prevHost) + return + } + _ = os.Unsetenv(config.EnvGatewayHost) + }() + } + return gateway.Run(debug, internal.GetPicoclawHome(), internal.GetConfigPath(), allowEmpty) }, } @@ -47,6 +80,12 @@ func NewGatewayCommand() *cobra.Command { false, "Continue starting even when no default model is configured", ) + cmd.Flags().StringVar( + &host, + "host", + "", + "Host address for gateway binding (overrides gateway.host for this run)", + ) return cmd } diff --git a/cmd/picoclaw/internal/gateway/command_test.go b/cmd/picoclaw/internal/gateway/command_test.go index 839a7315a..825369abb 100644 --- a/cmd/picoclaw/internal/gateway/command_test.go +++ b/cmd/picoclaw/internal/gateway/command_test.go @@ -29,4 +29,38 @@ func TestNewGatewayCommand(t *testing.T) { assert.True(t, cmd.HasFlags()) assert.NotNil(t, cmd.Flags().Lookup("debug")) assert.NotNil(t, cmd.Flags().Lookup("allow-empty")) + assert.NotNil(t, cmd.Flags().Lookup("host")) +} + +func TestResolveGatewayHostOverride(t *testing.T) { + tests := []struct { + name string + explicit bool + host string + wantHost string + wantErr bool + }{ + {name: "implicit empty host is allowed", explicit: false, host: "", wantHost: "", wantErr: false}, + {name: "explicit empty host rejected", explicit: true, host: " ", wantHost: "", wantErr: true}, + {name: "explicit localhost kept", explicit: true, host: " localhost ", wantHost: "localhost", wantErr: false}, + { + name: "explicit multi host normalized", + explicit: true, + host: " [::1] , 127.0.0.1 ", + wantHost: "::1,127.0.0.1", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveGatewayHostOverride(tt.explicit, tt.host) + if (err != nil) != tt.wantErr { + t.Fatalf("resolveGatewayHostOverride() err = %v, wantErr %t", err, tt.wantErr) + } + if got != tt.wantHost { + t.Fatalf("resolveGatewayHostOverride() host = %q, want %q", got, tt.wantHost) + } + }) + } } diff --git a/config/config.example.json b/config/config.example.json index 2d2d38496..d83c31076 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -472,7 +472,7 @@ }, "gateway": { "_comment": "Default log level is set to 'fatal'. Other available options are 'debug', 'info', 'warn' and 'error'.", - "host": "127.0.0.1", + "host": "localhost", "port": 18790, "hot_reload": false, "log_level": "fatal" diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 4d8e47c0f..928676cbc 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "math" + "net" "net/http" "sort" "sync" @@ -86,6 +87,7 @@ type Manager struct { dispatchTask *asyncTask mux *dynamicServeMux httpServer *http.Server + httpListeners []net.Listener mu sync.RWMutex placeholders sync.Map // "channel:chatID" → placeholderID (string) typingStops sync.Map // "channel:chatID" → func() @@ -474,6 +476,12 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error { // It registers health endpoints from the health server and discovers channels // that implement WebhookHandler and/or HealthChecker to register their handlers. func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { + m.SetupHTTPServerListeners(nil, addr, healthServer) +} + +// SetupHTTPServerListeners creates a shared HTTP server on pre-opened listeners. +// When listeners is empty it falls back to Addr-based ListenAndServe behavior. +func (m *Manager) SetupHTTPServerListeners(listeners []net.Listener, addr string, healthServer *health.Server) { m.mux = newDynamicServeMux() // Register health endpoints @@ -490,6 +498,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } + m.httpListeners = append([]net.Listener(nil), listeners...) } // registerHTTPHandlersLocked registers webhook and health-check handlers for @@ -619,16 +628,33 @@ func (m *Manager) StartAll(ctx context.Context) error { // Start shared HTTP server if configured if m.httpServer != nil { - go func() { - logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ - "addr": m.httpServer.Addr, - }) - if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.FatalCF("channels", "Shared HTTP server error", map[string]any{ - "error": err.Error(), - }) + if len(m.httpListeners) > 0 { + for _, listener := range m.httpListeners { + ln := listener + go func() { + logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ + "addr": ln.Addr().String(), + }) + if err := m.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + logger.FatalCF("channels", "Shared HTTP server error", map[string]any{ + "addr": ln.Addr().String(), + "error": err.Error(), + }) + } + }() } - }() + } else { + go func() { + logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ + "addr": m.httpServer.Addr, + }) + if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.FatalCF("channels", "Shared HTTP server error", map[string]any{ + "error": err.Error(), + }) + } + }() + } } logger.InfoCF("channels", "Channel startup completed", map[string]any{ @@ -655,6 +681,7 @@ func (m *Manager) StopAll(ctx context.Context) error { }) } m.httpServer = nil + m.httpListeners = nil } // Cancel dispatcher diff --git a/pkg/config/config.go b/pkg/config/config.go index 683f68951..c928e8c5f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1136,6 +1136,8 @@ func LoadConfig(path string) (*Config, error) { applyLegacyBindingsMigration(data, cfg) + gatewayHostBeforeEnv := cfg.Gateway.Host + if err = env.Parse(cfg); err != nil { return nil, err } @@ -1144,6 +1146,10 @@ func LoadConfig(path string) (*Config, error) { if err = InitChannelList(cfg.Channels); err != nil { return nil, err } + cfg.Gateway.Host, err = resolveGatewayHostFromEnv(gatewayHostBeforeEnv) + if err != nil { + return nil, fmt.Errorf("invalid gateway host: %w", err) + } // Expand multi-key configs into separate entries for key-level failover cfg.ModelList = expandMultiKeyModels(cfg.ModelList) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index ce69b4c98..719bbf0c6 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -503,7 +503,7 @@ func TestDefaultConfig_Temperature(t *testing.T) { func TestDefaultConfig_Gateway(t *testing.T) { cfg := DefaultConfig() - if cfg.Gateway.Host != "127.0.0.1" { + if cfg.Gateway.Host != "localhost" { t.Error("Gateway host should have default value") } if cfg.Gateway.Port == 0 { @@ -739,7 +739,7 @@ func TestConfig_Complete(t *testing.T) { if cfg.Agents.Defaults.MaxToolIterations == 0 { t.Error("MaxToolIterations should not be zero") } - if cfg.Gateway.Host != "127.0.0.1" { + if cfg.Gateway.Host != "localhost" { t.Error("Gateway host should have default value") } if cfg.Gateway.Port == 0 { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index d67b7a668..365bc0808 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -259,7 +259,7 @@ func DefaultConfig() *Config { }, }, Gateway: GatewayConfig{ - Host: "127.0.0.1", + Host: "localhost", Port: 18790, HotReload: false, LogLevel: DefaultGatewayLogLevel, diff --git a/pkg/config/envkeys.go b/pkg/config/envkeys.go index 615769d3c..5a2590299 100644 --- a/pkg/config/envkeys.go +++ b/pkg/config/envkeys.go @@ -39,7 +39,7 @@ const ( EnvBinary = "PICOCLAW_BINARY" // EnvGatewayHost overrides the host address for the gateway server. - // Default: "127.0.0.1" + // Default: "localhost" EnvGatewayHost = "PICOCLAW_GATEWAY_HOST" ) diff --git a/pkg/config/gateway.go b/pkg/config/gateway.go index e9f4085d3..392a4ca5e 100644 --- a/pkg/config/gateway.go +++ b/pkg/config/gateway.go @@ -3,8 +3,10 @@ package config import ( "encoding/json" "os" + "strings" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" ) const DefaultGatewayLogLevel = "warn" @@ -49,6 +51,31 @@ func EffectiveGatewayLogLevel(cfg *Config) string { return normalizeGatewayLogLevel(cfg.Gateway.LogLevel) } +func resolveGatewayHostFromEnv(baseHost string) (string, error) { + envHost, ok := os.LookupEnv(EnvGatewayHost) + if !ok { + return normalizeGatewayHostInput(baseHost) + } + + envHost = strings.TrimSpace(envHost) + if envHost == "" { + return normalizeGatewayHostInput(baseHost) + } + + return normalizeGatewayHostInput(envHost) +} + +func normalizeGatewayHostInput(host string) (string, error) { + host = strings.TrimSpace(host) + if host == "" { + host = strings.TrimSpace(DefaultConfig().Gateway.Host) + } + if host == "" { + host = "localhost" + } + return netbind.NormalizeHostInput(host) +} + // ResolveGatewayLogLevel reads the configured gateway log level without triggering // the full config loader, so startup code can apply logging before config load logs run. // The PICOCLAW_LOG_LEVEL environment variable overrides the file value. diff --git a/pkg/config/gateway_host_env_test.go b/pkg/config/gateway_host_env_test.go new file mode 100644 index 000000000..40fabb1a3 --- /dev/null +++ b/pkg/config/gateway_host_env_test.go @@ -0,0 +1,98 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +func writeGatewayHostTestConfig(t *testing.T, host string) string { + t.Helper() + + configPath := filepath.Join(t.TempDir(), "config.json") + raw := fmt.Sprintf(`{"version":2,"gateway":{"host":%q,"port":18790}}`, host) + if err := os.WriteFile(configPath, []byte(raw), 0o600); err != nil { + t.Fatalf("WriteFile(configPath): %v", err) + } + return configPath +} + +func TestLoadConfig_GatewayHostEnvTrimmed(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, "127.0.0.1") + t.Setenv(EnvGatewayHost, " ::1 ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Gateway.Host != "::1" { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1") + } +} + +func TestLoadConfig_GatewayHostBlankEnvFallsBackToConfigHost(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, " localhost ") + t.Setenv(EnvGatewayHost, " ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + want, err := normalizeGatewayHostInput("localhost") + if err != nil { + t.Fatalf("normalizeGatewayHostInput() error: %v", err) + } + if cfg.Gateway.Host != want { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want) + } +} + +func TestLoadConfig_GatewayHostBlankEnvAndConfigFallsBackToDefault(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, " ") + t.Setenv(EnvGatewayHost, " ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + + defaultHost, err := normalizeGatewayHostInput(DefaultConfig().Gateway.Host) + if err != nil { + t.Fatalf("normalizeGatewayHostInput() error: %v", err) + } + if cfg.Gateway.Host != defaultHost { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, defaultHost) + } +} + +func TestLoadConfig_GatewayHostEnvPreservesExplicitWildcardHost(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, "localhost") + t.Setenv(EnvGatewayHost, " 0.0.0.0 ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + + want, err := normalizeGatewayHostInput("0.0.0.0") + if err != nil { + t.Fatalf("normalizeGatewayHostInput() error: %v", err) + } + if cfg.Gateway.Host != want { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want) + } +} + +func TestLoadConfig_GatewayHostEnvNormalizesMultiHostInput(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, "localhost") + t.Setenv(EnvGatewayHost, " [::1] , 127.0.0.1 , ::1 ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Gateway.Host != "::1,127.0.0.1" { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1,127.0.0.1") + } +} diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index a5afb0eb8..039f45075 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -3,10 +3,12 @@ package gateway import ( "context" "fmt" + "net" "os" "os/signal" "path/filepath" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -42,6 +44,7 @@ import ( "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/pkg/pid" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" @@ -159,13 +162,30 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr logger.Infof("Log level set to %q", effectiveLogLevel) } + bindPlan, listenResult, err := openGatewayListeners(cfg.Gateway.Host, cfg.Gateway.Port) + if err != nil { + return fmt.Errorf("error opening gateway listeners: %w", err) + } + // Enforce singleton: write PID file with generated token. - pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port) + pidData, err := pid.WritePidFile(homePath, bindPlan.ProbeHost, cfg.Gateway.Port) if err != nil { logger.Warnf("write pid file failed: %v", err) + for _, ln := range listenResult.Listeners { + _ = ln.Close() + } return fmt.Errorf("singleton check failed: %w", err) } defer pid.RemovePidFile(homePath) + closeListeners := true + defer func() { + if !closeListeners { + return + } + for _, ln := range listenResult.Listeners { + _ = ln.Close() + } + }() provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup) if err != nil { @@ -193,10 +213,11 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr "skills_available": skillsInfo["available"], }) - runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token) + runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token, listenResult) if err != nil { return err } + closeListeners = false // Setup manual reload channel for /reload endpoint manualReloadChan := make(chan struct{}, 1) @@ -217,7 +238,9 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr runningServices.HealthServer.SetReloadFunc(reloadTrigger) agentLoop.SetReloadFunc(reloadTrigger) - fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) + for _, bindHost := range listenResult.BindHosts { + fmt.Printf("✓ Gateway started on %s\n", net.JoinHostPort(bindHost, strconv.Itoa(cfg.Gateway.Port))) + } fmt.Println("Press Ctrl+C to stop") ctx, cancel := context.WithCancel(context.Background()) @@ -320,6 +343,7 @@ func setupAndStartServices( agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, authToken string, + listenResult netbind.OpenResult, ) (*services, error) { runningServices := &services{} @@ -390,10 +414,20 @@ func setupAndStartServices( fmt.Println("⚠ Warning: No channels enabled") } - addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) runningServices.authToken = authToken - runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken) - runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) + runningServices.HealthServer = health.NewServer(listenResult.ProbeHost, cfg.Gateway.Port, authToken) + + var listenAddr string + if len(listenResult.Listeners) > 0 { + listenAddr = listenResult.Listeners[0].Addr().String() + } else { + listenAddr = net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port)) + } + runningServices.ChannelManager.SetupHTTPServerListeners( + listenResult.Listeners, + listenAddr, + runningServices.HealthServer, + ) if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil { return nil, fmt.Errorf("error starting channels: %w", err) @@ -409,10 +443,10 @@ func setupAndStartServices( voiceAgent.Start(vaCtx) } + healthAddr := net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port)) fmt.Printf( - "✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n", - cfg.Gateway.Host, - cfg.Gateway.Port, + "✓ Health endpoints available at http://%s/health, /ready and /reload (POST)\n", + healthAddr, ) stateManager := state.NewManager(cfg.WorkspacePath()) diff --git a/pkg/gateway/listen.go b/pkg/gateway/listen.go new file mode 100644 index 000000000..99be63096 --- /dev/null +++ b/pkg/gateway/listen.go @@ -0,0 +1,21 @@ +package gateway + +import ( + "strconv" + + "github.com/sipeed/picoclaw/pkg/netbind" +) + +func openGatewayListeners(host string, port int) (netbind.Plan, netbind.OpenResult, error) { + plan, err := netbind.BuildPlan(host, netbind.DefaultLoopback) + if err != nil { + return netbind.Plan{}, netbind.OpenResult{}, err + } + + result, err := netbind.OpenPlan(plan, strconv.Itoa(port)) + if err != nil { + return netbind.Plan{}, netbind.OpenResult{}, err + } + + return plan, result, nil +} diff --git a/pkg/gateway/listen_test.go b/pkg/gateway/listen_test.go new file mode 100644 index 000000000..9b932f852 --- /dev/null +++ b/pkg/gateway/listen_test.go @@ -0,0 +1,130 @@ +package gateway + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "strconv" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/netbind" +) + +func TestOpenGatewayListeners_HonorsIPv6OnlyHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv6 { + t.Skip("IPv6 is unavailable in this environment") + } + + _, result, err := openGatewayListeners("::", 0) + if err != nil { + t.Fatalf("openGatewayListeners() error = %v", err) + } + startGatewayTestHTTPServer(t, result.Listeners) + port := mustGatewayAtoi(t, result.Port) + + requireGatewayHTTPReachable(t, "::1", port) + if hasIPv4 { + requireGatewayHTTPUnreachable(t, "127.0.0.1", port) + } +} + +func TestOpenGatewayListeners_SupportsExplicitMultiHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + _, result, err := openGatewayListeners("127.0.0.1,::1", 0) + if err != nil { + t.Fatalf("openGatewayListeners() error = %v", err) + } + startGatewayTestHTTPServer(t, result.Listeners) + port := mustGatewayAtoi(t, result.Port) + + requireGatewayHTTPReachable(t, "127.0.0.1", port) + requireGatewayHTTPReachable(t, "::1", port) +} + +func startGatewayTestHTTPServer(t *testing.T, listeners []net.Listener) { + t.Helper() + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + }), + } + + errCh := make(chan error, len(listeners)) + for _, listener := range listeners { + ln := listener + go func() { + errCh <- server.Serve(ln) + }() + } + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(ctx) + for range listeners { + err := <-errCh + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("server.Serve() error = %v", err) + } + } + }) +} + +func requireGatewayHTTPReachable(t *testing.T, host string, port int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for { + err := gatewayHTTPGet(host, port) + if err == nil { + return + } + if time.Now().After(deadline) { + t.Fatalf("expected %s:%d to be reachable: %v", host, port, err) + } + time.Sleep(50 * time.Millisecond) + } +} + +func requireGatewayHTTPUnreachable(t *testing.T, host string, port int) { + t.Helper() + if err := gatewayHTTPGet(host, port); err == nil { + t.Fatalf("expected %s:%d to be unreachable", host, port) + } +} + +func gatewayHTTPGet(host string, port int) error { + client := &http.Client{ + Timeout: 300 * time.Millisecond, + Transport: &http.Transport{ + Proxy: nil, + }, + } + + resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} + +func mustGatewayAtoi(t *testing.T, value string) int { + t.Helper() + n, err := strconv.Atoi(value) + if err != nil { + t.Fatalf("Atoi(%q) error = %v", value, err) + } + return n +} diff --git a/pkg/health/server.go b/pkg/health/server.go index a152d8ab1..22346490c 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -4,10 +4,11 @@ import ( "context" "crypto/subtle" "encoding/json" - "fmt" "maps" + "net" "net/http" "os" + "strconv" "sync" "time" ) @@ -49,7 +50,7 @@ func NewServer(host string, port int, token string) *Server { mux.HandleFunc("/ready", s.readyHandler) mux.HandleFunc("/reload", s.reloadHandler) - addr := fmt.Sprintf("%s:%d", host, port) + addr := net.JoinHostPort(host, strconv.Itoa(port)) s.server = &http.Server{ Addr: addr, Handler: mux, diff --git a/pkg/health/server_test.go b/pkg/health/server_test.go index c4982fff9..31dbc37c0 100644 --- a/pkg/health/server_test.go +++ b/pkg/health/server_test.go @@ -305,6 +305,16 @@ func TestNewServer(t *testing.T) { } } +func TestNewServer_IPv6ListenAddrFormatting(t *testing.T) { + s := NewServer("::", 18790, "") + if s.server == nil { + t.Fatal("server should be initialized") + } + if s.server.Addr != "[::]:18790" { + t.Fatalf("server.Addr = %q, want %q", s.server.Addr, "[::]:18790") + } +} + func TestStartContext_Cancellation(t *testing.T) { s := NewServer("127.0.0.1", 0, "") diff --git a/pkg/netbind/netbind.go b/pkg/netbind/netbind.go new file mode 100644 index 000000000..ae6cacf49 --- /dev/null +++ b/pkg/netbind/netbind.go @@ -0,0 +1,606 @@ +package netbind + +import ( + "context" + "errors" + "fmt" + "net" + "strconv" + "strings" + "sync" +) + +type DefaultMode int + +const ( + DefaultLoopback DefaultMode = iota + DefaultAny +) + +type groupKind int + +const ( + groupAdaptiveLoopback groupKind = iota + groupAdaptiveAny + groupExact +) + +type exactBinding struct { + host string + network string + v6Only bool +} + +type bindGroup struct { + kind groupKind + allowIPv4 bool + allowIPv6 bool + exact exactBinding +} + +type Plan struct { + groups []bindGroup + ProbeHost string +} + +type OpenResult struct { + Listeners []net.Listener + BindHosts []string + Port string + ProbeHost string +} + +type tokenKind int + +const ( + tokenName tokenKind = iota + tokenLocalhost + tokenStar + tokenIPv4 + tokenIPv6 + tokenIPv4Any + tokenIPv6Any +) + +type hostToken struct { + kind tokenKind + canonical string + key string +} + +var ( + ipFamiliesOnce sync.Once + hasIPv4 bool + hasIPv6 bool +) + +func DetectIPFamilies() (bool, bool) { + ipFamiliesOnce.Do(func() { + if ips, err := net.LookupIP("localhost"); err == nil { + for _, ip := range ips { + if ip == nil { + continue + } + if ip.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + } + + if hasIPv4 && hasIPv6 { + return + } + + if addrs, err := net.InterfaceAddrs(); err == nil { + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok || ipnet.IP == nil { + continue + } + if ipnet.IP.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + } + }) + + return hasIPv4, hasIPv6 +} + +func SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string { + switch { + case hasIPv4 && hasIPv6: + return "localhost" + case hasIPv6: + return "::1" + case hasIPv4: + return "127.0.0.1" + default: + return "localhost" + } +} + +func SelectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string { + switch { + case hasIPv4 && hasIPv6: + return "::" + case hasIPv6: + return "::" + case hasIPv4: + return "0.0.0.0" + default: + return "::" + } +} + +func ResolveAdaptiveLoopbackHost() string { + hasIPv4, hasIPv6 := DetectIPFamilies() + return SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6) +} + +func ResolveAdaptiveAnyHost() string { + hasIPv4, hasIPv6 := DetectIPFamilies() + return SelectAdaptiveAnyHost(hasIPv4, hasIPv6) +} + +func IsLoopbackHost(host string) bool { + host = strings.TrimSpace(host) + if host == "" { + return false + } + if strings.EqualFold(host, "localhost") { + return true + } + ip := net.ParseIP(strings.Trim(host, "[]")) + return ip != nil && ip.IsLoopback() +} + +func IsUnspecifiedHost(host string) bool { + host = strings.TrimSpace(host) + if host == "" { + return false + } + ip := net.ParseIP(strings.Trim(host, "[]")) + return ip != nil && ip.IsUnspecified() +} + +func NormalizeHostInput(raw string) (string, error) { + tokens, err := parseHostTokens(raw) + if err != nil { + return "", err + } + + parts := make([]string, 0, len(tokens)) + for _, token := range tokens { + parts = append(parts, token.canonical) + } + return strings.Join(parts, ","), nil +} + +func BuildPlan(raw string, defaultMode DefaultMode) (Plan, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return buildDefaultPlan(defaultMode), nil + } + + tokens, err := parseHostTokens(raw) + if err != nil { + return Plan{}, err + } + + for _, token := range tokens { + if token.kind == tokenStar { + return Plan{ + groups: []bindGroup{{kind: groupAdaptiveAny}}, + ProbeHost: ResolveAdaptiveLoopbackHost(), + }, nil + } + } + + hasIPv4Any := false + hasIPv6Any := false + for _, token := range tokens { + switch token.kind { + case tokenIPv4Any: + hasIPv4Any = true + case tokenIPv6Any: + hasIPv6Any = true + } + } + + allowLocalhostIPv4 := !hasIPv4Any + allowLocalhostIPv6 := !hasIPv6Any + + groups := make([]bindGroup, 0, len(tokens)) + seenExact := make(map[string]struct{}, len(tokens)) + addedLocalhost := false + + for _, token := range tokens { + switch token.kind { + case tokenLocalhost: + if addedLocalhost || (!allowLocalhostIPv4 && !allowLocalhostIPv6) { + continue + } + groups = append(groups, bindGroup{ + kind: groupAdaptiveLoopback, + allowIPv4: allowLocalhostIPv4, + allowIPv6: allowLocalhostIPv6, + }) + addedLocalhost = true + case tokenIPv4Any: + key := "exact:tcp4:0.0.0.0" + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: "0.0.0.0", + network: "tcp4", + }, + }) + case tokenIPv6Any: + key := "exact:tcp6:::" + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: "::", + network: "tcp6", + v6Only: true, + }, + }) + case tokenIPv4: + if hasIPv4Any { + continue + } + key := "exact:tcp4:" + strings.ToLower(token.canonical) + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: token.canonical, + network: "tcp4", + }, + }) + case tokenIPv6: + if hasIPv6Any { + continue + } + key := "exact:tcp6:" + strings.ToLower(token.canonical) + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: token.canonical, + network: "tcp6", + v6Only: true, + }, + }) + case tokenName: + key := "exact:tcp:" + token.key + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: token.canonical, + network: "tcp", + }, + }) + } + } + + plan := Plan{groups: groups} + plan.ProbeHost = probeHostForGroups(groups) + return plan, nil +} + +func OpenPlan(plan Plan, port string) (OpenResult, error) { + if port == "" { + return OpenResult{}, errors.New("port cannot be empty") + } + + selectedPort := port + listeners := make([]net.Listener, 0, len(plan.groups)) + bindHosts := make([]string, 0, len(plan.groups)) + bindSeen := make(map[string]struct{}, len(plan.groups)) + + closeAll := func() { + for _, ln := range listeners { + _ = ln.Close() + } + } + + for _, group := range plan.groups { + groupListeners, groupHosts, actualPort, err := openGroup(group, selectedPort) + if err != nil { + closeAll() + return OpenResult{}, err + } + if selectedPort == "0" && actualPort != "" { + selectedPort = actualPort + } + listeners = append(listeners, groupListeners...) + for _, host := range groupHosts { + key := strings.ToLower(host) + if _, ok := bindSeen[key]; ok { + continue + } + bindSeen[key] = struct{}{} + bindHosts = append(bindHosts, host) + } + } + + return OpenResult{ + Listeners: listeners, + BindHosts: bindHosts, + Port: selectedPort, + ProbeHost: plan.ProbeHost, + }, nil +} + +func buildDefaultPlan(defaultMode DefaultMode) Plan { + switch defaultMode { + case DefaultAny: + return Plan{ + groups: []bindGroup{{kind: groupAdaptiveAny}}, + ProbeHost: ResolveAdaptiveLoopbackHost(), + } + default: + return Plan{ + groups: []bindGroup{{ + kind: groupAdaptiveLoopback, + allowIPv4: true, + allowIPv6: true, + }}, + ProbeHost: ResolveAdaptiveLoopbackHost(), + } + } +} + +func probeHostForGroups(groups []bindGroup) string { + hasIPv4Any := false + hasIPv6Any := false + for _, group := range groups { + if group.kind == groupAdaptiveLoopback { + switch { + case group.allowIPv4 && group.allowIPv6: + return ResolveAdaptiveLoopbackHost() + case group.allowIPv6: + return "::1" + case group.allowIPv4: + return "127.0.0.1" + } + } + if group.kind == groupAdaptiveAny { + return ResolveAdaptiveLoopbackHost() + } + if group.kind != groupExact { + continue + } + switch group.exact.host { + case "0.0.0.0": + hasIPv4Any = true + case "::": + hasIPv6Any = true + } + } + + switch { + case hasIPv4Any && hasIPv6Any: + return ResolveAdaptiveLoopbackHost() + case hasIPv6Any: + return "::1" + case hasIPv4Any: + return "127.0.0.1" + } + + for _, group := range groups { + if group.kind == groupExact { + return group.exact.host + } + } + return ResolveAdaptiveLoopbackHost() +} + +func parseHostTokens(raw string) ([]hostToken, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, errors.New("host cannot be empty") + } + + parts := strings.Split(raw, ",") + tokens := make([]hostToken, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + for _, part := range parts { + token, err := parseHostToken(part) + if err != nil { + return nil, err + } + if _, ok := seen[token.key]; ok { + continue + } + seen[token.key] = struct{}{} + tokens = append(tokens, token) + } + + if len(tokens) == 0 { + return nil, errors.New("host cannot be empty") + } + + return tokens, nil +} + +func parseHostToken(raw string) (hostToken, error) { + host := strings.TrimSpace(raw) + if host == "" { + return hostToken{}, errors.New("host list contains an empty entry") + } + + if host == "*" { + return hostToken{kind: tokenStar, canonical: "*", key: "*"}, nil + } + if strings.EqualFold(host, "localhost") { + return hostToken{kind: tokenLocalhost, canonical: "localhost", key: "localhost"}, nil + } + + trimmed := strings.Trim(host, "[]") + if ip := net.ParseIP(trimmed); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + canonical := ip4.String() + kind := tokenIPv4 + if ip4.IsUnspecified() { + kind = tokenIPv4Any + } + return hostToken{kind: kind, canonical: canonical, key: canonical}, nil + } + + canonical := ip.String() + kind := tokenIPv6 + if ip.IsUnspecified() { + kind = tokenIPv6Any + } + return hostToken{kind: kind, canonical: canonical, key: strings.ToLower(canonical)}, nil + } + + return hostToken{ + kind: tokenName, + canonical: host, + key: strings.ToLower(host), + }, nil +} + +func openGroup(group bindGroup, port string) ([]net.Listener, []string, string, error) { + switch group.kind { + case groupAdaptiveLoopback: + return openAdaptiveLoopbackGroup(group.allowIPv6, group.allowIPv4, port) + case groupAdaptiveAny: + return openAdaptiveAnyGroup(port) + case groupExact: + ln, actualPort, err := openExactListener(group.exact, port) + if err != nil { + return nil, nil, "", err + } + return []net.Listener{ln}, []string{group.exact.host}, actualPort, nil + default: + return nil, nil, "", fmt.Errorf("unsupported bind group kind: %d", group.kind) + } +} + +func openAdaptiveLoopbackGroup(allowIPv6, allowIPv4 bool, port string) ([]net.Listener, []string, string, error) { + if allowIPv6 && allowIPv4 { + if ln6, actualPort, err6 := openExactListener( + exactBinding{host: "::1", network: "tcp6", v6Only: true}, + port, + ); err6 == nil { + if ln4, _, err4 := openExactListener( + exactBinding{host: "127.0.0.1", network: "tcp4"}, + actualPort, + ); err4 == nil { + return []net.Listener{ln6, ln4}, []string{"::1", "127.0.0.1"}, actualPort, nil + } + _ = ln6.Close() + } + } + + if allowIPv6 { + ln6, actualPort, err := openExactListener(exactBinding{host: "::1", network: "tcp6", v6Only: true}, port) + if err == nil { + return []net.Listener{ln6}, []string{"::1"}, actualPort, nil + } + } + + if allowIPv4 { + ln4, actualPort, err := openExactListener(exactBinding{host: "127.0.0.1", network: "tcp4"}, port) + if err == nil { + return []net.Listener{ln4}, []string{"127.0.0.1"}, actualPort, nil + } + } + + return nil, nil, "", fmt.Errorf("failed to open adaptive localhost listener on port %s", port) +} + +func openAdaptiveAnyGroup(port string) ([]net.Listener, []string, string, error) { + hasIPv4, hasIPv6 := DetectIPFamilies() + + if hasIPv4 && hasIPv6 { + if ln6, actualPort, err6 := openExactListener( + exactBinding{host: "::", network: "tcp6", v6Only: true}, + port, + ); err6 == nil { + if ln4, _, err4 := openExactListener( + exactBinding{host: "0.0.0.0", network: "tcp4"}, + actualPort, + ); err4 == nil { + return []net.Listener{ln6, ln4}, []string{"::", "0.0.0.0"}, actualPort, nil + } + _ = ln6.Close() + } + } + + if hasIPv6 { + ln6, actualPort, err := openExactListener(exactBinding{host: "::", network: "tcp6", v6Only: true}, port) + if err == nil { + return []net.Listener{ln6}, []string{"::"}, actualPort, nil + } + } + + if hasIPv4 { + ln4, actualPort, err := openExactListener(exactBinding{host: "0.0.0.0", network: "tcp4"}, port) + if err == nil { + return []net.Listener{ln4}, []string{"0.0.0.0"}, actualPort, nil + } + } + + return nil, nil, "", fmt.Errorf("failed to open adaptive any-host listener on port %s", port) +} + +func openExactListener(binding exactBinding, port string) (net.Listener, string, error) { + listenConfig := net.ListenConfig{} + if binding.network == "tcp6" && binding.v6Only { + listenConfig.Control = applyIPv6OnlyControl(true) + } + + ln, err := listenConfig.Listen(context.Background(), binding.network, net.JoinHostPort(binding.host, port)) + if err != nil { + return nil, "", err + } + + actualPort, err := listenerPort(ln) + if err != nil { + _ = ln.Close() + return nil, "", err + } + + return ln, actualPort, nil +} + +func listenerPort(ln net.Listener) (string, error) { + addr, ok := ln.Addr().(*net.TCPAddr) + if ok { + return strconv.Itoa(addr.Port), nil + } + + _, port, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + return "", err + } + return port, nil +} diff --git a/pkg/netbind/netbind_test.go b/pkg/netbind/netbind_test.go new file mode 100644 index 000000000..20b7ff141 --- /dev/null +++ b/pkg/netbind/netbind_test.go @@ -0,0 +1,280 @@ +package netbind + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "strconv" + "testing" + "time" +) + +func TestNormalizeHostInput(t *testing.T) { + tests := []struct { + name string + raw string + want string + wantErr bool + }{ + {name: "single host", raw: "127.0.0.1", want: "127.0.0.1"}, + {name: "trim and dedupe", raw: " [::1] , ::1 , 127.0.0.1 ", want: "::1,127.0.0.1"}, + {name: "star preserved", raw: "*,127.0.0.1", want: "*,127.0.0.1"}, + {name: "reject empty", raw: "127.0.0.1, ", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeHostInput(tt.raw) + if (err != nil) != tt.wantErr { + t.Fatalf("NormalizeHostInput() err = %v, wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + return + } + if got != tt.want { + t.Fatalf("NormalizeHostInput() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestBuildPlan_DefaultAnyUsesLoopbackProbe(t *testing.T) { + plan, err := BuildPlan("", DefaultAny) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + if plan.ProbeHost != ResolveAdaptiveLoopbackHost() { + t.Fatalf("ProbeHost = %q, want %q", plan.ProbeHost, ResolveAdaptiveLoopbackHost()) + } +} + +func TestOpenPlan_LocalhostSupportsLoopbackCommunication(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + + plan, err := BuildPlan("localhost", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + if hasIPv6 { + requireHTTPReachable(t, "::1", port) + } + if hasIPv4 { + requireHTTPReachable(t, "127.0.0.1", port) + } +} + +func TestOpenPlan_DefaultAnySupportsDualStackLoopback(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + + plan, err := BuildPlan("", DefaultAny) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + if hasIPv6 { + requireHTTPReachable(t, "::1", port) + } + if hasIPv4 { + requireHTTPReachable(t, "127.0.0.1", port) + } + + switch { + case hasIPv4 && hasIPv6: + if len(result.BindHosts) != 2 { + t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts) + } + case hasIPv6 || hasIPv4: + if len(result.BindHosts) != 1 { + t.Fatalf("len(BindHosts) = %d, want 1 (%#v)", len(result.BindHosts), result.BindHosts) + } + } +} + +func TestOpenPlan_ExplicitIPv6AnyIsIPv6Only(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv6 { + t.Skip("IPv6 is unavailable in this environment") + } + + plan, err := BuildPlan("::", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "::1", port) + if hasIPv4 { + requireHTTPUnreachable(t, "127.0.0.1", port) + } +} + +func TestOpenPlan_ExplicitIPv4AnyIsIPv4Only(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv4 { + t.Skip("IPv4 is unavailable in this environment") + } + + plan, err := BuildPlan("0.0.0.0", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "127.0.0.1", port) + if hasIPv6 { + requireHTTPUnreachable(t, "::1", port) + } +} + +func TestOpenPlan_MultiHostSupportsExplicitIPv4AndIPv6(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + plan, err := BuildPlan("127.0.0.1,::1", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "127.0.0.1", port) + requireHTTPReachable(t, "::1", port) +} + +func TestOpenPlan_WildcardRulesKeepIPv4AndIPv6AnyHosts(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + plan, err := BuildPlan("::,::1,0.0.0.0,127.0.0.1", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "127.0.0.1", port) + requireHTTPReachable(t, "::1", port) + if len(result.BindHosts) != 2 { + t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts) + } +} + +func startTestHTTPServer(t *testing.T, listeners []net.Listener) { + t.Helper() + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + }), + } + + errCh := make(chan error, len(listeners)) + for _, listener := range listeners { + ln := listener + go func() { + errCh <- server.Serve(ln) + }() + } + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(ctx) + for range listeners { + err := <-errCh + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("server.Serve() error = %v", err) + } + } + }) +} + +func requireHTTPReachable(t *testing.T, host string, port int) { + t.Helper() + + deadline := time.Now().Add(2 * time.Second) + for { + err := httpGET(host, port) + if err == nil { + return + } + if time.Now().After(deadline) { + t.Fatalf("expected %s:%d to be reachable: %v", host, port, err) + } + time.Sleep(50 * time.Millisecond) + } +} + +func requireHTTPUnreachable(t *testing.T, host string, port int) { + t.Helper() + + if err := httpGET(host, port); err == nil { + t.Fatalf("expected %s:%d to be unreachable", host, port) + } +} + +func httpGET(host string, port int) error { + client := &http.Client{ + Timeout: 300 * time.Millisecond, + Transport: &http.Transport{ + Proxy: nil, + }, + } + + resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} + +func mustAtoi(t *testing.T, value string) int { + t.Helper() + n, err := strconv.Atoi(value) + if err != nil { + t.Fatalf("Atoi(%q) error = %v", value, err) + } + return n +} diff --git a/pkg/netbind/socket_v6only_unix.go b/pkg/netbind/socket_v6only_unix.go new file mode 100644 index 000000000..20cf7bbce --- /dev/null +++ b/pkg/netbind/socket_v6only_unix.go @@ -0,0 +1,25 @@ +//go:build !windows + +package netbind + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error { + return func(_, _ string, rawConn syscall.RawConn) error { + var controlErr error + if err := rawConn.Control(func(fd uintptr) { + value := 0 + if enabled { + value = 1 + } + controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, value) + }); err != nil { + return err + } + return controlErr + } +} diff --git a/pkg/netbind/socket_v6only_windows.go b/pkg/netbind/socket_v6only_windows.go new file mode 100644 index 000000000..006b4e1ac --- /dev/null +++ b/pkg/netbind/socket_v6only_windows.go @@ -0,0 +1,25 @@ +//go:build windows + +package netbind + +import ( + "syscall" + + "golang.org/x/sys/windows" +) + +func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error { + return func(_, _ string, rawConn syscall.RawConn) error { + var controlErr error + if err := rawConn.Control(func(fd uintptr) { + value := 0 + if enabled { + value = 1 + } + controlErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, value) + }); err != nil { + return err + } + return controlErr + } +} diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 0dec45cba..fa5652323 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" ppid "github.com/sipeed/picoclaw/pkg/pid" "github.com/sipeed/picoclaw/web/backend/utils" ) @@ -119,6 +120,7 @@ var ( gatewayRestartGracePeriod = 5 * time.Second gatewayRestartForceKillWindow = 3 * time.Second gatewayRestartPollInterval = 100 * time.Millisecond + gatewayExecCommand = exec.Command ) var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { @@ -262,7 +264,7 @@ func (h *Handler) getGatewayHealthForPidData( host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) } if host == "" { - host = "127.0.0.1" + host = netbind.ResolveAdaptiveLoopbackHost() } url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health" @@ -723,7 +725,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int execPath := utils.FindPicoclawBinary() logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath)) - cmd = exec.Command(execPath, h.gatewayCommandArgs()...) + cmd = gatewayExecCommand(execPath, h.gatewayCommandArgs()...) cmd.Env = os.Environ() // Forward the launcher's config path via the environment variable that // GetConfigPath() already reads, so the gateway sub-process uses the same @@ -731,8 +733,9 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int if h.configPath != "" { cmd.Env = append(cmd.Env, config.EnvConfig+"="+h.configPath) } - if host := h.gatewayHostOverride(); host != "" { - cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+host) + gatewayHostOverride := h.gatewayHostOverride() + if gatewayHostOverride != "" { + cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+gatewayHostOverride) } stdoutPipe, err := cmd.StdoutPipe() diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index f8e8eadba..c6c2073e2 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -8,9 +8,15 @@ import ( "strings" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/netbind" ) func (h *Handler) effectiveLauncherPublic() bool { + if h.serverHostExplicit { + // -host takes precedence over -public and launcher-config public setting. + return false + } + if h.serverPublicExplicit { return h.serverPublic } @@ -24,8 +30,11 @@ func (h *Handler) effectiveLauncherPublic() bool { } func (h *Handler) gatewayHostOverride() string { + if h.serverHostExplicit { + return strings.TrimSpace(h.serverHostInput) + } if h.effectiveLauncherPublic() { - return "0.0.0.0" + return "*" } return "" } @@ -41,10 +50,11 @@ func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string { } func gatewayProbeHost(bindHost string) string { - if bindHost == "" || bindHost == "0.0.0.0" { - return "127.0.0.1" + plan, err := netbind.BuildPlan(bindHost, netbind.DefaultLoopback) + if err != nil || strings.TrimSpace(plan.ProbeHost) == "" { + return netbind.ResolveAdaptiveLoopbackHost() } - return bindHost + return plan.ProbeHost } func (h *Handler) gatewayProxyURL() *url.URL { @@ -72,7 +82,7 @@ func requestHostName(r *http.Request) string { if strings.TrimSpace(r.Host) != "" { return r.Host } - return "127.0.0.1" + return netbind.ResolveAdaptiveLoopbackHost() } func requestWSScheme(r *http.Request) string { diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index 7150b6fee..d0fc26d7b 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -3,6 +3,7 @@ package api import ( "crypto/tls" "errors" + "net" "net/http" "net/http/httptest" "path/filepath" @@ -10,6 +11,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/web/backend/launcherconfig" ) @@ -26,8 +28,8 @@ func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) { h := NewHandler(configPath) h.SetServerOptions(18800, true, true, nil) - if got := h.gatewayHostOverride(); got != "0.0.0.0" { - t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0") + if got := h.gatewayHostOverride(); got != "*" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "*") } } @@ -64,8 +66,36 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) { } func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) { - if got := gatewayProbeHost("0.0.0.0"); got != "127.0.0.1" { - t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1") + want := "127.0.0.1" + if got := gatewayProbeHost("0.0.0.0"); got != want { + t.Fatalf("gatewayProbeHost() = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesPreferredLoopbackForEmptyBind(t *testing.T) { + want := netbind.ResolveAdaptiveLoopbackHost() + if got := gatewayProbeHost(""); got != want { + t.Fatalf("gatewayProbeHost(empty) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesPreferredLoopbackForLocalhostBind(t *testing.T) { + want := netbind.ResolveAdaptiveLoopbackHost() + if got := gatewayProbeHost("localhost"); got != want { + t.Fatalf("gatewayProbeHost(localhost) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesLoopbackForIPv6WildcardBind(t *testing.T) { + want := "::1" + if got := gatewayProbeHost("::"); got != want { + t.Fatalf("gatewayProbeHost(::) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesFirstConcreteHostForMultiHostBind(t *testing.T) { + if got := gatewayProbeHost("127.0.0.1,::1"); got != "127.0.0.1" { + t.Fatalf("gatewayProbeHost(multi) = %q, want %q", got, "127.0.0.1") } } @@ -137,8 +167,9 @@ func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) { _ = statusCode _ = err - if requestedURL != "http://127.0.0.1:18791/health" { - t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health") + want := "http://" + net.JoinHostPort(netbind.ResolveAdaptiveLoopbackHost(), "18791") + "/health" + if requestedURL != want { + t.Fatalf("health url = %q, want %q", requestedURL, want) } } @@ -240,3 +271,43 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) { t.Fatalf("buildWsURL() = %q, want %q", got, "ws://localhost:18800/pico/ws") } } + +func TestGatewayHostOverrideWithExplicitHostAndAlignedGatewayHost(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, false, false, nil) + h.SetServerBindHost("0.0.0.0", true) + + if got := h.gatewayHostOverride(); got != "0.0.0.0" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0") + } +} + +func TestGatewayHostOverrideWithExplicitHostAndLocalhostGatewayHost(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, false, false, nil) + h.SetServerBindHost("::", true) + + if got := h.gatewayHostOverride(); got != "::" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "::") + } +} + +func TestGatewayHostOverrideWithExplicitMultiHost(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, false, false, nil) + h.SetServerBindHost("127.0.0.1,::1", true) + + if got := h.gatewayHostOverride(); got != "127.0.0.1,::1" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "127.0.0.1,::1") + } +} + +func TestGatewayHostExplicitIgnoresPublicFlag(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, true, true, nil) + h.SetServerBindHost("127.0.0.1", true) + + if got := h.effectiveLauncherPublic(); got { + t.Fatalf("effectiveLauncherPublic() = %t, want false when explicit host is set", got) + } +} diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index d300b657c..78bf34a63 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -97,6 +97,7 @@ func resetGatewayTestState(t *testing.T) { originalHealthGet := gatewayHealthGet originalProcessMatcher := gatewayProcessMatcher + originalExecCommand := gatewayExecCommand originalRestartGracePeriod := gatewayRestartGracePeriod originalRestartForceKillWindow := gatewayRestartForceKillWindow originalRestartPollInterval := gatewayRestartPollInterval @@ -104,6 +105,7 @@ func resetGatewayTestState(t *testing.T) { t.Cleanup(func() { gatewayHealthGet = originalHealthGet gatewayProcessMatcher = originalProcessMatcher + gatewayExecCommand = originalExecCommand gatewayRestartGracePeriod = originalRestartGracePeriod gatewayRestartForceKillWindow = originalRestartForceKillWindow gatewayRestartPollInterval = originalRestartPollInterval @@ -119,6 +121,159 @@ func resetGatewayTestState(t *testing.T) { }) } +type gatewayStartEnvSnapshot struct { + GatewayHost string `json:"gateway_host"` + GatewayHostSet bool `json:"gateway_host_set"` + ConfigPath string `json:"config_path"` +} + +func TestGatewayStartHelperProcess(t *testing.T) { + var envPath string + for i, arg := range os.Args { + if arg == "--" && i+2 < len(os.Args) && os.Args[i+1] == "gateway-env-helper" { + envPath = os.Args[i+2] + break + } + } + if envPath == "" { + t.Skip("helper process") + } + + host, ok := os.LookupEnv(config.EnvGatewayHost) + raw, err := json.Marshal(gatewayStartEnvSnapshot{ + GatewayHost: host, + GatewayHostSet: ok, + ConfigPath: os.Getenv(config.EnvConfig), + }) + if err != nil { + _, _ = io.WriteString(os.Stderr, err.Error()) + os.Exit(2) + } + if err := os.WriteFile(envPath, raw, 0o600); err != nil { + _, _ = io.WriteString(os.Stderr, err.Error()) + os.Exit(2) + } + os.Exit(0) +} + +func unsetGatewayStartEnvForTest(t *testing.T, key string) { + t.Helper() + + prev, hadPrev := os.LookupEnv(key) + if err := os.Unsetenv(key); err != nil { + t.Fatalf("Unsetenv(%q) error = %v", key, err) + } + t.Cleanup(func() { + if hadPrev { + _ = os.Setenv(key, prev) + return + } + _ = os.Unsetenv(key) + }) +} + +func newGatewayStartTestHandler(t *testing.T) *Handler { + t.Helper() + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + h.SetServerOptions(18800, false, false, nil) + return h +} + +func startGatewayAndCaptureEnv(t *testing.T, h *Handler) gatewayStartEnvSnapshot { + t.Helper() + + unsetGatewayStartEnvForTest(t, config.EnvGatewayHost) + + envPath := filepath.Join(t.TempDir(), "gateway-child-env.json") + gatewayExecCommand = func(_ string, _ ...string) *exec.Cmd { + return exec.Command( + os.Args[0], + "-test.run=TestGatewayStartHelperProcess", + "--", + "gateway-env-helper", + envPath, + ) + } + + pid, err := h.startGatewayLocked("starting", 0) + if err != nil { + t.Fatalf("startGatewayLocked() error = %v", err) + } + if pid <= 0 { + t.Fatalf("startGatewayLocked() pid = %d, want > 0", pid) + } + + deadline := time.Now().Add(3 * time.Second) + for { + raw, err := os.ReadFile(envPath) + if err == nil { + var snapshot gatewayStartEnvSnapshot + err = json.Unmarshal(raw, &snapshot) + if err != nil { + t.Fatalf("Unmarshal(child env) error = %v", err) + } + return snapshot + } + if !os.IsNotExist(err) { + t.Fatalf("ReadFile(%q) error = %v", envPath, err) + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for gateway child env snapshot %q", envPath) + } + time.Sleep(20 * time.Millisecond) + } +} + +func TestStartGatewayLocked_ForwardsLauncherHostOverrideToGatewayEnv(t *testing.T) { + h := newGatewayStartTestHandler(t) + h.SetServerBindHost("127.0.0.1,::1", true) + + snapshot := startGatewayAndCaptureEnv(t, h) + if !snapshot.GatewayHostSet { + t.Fatal("gateway host env was not set") + } + if snapshot.GatewayHost != "127.0.0.1,::1" { + t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "127.0.0.1,::1") + } + if snapshot.ConfigPath != h.configPath { + t.Fatalf("config env = %q, want %q", snapshot.ConfigPath, h.configPath) + } +} + +func TestStartGatewayLocked_ForwardsLauncherHostFromEnvironmentToGatewayEnv(t *testing.T) { + h := newGatewayStartTestHandler(t) + h.SetServerBindHost("::", true) + + snapshot := startGatewayAndCaptureEnv(t, h) + if !snapshot.GatewayHostSet { + t.Fatal("gateway host env was not set") + } + if snapshot.GatewayHost != "::" { + t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "::") + } +} + +func TestStartGatewayLocked_ForwardsWildcardHostForPublicLauncher(t *testing.T) { + h := newGatewayStartTestHandler(t) + h.SetServerOptions(18800, true, true, nil) + + snapshot := startGatewayAndCaptureEnv(t, h) + if !snapshot.GatewayHostSet { + t.Fatal("gateway host env was not set") + } + if snapshot.GatewayHost != "*" { + t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "*") + } +} + func TestGatewayStartReady_NoDefaultModel(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) diff --git a/web/backend/api/router.go b/web/backend/api/router.go index c6781baf1..76f63607e 100644 --- a/web/backend/api/router.go +++ b/web/backend/api/router.go @@ -2,6 +2,7 @@ package api import ( "net/http" + "strings" "sync" "github.com/sipeed/picoclaw/web/backend/launcherconfig" @@ -13,6 +14,8 @@ type Handler struct { serverPort int serverPublic bool serverPublicExplicit bool + serverHostInput string + serverHostExplicit bool serverCIDRs []string debug bool oauthMu sync.Mutex @@ -41,9 +44,21 @@ func (h *Handler) SetServerOptions(port int, public bool, publicExplicit bool, a h.serverPort = port h.serverPublic = public h.serverPublicExplicit = publicExplicit + h.serverHostInput = "" + h.serverHostExplicit = false h.serverCIDRs = append([]string(nil), allowedCIDRs...) } +// SetServerBindHost stores the launcher's effective bind host. +// When explicit is true, hostInput is the normalized -host / PICOCLAW_LAUNCHER_HOST value. +func (h *Handler) SetServerBindHost(hostInput string, explicit bool) { + h.serverHostInput = strings.TrimSpace(hostInput) + if !explicit { + h.serverHostInput = "" + } + h.serverHostExplicit = explicit +} + func (h *Handler) SetDebug(debug bool) { h.debug = debug } diff --git a/web/backend/app_runtime.go b/web/backend/app_runtime.go index ab564db2c..a06396526 100644 --- a/web/backend/app_runtime.go +++ b/web/backend/app_runtime.go @@ -34,22 +34,30 @@ func shutdownApp() { apiHandler.Shutdown() } - if server != nil { - // Disable keep-alive to allow graceful shutdown - server.SetKeepAlivesEnabled(false) - - ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) - defer cancel() - if err := server.Shutdown(ctx); err != nil { - // Context deadline exceeded is expected if there are active connections - // This is not necessarily an error, so log it at info level - if errors.Is(err, context.DeadlineExceeded) { - logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout) - } else { - logger.Errorf("Server shutdown error: %v", err) + if len(servers) > 0 { + for _, srv := range servers { + if srv == nil { + continue + } + + // Disable keep-alive to allow graceful shutdown + srv.SetKeepAlivesEnabled(false) + + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + err := srv.Shutdown(ctx) + cancel() + + if err != nil { + // Context deadline exceeded is expected if there are active connections + // This is not necessarily an error, so log it at info level + if errors.Is(err, context.DeadlineExceeded) { + logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout) + } else { + logger.Errorf("Server shutdown error: %v", err) + } + } else { + logger.Infof("Server shutdown completed successfully") } - } else { - logger.Infof("Server shutdown completed successfully") } } } diff --git a/web/backend/launcherconfig/config.go b/web/backend/launcherconfig/config.go index 60c369f4f..b6faa63fe 100644 --- a/web/backend/launcherconfig/config.go +++ b/web/backend/launcherconfig/config.go @@ -16,6 +16,10 @@ const ( FileName = "launcher-config.json" // DefaultPort is the default port for the web launcher. DefaultPort = 18800 + // EnvLauncherToken overrides launcher dashboard token. + EnvLauncherToken = "PICOCLAW_LAUNCHER_TOKEN" + // EnvLauncherHost overrides launcher listen host. + EnvLauncherHost = "PICOCLAW_LAUNCHER_HOST" // dashboardSigningKeyBytes is the HMAC-SHA256 key size (256 bits). dashboardSigningKeyBytes = 32 @@ -59,7 +63,7 @@ func Validate(cfg Config) error { // EnsureDashboardSecrets returns signing key bytes and the effective dashboard token for this // process. The signing key is freshly random each call; the token comes from -// PICOCLAW_LAUNCHER_TOKEN when set, otherwise launcher-config.json launcher_token, +// EnvLauncherToken when set, otherwise launcher-config.json launcher_token, // otherwise a new random token. func EnsureDashboardSecrets( cfg Config, @@ -69,7 +73,7 @@ func EnsureDashboardSecrets( return "", nil, "", err } - effectiveToken = strings.TrimSpace(os.Getenv("PICOCLAW_LAUNCHER_TOKEN")) + effectiveToken = strings.TrimSpace(os.Getenv(EnvLauncherToken)) if effectiveToken != "" { return effectiveToken, signingKey, DashboardTokenSourceEnv, nil } diff --git a/web/backend/main.go b/web/backend/main.go index c5d25f6ef..57409f03a 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -15,17 +15,20 @@ import ( "errors" "flag" "fmt" + "net" "net/http" "net/url" "os" "os/signal" "path/filepath" "strconv" + "strings" "syscall" "time" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/web/backend/api" "github.com/sipeed/picoclaw/web/backend/dashboardauth" "github.com/sipeed/picoclaw/web/backend/launcherconfig" @@ -44,7 +47,7 @@ const ( var ( appVersion = config.Version - server *http.Server + servers []*http.Server serverAddr string // browserLaunchURL is opened by openBrowser() (auto-open + tray "open console"). // Includes ?token= for same-machine dashboard login; keep serverAddr without secrets for other use. @@ -65,6 +68,255 @@ func dashboardTokenConfigHelpPath(source launcherconfig.DashboardTokenSource, la return launcherPath } +func resolveLauncherHostInput(flagHost string, explicitFlag bool, envHost string) (string, bool, error) { + if explicitFlag { + normalized, err := netbind.NormalizeHostInput(flagHost) + if err != nil { + return "", false, err + } + return normalized, true, nil + } + + envHost = strings.TrimSpace(envHost) + if envHost == "" { + return "", false, nil + } + + normalized, err := netbind.NormalizeHostInput(envHost) + if err != nil { + return "", false, err + } + return normalized, true, nil +} + +func openLauncherListeners(hostInput string, public bool, port string) (netbind.OpenResult, error) { + defaultMode := netbind.DefaultLoopback + if strings.TrimSpace(hostInput) == "" && public { + defaultMode = netbind.DefaultAny + } + + plan, err := netbind.BuildPlan(hostInput, defaultMode) + if err != nil { + return netbind.OpenResult{}, err + } + return netbind.OpenPlan(plan, port) +} + +func appendUniqueHost(hosts []string, seen map[string]struct{}, host string) []string { + host = strings.TrimSpace(host) + if host == "" { + return hosts + } + key := strings.ToLower(host) + if _, ok := seen[key]; ok { + return hosts + } + seen[key] = struct{}{} + return append(hosts, host) +} + +func hasWildcardBindHosts(bindHosts []string) bool { + for _, bindHost := range bindHosts { + if netbind.IsUnspecifiedHost(bindHost) { + return true + } + } + return false +} + +func wildcardBindHostFamilies(bindHosts []string) (hasIPv4, hasIPv6 bool) { + for _, bindHost := range bindHosts { + host := strings.TrimSpace(bindHost) + if host == "" { + continue + } + + if !netbind.IsUnspecifiedHost(host) { + continue + } + + ip := net.ParseIP(strings.Trim(host, "[]")) + if ip == nil { + continue + } + if ip.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + + return hasIPv4, hasIPv6 +} + +func wildcardAdvertiseIP(bindHosts []string, ipv4, ipv6 string) string { + hasIPv4Wildcard, hasIPv6Wildcard := wildcardBindHostFamilies(bindHosts) + v4 := strings.TrimSpace(ipv4) + v6 := strings.TrimSpace(ipv6) + + switch { + case hasIPv4Wildcard && hasIPv6Wildcard: + if v6 != "" { + return v6 + } + return v4 + case hasIPv6Wildcard: + return v6 + case hasIPv4Wildcard: + return v4 + default: + return "" + } +} + +func advertiseIPForWildcardBindHosts(bindHosts []string) string { + return wildcardAdvertiseIP(bindHosts, utils.GetLocalIPv4(), utils.GetLocalIPv6()) +} + +func appendLauncherConsoleHostList(hosts []string, seen map[string]struct{}, values []string) []string { + for _, value := range values { + hosts = appendUniqueHost(hosts, seen, value) + } + return hosts +} + +func shouldShowLocalhostConsoleEntry(hostInput string) bool { + normalizedHostInput := strings.TrimSpace(hostInput) + if normalizedHostInput == "" { + return true + } + + for token := range strings.SplitSeq(normalizedHostInput, ",") { + token = strings.TrimSpace(token) + if token == "" { + continue + } + if token == "*" || strings.EqualFold(token, "localhost") { + return true + } + + ip := net.ParseIP(strings.Trim(token, "[]")) + if ip == nil { + continue + } + if ip4 := ip.To4(); ip4 != nil { + if ip4.String() == "127.0.0.1" || ip4.String() == "0.0.0.0" { + return true + } + continue + } + if ip.String() == "::1" || ip.String() == "::" { + return true + } + } + + return false +} + +func isConsoleDisplayGlobalIPv6(ip net.IP) bool { + if ip == nil || ip.IsLoopback() || ip.To4() != nil { + return false + } + ip = ip.To16() + if ip == nil { + return false + } + return ip[0]&0xe0 == 0x20 +} + +func launcherConsoleHostsWithLocalAddrs( + hostInput string, + public bool, + ipv4s []string, + globalIPv6s []string, +) []string { + hosts := make([]string, 0, 8) + seen := make(map[string]struct{}, 8) + + if shouldShowLocalhostConsoleEntry(hostInput) { + hosts = appendUniqueHost(hosts, seen, "localhost") + } + + normalizedHostInput := strings.TrimSpace(hostInput) + if normalizedHostInput == "" { + if public { + hosts = appendLauncherConsoleHostList(hosts, seen, globalIPv6s) + hosts = appendLauncherConsoleHostList(hosts, seen, ipv4s) + } + return hosts + } + + hasStar := false + hasIPv4Any := false + hasIPv6Any := false + for _, token := range strings.Split(normalizedHostInput, ",") { + switch strings.TrimSpace(token) { + case "*": + hasStar = true + case "0.0.0.0": + hasIPv4Any = true + case "::": + hasIPv6Any = true + } + } + + if hasStar { + hosts = appendLauncherConsoleHostList(hosts, seen, globalIPv6s) + hosts = appendLauncherConsoleHostList(hosts, seen, ipv4s) + return hosts + } + + for _, token := range strings.Split(normalizedHostInput, ",") { + token = strings.TrimSpace(token) + if token == "" || strings.EqualFold(token, "localhost") || netbind.IsLoopbackHost(token) { + continue + } + + ip := net.ParseIP(strings.Trim(token, "[]")) + switch { + case token == "::": + hosts = appendLauncherConsoleHostList(hosts, seen, globalIPv6s) + case token == "0.0.0.0": + hosts = appendLauncherConsoleHostList(hosts, seen, ipv4s) + case ip != nil && ip.To4() != nil: + if hasIPv4Any { + continue + } + hosts = appendUniqueHost(hosts, seen, ip.String()) + case ip != nil: + if hasIPv6Any { + continue + } + if isConsoleDisplayGlobalIPv6(ip) { + hosts = appendUniqueHost(hosts, seen, ip.String()) + } + default: + hosts = appendUniqueHost(hosts, seen, token) + } + } + + return hosts +} + +func launcherConsoleHosts(hostInput string, public bool) []string { + return launcherConsoleHostsWithLocalAddrs( + hostInput, + public, + utils.GetLocalIPv4s(), + utils.GetGlobalIPv6s(), + ) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} + // maskSecret masks a secret for display. It always shows up to the first 3 // runes. The last 4 runes are only appended when at least 5 runes remain // hidden in the middle (i.e. string length >= 12), so an 8-char minimum @@ -85,7 +337,8 @@ func maskSecret(s string) string { func main() { port := flag.String("port", "18800", "Port to listen on") - public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only") + host := flag.String("host", "", "Host to listen on (overrides -public when set)") + public := flag.Bool("public", false, "Listen on all interfaces (dual-stack) instead of localhost only") noBrowser = flag.Bool("no-browser", false, "Do not auto-open browser on startup") lang := flag.String("lang", "", "Language: en (English) or zh (Chinese). Default: auto-detect from system locale") console := flag.Bool("console", false, "Console mode, no GUI") @@ -112,6 +365,8 @@ func main() { os.Args[0], ) fmt.Fprintf(os.Stderr, " Allow access from other devices on the local network\n") + fmt.Fprintf(os.Stderr, " %s -host :: ./config.json\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " Bind launcher host explicitly with exact host semantics\n") fmt.Fprintf(os.Stderr, " %s -console -d ./config.json\n", os.Args[0]) fmt.Fprintf(os.Stderr, " Run in the terminal with debug logs enabled\n") } @@ -175,8 +430,9 @@ func main() { logger.DebugC( "web", fmt.Sprintf( - "Launcher flags: console=%t public=%t no_browser=%t config=%s", + "Launcher flags: console=%t host=%q public=%t no_browser=%t config=%s", enableConsole, + *host, *public, *noBrowser, absPath, @@ -186,10 +442,13 @@ func main() { var explicitPort bool var explicitPublic bool + var explicitHost bool flag.Visit(func(f *flag.Flag) { switch f.Name { case "port": explicitPort = true + case "host": + explicitHost = true case "public": explicitPublic = true } @@ -210,6 +469,23 @@ func main() { if !explicitPublic { effectivePublic = launcherCfg.Public } + envHost := strings.TrimSpace(os.Getenv(launcherconfig.EnvLauncherHost)) + + hostInput, hostOverrideActive, err := resolveLauncherHostInput(*host, explicitHost, envHost) + if err != nil { + logger.Fatalf("Invalid host %q: %v", firstNonEmpty(strings.TrimSpace(*host), envHost), err) + } + if hostOverrideActive { + effectivePublic = false + } + + if !explicitHost && hostOverrideActive { + logger.InfoC("web", "Using launcher host from environment PICOCLAW_LAUNCHER_HOST") + } + + if hostOverrideActive && explicitPublic { + logger.InfoC("web", "Ignoring -public because launcher host was explicitly set") + } portNum, err := strconv.Atoi(effectivePort) if err != nil || portNum < 1 || portNum > 65535 { @@ -219,6 +495,12 @@ func main() { logger.Fatalf("Invalid port %q: %v", effectivePort, err) } + openResult, err := openLauncherListeners(hostInput, effectivePublic, effectivePort) + if err != nil { + logger.Fatalf("Failed to open launcher listener(s): %v", err) + } + listeners := openResult.Listeners + dashboardToken, dashboardSigningKey, dashboardTokenSource, dashErr := launcherconfig.EnsureDashboardSecrets( launcherCfg, ) @@ -246,14 +528,6 @@ func main() { logger.ErrorC("web", fmt.Sprintf("Warning: could not open auth store: %v", authStoreErr)) } - // Determine listen address - var addr string - if effectivePublic { - addr = "0.0.0.0:" + effectivePort - } else { - addr = "127.0.0.1:" + effectivePort - } - // Initialize Server components mux := http.NewServeMux() @@ -271,6 +545,7 @@ func main() { logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) } apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs) + apiHandler.SetServerBindHost(hostInput, hostOverrideActive) apiHandler.RegisterRoutes(mux) // Frontend Embedded Assets @@ -297,15 +572,14 @@ func main() { // Print startup banner and token (console mode only). if enableConsole || debug { + consoleHosts := launcherConsoleHosts(hostInput, effectivePublic) + fmt.Print(utils.Banner) fmt.Println() fmt.Println(" Open the following URL in your browser:") fmt.Println() - fmt.Printf(" >> http://localhost:%s <<\n", effectivePort) - if effectivePublic { - if ip := utils.GetLocalIP(); ip != "" { - fmt.Printf(" >> http://%s:%s <<\n", ip, effectivePort) - } + for _, host := range consoleHosts { + fmt.Printf(" >> http://%s <<\n", net.JoinHostPort(host, effectivePort)) } fmt.Println() switch dashboardTokenSource { @@ -331,15 +605,17 @@ func main() { } // Log startup info to file - logger.InfoC("web", fmt.Sprintf("Server will listen on http://localhost:%s", effectivePort)) - if effectivePublic { - if ip := utils.GetLocalIP(); ip != "" { - logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s:%s", ip, effectivePort)) + for _, ln := range listeners { + logger.InfoC("web", fmt.Sprintf("Server will listen on http://%s", ln.Addr().String())) + } + if hasWildcardBindHosts(openResult.BindHosts) { + if ip := advertiseIPForWildcardBindHosts(openResult.BindHosts); ip != "" { + logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s", net.JoinHostPort(ip, effectivePort))) } } // Share the local URL with the launcher runtime. - serverAddr = fmt.Sprintf("http://localhost:%s", effectivePort) + serverAddr = fmt.Sprintf("http://%s", net.JoinHostPort(openResult.ProbeHost, effectivePort)) if dashboardToken != "" { browserLaunchURL = serverAddr + "?token=" + url.QueryEscape(dashboardToken) } else { @@ -354,14 +630,19 @@ func main() { apiHandler.TryAutoStartGateway() }() - // Start the Server in a goroutine - server = &http.Server{Addr: addr, Handler: handler} - go func() { - logger.InfoC("web", fmt.Sprintf("Server listening on %s", addr)) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Fatalf("Server failed to start: %v", err) - } - }() + // Start the server(s) in goroutines. + servers = make([]*http.Server, 0, len(listeners)) + for _, ln := range listeners { + srv := &http.Server{Handler: handler} + servers = append(servers, srv) + + go func(s *http.Server, l net.Listener) { + logger.InfoC("web", fmt.Sprintf("Server listening on %s", l.Addr().String())) + if serveErr := s.Serve(l); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + logger.Fatalf("Server failed to start on %s: %v", l.Addr().String(), serveErr) + } + }(srv, ln) + } defer shutdownApp() diff --git a/web/backend/main_test.go b/web/backend/main_test.go index 82bf12b40..6df5370b1 100644 --- a/web/backend/main_test.go +++ b/web/backend/main_test.go @@ -1,8 +1,17 @@ package main import ( + "context" + "errors" + "io" + "net" + "net/http" + "strconv" + "strings" "testing" + "time" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/web/backend/launcherconfig" ) @@ -73,25 +82,351 @@ func TestMaskSecret(t *testing.T) { input string want string }{ - // Long token (>=12 chars): first 3 + 10 stars + last 4 {"sdhjflsjdflksdf", "sdh**********ksdf"}, {"abcdefghijklmnopqrstuvwxyz", "abc**********wxyz"}, - // Exactly 12 chars (3+4+5 hidden): suffix shown {"abcdefghijkl", "abc**********ijkl"}, - // 8 chars (minimum password length): suffix NOT shown — only prefix+stars {"abcdefgh", "abc**********"}, - // 11 chars (one below threshold): suffix NOT shown {"abcdefghijk", "abc**********"}, - // 4..3 chars: prefix shown, no suffix {"abcdefg", "abc**********"}, {"abcd", "abc**********"}, - // <=3 chars: fully masked {"abc", "**********"}, {"", "**********"}, } + for _, tt := range tests { if got := maskSecret(tt.input); got != tt.want { t.Errorf("maskSecret(%q) = %q, want %q", tt.input, got, tt.want) } } } + +func TestResolveLauncherHostInput(t *testing.T) { + tests := []struct { + name string + flagHost string + explicitFlag bool + envHost string + wantHost string + wantActive bool + wantErr bool + }{ + { + name: "flag host wins", + flagHost: "127.0.0.1", + explicitFlag: true, + envHost: "::", + wantHost: "127.0.0.1", + wantActive: true, + }, + {name: "env host used when flag absent", envHost: "127.0.0.1,::1", wantHost: "127.0.0.1,::1", wantActive: true}, + {name: "blank env ignored", envHost: " ", wantHost: "", wantActive: false}, + {name: "invalid flag rejected", flagHost: "127.0.0.1, ", explicitFlag: true, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotHost, gotActive, err := resolveLauncherHostInput(tt.flagHost, tt.explicitFlag, tt.envHost) + if (err != nil) != tt.wantErr { + t.Fatalf("resolveLauncherHostInput() err = %v, wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + return + } + if gotHost != tt.wantHost { + t.Fatalf("resolveLauncherHostInput() host = %q, want %q", gotHost, tt.wantHost) + } + if gotActive != tt.wantActive { + t.Fatalf("resolveLauncherHostInput() active = %t, want %t", gotActive, tt.wantActive) + } + }) + } +} + +func TestLauncherConsoleHosts(t *testing.T) { + t.Run("default loopback shows localhost only", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit loopback hosts collapse to localhost", func(t *testing.T) { + tests := []struct { + name string + hostInput string + }{ + {name: "ipv6 loopback", hostInput: "::1"}, + {name: "ipv4 loopback", hostInput: "127.0.0.1"}, + {name: "localhost", hostInput: "localhost"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + tt.hostInput, + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + } + }) + + t.Run("public wildcard shows localhost then ipv6 and ipv4", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "", + true, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "2001:db8::1", "2001:db8::2", "192.168.1.2", "10.0.0.8"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit ipv6 any shows localhost then ipv6 variants", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "::", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "2001:db8::1", "2001:db8::2"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + + for _, host := range hosts { + if host == "::1" || host == "127.0.0.1" || strings.HasPrefix(strings.ToLower(host), "fe80:") { + t.Fatalf("hosts = %#v, loopback IPs must not be displayed", hosts) + } + } + }) + + t.Run("explicit ipv4 any shows localhost then lan ipv4", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "0.0.0.0", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "192.168.1.2", "10.0.0.8"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit wildcard star shows localhost first", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "*", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "2001:db8::1", "2001:db8::2", "192.168.1.2", "10.0.0.8"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit multi-address binding without local tokens hides localhost", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "192.168.1.2,10.0.0.8,2001:db8::1,2001:db8::2,fe80::1", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"192.168.1.2", "10.0.0.8", "2001:db8::1", "2001:db8::2"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) +} + +func TestWildcardAdvertiseIP(t *testing.T) { + tests := []struct { + name string + bindHosts []string + ipv4 string + ipv6 string + want string + }{ + { + name: "ipv4 wildcard uses ipv4", + bindHosts: []string{"0.0.0.0"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "192.168.1.2", + }, + { + name: "dual wildcard prefers ipv6", + bindHosts: []string{"0.0.0.0", "::"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "2001:db8::1", + }, + { + name: "ipv6 wildcard uses ipv6", + bindHosts: []string{"::"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "2001:db8::1", + }, + { + name: "dual wildcard falls back to ipv4 when ipv6 missing", + bindHosts: []string{"0.0.0.0", "::"}, + ipv4: "192.168.1.2", + ipv6: "", + want: "192.168.1.2", + }, + { + name: "ipv6 wildcard without ipv6 does not advertise ipv4", + bindHosts: []string{"::"}, + ipv4: "192.168.1.2", + ipv6: "", + want: "", + }, + { + name: "non wildcard does not advertise", + bindHosts: []string{"127.0.0.1"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wildcardAdvertiseIP(tt.bindHosts, tt.ipv4, tt.ipv6); got != tt.want { + t.Fatalf("wildcardAdvertiseIP(%#v, %q, %q) = %q, want %q", tt.bindHosts, tt.ipv4, tt.ipv6, got, tt.want) + } + }) + } +} + +func TestOpenLauncherListeners_HonorsIPv6OnlyHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv6 { + t.Skip("IPv6 is unavailable in this environment") + } + + result, err := openLauncherListeners("::", false, "0") + if err != nil { + t.Fatalf("openLauncherListeners() error = %v", err) + } + startLauncherTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireLauncherHTTPReachable(t, "::1", port) + if hasIPv4 { + requireLauncherHTTPUnreachable(t, "127.0.0.1", port) + } +} + +func TestOpenLauncherListeners_SupportsExplicitMultiHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + result, err := openLauncherListeners("127.0.0.1,::1", false, "0") + if err != nil { + t.Fatalf("openLauncherListeners() error = %v", err) + } + startLauncherTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireLauncherHTTPReachable(t, "127.0.0.1", port) + requireLauncherHTTPReachable(t, "::1", port) +} + +func startLauncherTestHTTPServer(t *testing.T, listeners []net.Listener) { + t.Helper() + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + }), + } + + errCh := make(chan error, len(listeners)) + for _, listener := range listeners { + ln := listener + go func() { + errCh <- server.Serve(ln) + }() + } + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(ctx) + for range listeners { + err := <-errCh + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("server.Serve() error = %v", err) + } + } + }) +} + +func requireLauncherHTTPReachable(t *testing.T, host string, port int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for { + err := launcherHTTPGet(host, port) + if err == nil { + return + } + if time.Now().After(deadline) { + t.Fatalf("expected %s:%d to be reachable: %v", host, port, err) + } + time.Sleep(50 * time.Millisecond) + } +} + +func requireLauncherHTTPUnreachable(t *testing.T, host string, port int) { + t.Helper() + if err := launcherHTTPGet(host, port); err == nil { + t.Fatalf("expected %s:%d to be unreachable", host, port) + } +} + +func launcherHTTPGet(host string, port int) error { + client := &http.Client{ + Timeout: 300 * time.Millisecond, + Transport: &http.Transport{ + Proxy: nil, + }, + } + + resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} + +func mustAtoi(t *testing.T, value string) int { + t.Helper() + n, err := strconv.Atoi(value) + if err != nil { + t.Fatalf("Atoi(%q) error = %v", value, err) + } + return n +} diff --git a/web/backend/utils/runtime.go b/web/backend/utils/runtime.go index 0b9e30979..8899a664b 100644 --- a/web/backend/utils/runtime.go +++ b/web/backend/utils/runtime.go @@ -7,6 +7,7 @@ import ( "os/exec" "path/filepath" "runtime" + "strings" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" @@ -54,18 +55,93 @@ func FindPicoclawBinary() string { return "picoclaw" } -// GetLocalIP returns the local IP address of the machine. -func GetLocalIP() string { +func appendUniqueIP(addrs []string, seen map[string]struct{}, value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return addrs + } + if _, ok := seen[value]; ok { + return addrs + } + seen[value] = struct{}{} + return append(addrs, value) +} + +// GetLocalIPv4s returns all non-loopback local IPv4 addresses. +func GetLocalIPv4s() []string { addrs, err := net.InterfaceAddrs() if err != nil { - return "" + return nil } + results := make([]string, 0, 4) + seen := make(map[string]struct{}, 4) for _, a := range addrs { - if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { - return ipnet.IP.String() + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP == nil || ipnet.IP.IsLoopback() { + continue + } + if ip4 := ipnet.IP.To4(); ip4 != nil { + results = appendUniqueIP(results, seen, ip4.String()) } } - return "" + return results +} + +func isDisplayGlobalIPv6(ip net.IP) bool { + if ip == nil || ip.IsLoopback() || ip.To4() != nil { + return false + } + ip = ip.To16() + if ip == nil { + return false + } + // Only show IPv6 global unicast addresses in 2000::/3. + return ip[0]&0xe0 == 0x20 +} + +// GetGlobalIPv6s returns all IPv6 global unicast addresses. +func GetGlobalIPv6s() []string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil + } + results := make([]string, 0, 4) + seen := make(map[string]struct{}, 4) + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP == nil { + continue + } + ip := ipnet.IP + if !isDisplayGlobalIPv6(ip) { + continue + } + results = appendUniqueIP(results, seen, ip.String()) + } + return results +} + +// GetLocalIPv4 returns the first non-loopback local IPv4 address. +func GetLocalIPv4() string { + addrs := GetLocalIPv4s() + if len(addrs) == 0 { + return "" + } + return addrs[0] +} + +// GetLocalIPv6 returns the first IPv6 global unicast address. +func GetLocalIPv6() string { + addrs := GetGlobalIPv6s() + if len(addrs) == 0 { + return "" + } + return addrs[0] +} + +// GetLocalIP returns a non-loopback local IPv4 address for backward compatibility. +func GetLocalIP() string { + return GetLocalIPv4() } // OpenBrowser automatically opens the given URL in the default browser.