diff --git a/cmd/picoclaw/internal/gateway/command.go b/cmd/picoclaw/internal/gateway/command.go index 7fa588c5c..7dd03b495 100644 --- a/cmd/picoclaw/internal/gateway/command.go +++ b/cmd/picoclaw/internal/gateway/command.go @@ -2,19 +2,34 @@ package gateway import ( "fmt" + "os" "github.com/spf13/cobra" "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/gateway" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/pkg/utils" ) +func resolveGatewayHostOverride(explicit bool, host string) (string, error) { + if !explicit { + return "", nil + } + normalized, err := netbind.NormalizeHostInput(host) + if err != nil { + return "", fmt.Errorf("invalid --host value: %w", err) + } + return normalized, nil +} + func NewGatewayCommand() *cobra.Command { var debug bool var noTruncate bool var allowEmpty bool + var host string cmd := &cobra.Command{ Use: "gateway", @@ -33,7 +48,25 @@ func NewGatewayCommand() *cobra.Command { return nil }, - RunE: func(_ *cobra.Command, _ []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { + resolvedHost, err := resolveGatewayHostOverride(cmd.Flags().Changed("host"), host) + if err != nil { + return err + } + if resolvedHost != "" { + prevHost, hadPrev := os.LookupEnv(config.EnvGatewayHost) + if err := os.Setenv(config.EnvGatewayHost, resolvedHost); err != nil { + return fmt.Errorf("failed to set %s: %w", config.EnvGatewayHost, err) + } + defer func() { + if hadPrev { + _ = os.Setenv(config.EnvGatewayHost, prevHost) + return + } + _ = os.Unsetenv(config.EnvGatewayHost) + }() + } + return gateway.Run(debug, internal.GetPicoclawHome(), internal.GetConfigPath(), allowEmpty) }, } @@ -47,6 +80,12 @@ func NewGatewayCommand() *cobra.Command { false, "Continue starting even when no default model is configured", ) + cmd.Flags().StringVar( + &host, + "host", + "", + "Host address for gateway binding (overrides gateway.host for this run)", + ) return cmd } diff --git a/cmd/picoclaw/internal/gateway/command_test.go b/cmd/picoclaw/internal/gateway/command_test.go index 839a7315a..825369abb 100644 --- a/cmd/picoclaw/internal/gateway/command_test.go +++ b/cmd/picoclaw/internal/gateway/command_test.go @@ -29,4 +29,38 @@ func TestNewGatewayCommand(t *testing.T) { assert.True(t, cmd.HasFlags()) assert.NotNil(t, cmd.Flags().Lookup("debug")) assert.NotNil(t, cmd.Flags().Lookup("allow-empty")) + assert.NotNil(t, cmd.Flags().Lookup("host")) +} + +func TestResolveGatewayHostOverride(t *testing.T) { + tests := []struct { + name string + explicit bool + host string + wantHost string + wantErr bool + }{ + {name: "implicit empty host is allowed", explicit: false, host: "", wantHost: "", wantErr: false}, + {name: "explicit empty host rejected", explicit: true, host: " ", wantHost: "", wantErr: true}, + {name: "explicit localhost kept", explicit: true, host: " localhost ", wantHost: "localhost", wantErr: false}, + { + name: "explicit multi host normalized", + explicit: true, + host: " [::1] , 127.0.0.1 ", + wantHost: "::1,127.0.0.1", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveGatewayHostOverride(tt.explicit, tt.host) + if (err != nil) != tt.wantErr { + t.Fatalf("resolveGatewayHostOverride() err = %v, wantErr %t", err, tt.wantErr) + } + if got != tt.wantHost { + t.Fatalf("resolveGatewayHostOverride() host = %q, want %q", got, tt.wantHost) + } + }) + } } diff --git a/config/config.example.json b/config/config.example.json index d56b1cff7..cd966e498 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -477,7 +477,7 @@ }, "gateway": { "_comment": "Default log level is set to 'fatal'. Other available options are 'debug', 'info', 'warn' and 'error'.", - "host": "127.0.0.1", + "host": "localhost", "port": 18790, "hot_reload": false, "log_level": "fatal" diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 4d8e47c0f..928676cbc 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "math" + "net" "net/http" "sort" "sync" @@ -86,6 +87,7 @@ type Manager struct { dispatchTask *asyncTask mux *dynamicServeMux httpServer *http.Server + httpListeners []net.Listener mu sync.RWMutex placeholders sync.Map // "channel:chatID" → placeholderID (string) typingStops sync.Map // "channel:chatID" → func() @@ -474,6 +476,12 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error { // It registers health endpoints from the health server and discovers channels // that implement WebhookHandler and/or HealthChecker to register their handlers. func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { + m.SetupHTTPServerListeners(nil, addr, healthServer) +} + +// SetupHTTPServerListeners creates a shared HTTP server on pre-opened listeners. +// When listeners is empty it falls back to Addr-based ListenAndServe behavior. +func (m *Manager) SetupHTTPServerListeners(listeners []net.Listener, addr string, healthServer *health.Server) { m.mux = newDynamicServeMux() // Register health endpoints @@ -490,6 +498,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } + m.httpListeners = append([]net.Listener(nil), listeners...) } // registerHTTPHandlersLocked registers webhook and health-check handlers for @@ -619,16 +628,33 @@ func (m *Manager) StartAll(ctx context.Context) error { // Start shared HTTP server if configured if m.httpServer != nil { - go func() { - logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ - "addr": m.httpServer.Addr, - }) - if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.FatalCF("channels", "Shared HTTP server error", map[string]any{ - "error": err.Error(), - }) + if len(m.httpListeners) > 0 { + for _, listener := range m.httpListeners { + ln := listener + go func() { + logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ + "addr": ln.Addr().String(), + }) + if err := m.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + logger.FatalCF("channels", "Shared HTTP server error", map[string]any{ + "addr": ln.Addr().String(), + "error": err.Error(), + }) + } + }() } - }() + } else { + go func() { + logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ + "addr": m.httpServer.Addr, + }) + if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.FatalCF("channels", "Shared HTTP server error", map[string]any{ + "error": err.Error(), + }) + } + }() + } } logger.InfoCF("channels", "Channel startup completed", map[string]any{ @@ -655,6 +681,7 @@ func (m *Manager) StopAll(ctx context.Context) error { }) } m.httpServer = nil + m.httpListeners = nil } // Cancel dispatcher diff --git a/pkg/config/config.go b/pkg/config/config.go index ae6a5cdb0..ab631107d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1143,6 +1143,8 @@ func LoadConfig(path string) (*Config, error) { applyLegacyBindingsMigration(data, cfg) + gatewayHostBeforeEnv := cfg.Gateway.Host + if err = env.Parse(cfg); err != nil { return nil, err } @@ -1151,6 +1153,10 @@ func LoadConfig(path string) (*Config, error) { if err = InitChannelList(cfg.Channels); err != nil { return nil, err } + cfg.Gateway.Host, err = resolveGatewayHostFromEnv(gatewayHostBeforeEnv) + if err != nil { + return nil, fmt.Errorf("invalid gateway host: %w", err) + } // Expand multi-key configs into separate entries for key-level failover cfg.ModelList = expandMultiKeyModels(cfg.ModelList) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 0bd8ee907..67411140c 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -503,7 +503,7 @@ func TestDefaultConfig_Temperature(t *testing.T) { func TestDefaultConfig_Gateway(t *testing.T) { cfg := DefaultConfig() - if cfg.Gateway.Host != "127.0.0.1" { + if cfg.Gateway.Host != "localhost" { t.Error("Gateway host should have default value") } if cfg.Gateway.Port == 0 { @@ -739,7 +739,7 @@ func TestConfig_Complete(t *testing.T) { if cfg.Agents.Defaults.MaxToolIterations == 0 { t.Error("MaxToolIterations should not be zero") } - if cfg.Gateway.Host != "127.0.0.1" { + if cfg.Gateway.Host != "localhost" { t.Error("Gateway host should have default value") } if cfg.Gateway.Port == 0 { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 6740c772e..f2f5c44c7 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -259,7 +259,7 @@ func DefaultConfig() *Config { }, }, Gateway: GatewayConfig{ - Host: "127.0.0.1", + Host: "localhost", Port: 18790, HotReload: false, LogLevel: DefaultGatewayLogLevel, diff --git a/pkg/config/envkeys.go b/pkg/config/envkeys.go index 615769d3c..5a2590299 100644 --- a/pkg/config/envkeys.go +++ b/pkg/config/envkeys.go @@ -39,7 +39,7 @@ const ( EnvBinary = "PICOCLAW_BINARY" // EnvGatewayHost overrides the host address for the gateway server. - // Default: "127.0.0.1" + // Default: "localhost" EnvGatewayHost = "PICOCLAW_GATEWAY_HOST" ) diff --git a/pkg/config/gateway.go b/pkg/config/gateway.go index e9f4085d3..392a4ca5e 100644 --- a/pkg/config/gateway.go +++ b/pkg/config/gateway.go @@ -3,8 +3,10 @@ package config import ( "encoding/json" "os" + "strings" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" ) const DefaultGatewayLogLevel = "warn" @@ -49,6 +51,31 @@ func EffectiveGatewayLogLevel(cfg *Config) string { return normalizeGatewayLogLevel(cfg.Gateway.LogLevel) } +func resolveGatewayHostFromEnv(baseHost string) (string, error) { + envHost, ok := os.LookupEnv(EnvGatewayHost) + if !ok { + return normalizeGatewayHostInput(baseHost) + } + + envHost = strings.TrimSpace(envHost) + if envHost == "" { + return normalizeGatewayHostInput(baseHost) + } + + return normalizeGatewayHostInput(envHost) +} + +func normalizeGatewayHostInput(host string) (string, error) { + host = strings.TrimSpace(host) + if host == "" { + host = strings.TrimSpace(DefaultConfig().Gateway.Host) + } + if host == "" { + host = "localhost" + } + return netbind.NormalizeHostInput(host) +} + // ResolveGatewayLogLevel reads the configured gateway log level without triggering // the full config loader, so startup code can apply logging before config load logs run. // The PICOCLAW_LOG_LEVEL environment variable overrides the file value. diff --git a/pkg/config/gateway_host_env_test.go b/pkg/config/gateway_host_env_test.go new file mode 100644 index 000000000..40fabb1a3 --- /dev/null +++ b/pkg/config/gateway_host_env_test.go @@ -0,0 +1,98 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +func writeGatewayHostTestConfig(t *testing.T, host string) string { + t.Helper() + + configPath := filepath.Join(t.TempDir(), "config.json") + raw := fmt.Sprintf(`{"version":2,"gateway":{"host":%q,"port":18790}}`, host) + if err := os.WriteFile(configPath, []byte(raw), 0o600); err != nil { + t.Fatalf("WriteFile(configPath): %v", err) + } + return configPath +} + +func TestLoadConfig_GatewayHostEnvTrimmed(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, "127.0.0.1") + t.Setenv(EnvGatewayHost, " ::1 ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Gateway.Host != "::1" { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1") + } +} + +func TestLoadConfig_GatewayHostBlankEnvFallsBackToConfigHost(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, " localhost ") + t.Setenv(EnvGatewayHost, " ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + want, err := normalizeGatewayHostInput("localhost") + if err != nil { + t.Fatalf("normalizeGatewayHostInput() error: %v", err) + } + if cfg.Gateway.Host != want { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want) + } +} + +func TestLoadConfig_GatewayHostBlankEnvAndConfigFallsBackToDefault(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, " ") + t.Setenv(EnvGatewayHost, " ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + + defaultHost, err := normalizeGatewayHostInput(DefaultConfig().Gateway.Host) + if err != nil { + t.Fatalf("normalizeGatewayHostInput() error: %v", err) + } + if cfg.Gateway.Host != defaultHost { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, defaultHost) + } +} + +func TestLoadConfig_GatewayHostEnvPreservesExplicitWildcardHost(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, "localhost") + t.Setenv(EnvGatewayHost, " 0.0.0.0 ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + + want, err := normalizeGatewayHostInput("0.0.0.0") + if err != nil { + t.Fatalf("normalizeGatewayHostInput() error: %v", err) + } + if cfg.Gateway.Host != want { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, want) + } +} + +func TestLoadConfig_GatewayHostEnvNormalizesMultiHostInput(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, "localhost") + t.Setenv(EnvGatewayHost, " [::1] , 127.0.0.1 , ::1 ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Gateway.Host != "::1,127.0.0.1" { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1,127.0.0.1") + } +} diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index a5afb0eb8..039f45075 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -3,10 +3,12 @@ package gateway import ( "context" "fmt" + "net" "os" "os/signal" "path/filepath" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -42,6 +44,7 @@ import ( "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/pkg/pid" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" @@ -159,13 +162,30 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr logger.Infof("Log level set to %q", effectiveLogLevel) } + bindPlan, listenResult, err := openGatewayListeners(cfg.Gateway.Host, cfg.Gateway.Port) + if err != nil { + return fmt.Errorf("error opening gateway listeners: %w", err) + } + // Enforce singleton: write PID file with generated token. - pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port) + pidData, err := pid.WritePidFile(homePath, bindPlan.ProbeHost, cfg.Gateway.Port) if err != nil { logger.Warnf("write pid file failed: %v", err) + for _, ln := range listenResult.Listeners { + _ = ln.Close() + } return fmt.Errorf("singleton check failed: %w", err) } defer pid.RemovePidFile(homePath) + closeListeners := true + defer func() { + if !closeListeners { + return + } + for _, ln := range listenResult.Listeners { + _ = ln.Close() + } + }() provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup) if err != nil { @@ -193,10 +213,11 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr "skills_available": skillsInfo["available"], }) - runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token) + runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token, listenResult) if err != nil { return err } + closeListeners = false // Setup manual reload channel for /reload endpoint manualReloadChan := make(chan struct{}, 1) @@ -217,7 +238,9 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr runningServices.HealthServer.SetReloadFunc(reloadTrigger) agentLoop.SetReloadFunc(reloadTrigger) - fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) + for _, bindHost := range listenResult.BindHosts { + fmt.Printf("✓ Gateway started on %s\n", net.JoinHostPort(bindHost, strconv.Itoa(cfg.Gateway.Port))) + } fmt.Println("Press Ctrl+C to stop") ctx, cancel := context.WithCancel(context.Background()) @@ -320,6 +343,7 @@ func setupAndStartServices( agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, authToken string, + listenResult netbind.OpenResult, ) (*services, error) { runningServices := &services{} @@ -390,10 +414,20 @@ func setupAndStartServices( fmt.Println("⚠ Warning: No channels enabled") } - addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) runningServices.authToken = authToken - runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken) - runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) + runningServices.HealthServer = health.NewServer(listenResult.ProbeHost, cfg.Gateway.Port, authToken) + + var listenAddr string + if len(listenResult.Listeners) > 0 { + listenAddr = listenResult.Listeners[0].Addr().String() + } else { + listenAddr = net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port)) + } + runningServices.ChannelManager.SetupHTTPServerListeners( + listenResult.Listeners, + listenAddr, + runningServices.HealthServer, + ) if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil { return nil, fmt.Errorf("error starting channels: %w", err) @@ -409,10 +443,10 @@ func setupAndStartServices( voiceAgent.Start(vaCtx) } + healthAddr := net.JoinHostPort(listenResult.ProbeHost, strconv.Itoa(cfg.Gateway.Port)) fmt.Printf( - "✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n", - cfg.Gateway.Host, - cfg.Gateway.Port, + "✓ Health endpoints available at http://%s/health, /ready and /reload (POST)\n", + healthAddr, ) stateManager := state.NewManager(cfg.WorkspacePath()) diff --git a/pkg/gateway/listen.go b/pkg/gateway/listen.go new file mode 100644 index 000000000..99be63096 --- /dev/null +++ b/pkg/gateway/listen.go @@ -0,0 +1,21 @@ +package gateway + +import ( + "strconv" + + "github.com/sipeed/picoclaw/pkg/netbind" +) + +func openGatewayListeners(host string, port int) (netbind.Plan, netbind.OpenResult, error) { + plan, err := netbind.BuildPlan(host, netbind.DefaultLoopback) + if err != nil { + return netbind.Plan{}, netbind.OpenResult{}, err + } + + result, err := netbind.OpenPlan(plan, strconv.Itoa(port)) + if err != nil { + return netbind.Plan{}, netbind.OpenResult{}, err + } + + return plan, result, nil +} diff --git a/pkg/gateway/listen_test.go b/pkg/gateway/listen_test.go new file mode 100644 index 000000000..9b932f852 --- /dev/null +++ b/pkg/gateway/listen_test.go @@ -0,0 +1,130 @@ +package gateway + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "strconv" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/netbind" +) + +func TestOpenGatewayListeners_HonorsIPv6OnlyHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv6 { + t.Skip("IPv6 is unavailable in this environment") + } + + _, result, err := openGatewayListeners("::", 0) + if err != nil { + t.Fatalf("openGatewayListeners() error = %v", err) + } + startGatewayTestHTTPServer(t, result.Listeners) + port := mustGatewayAtoi(t, result.Port) + + requireGatewayHTTPReachable(t, "::1", port) + if hasIPv4 { + requireGatewayHTTPUnreachable(t, "127.0.0.1", port) + } +} + +func TestOpenGatewayListeners_SupportsExplicitMultiHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + _, result, err := openGatewayListeners("127.0.0.1,::1", 0) + if err != nil { + t.Fatalf("openGatewayListeners() error = %v", err) + } + startGatewayTestHTTPServer(t, result.Listeners) + port := mustGatewayAtoi(t, result.Port) + + requireGatewayHTTPReachable(t, "127.0.0.1", port) + requireGatewayHTTPReachable(t, "::1", port) +} + +func startGatewayTestHTTPServer(t *testing.T, listeners []net.Listener) { + t.Helper() + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + }), + } + + errCh := make(chan error, len(listeners)) + for _, listener := range listeners { + ln := listener + go func() { + errCh <- server.Serve(ln) + }() + } + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(ctx) + for range listeners { + err := <-errCh + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("server.Serve() error = %v", err) + } + } + }) +} + +func requireGatewayHTTPReachable(t *testing.T, host string, port int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for { + err := gatewayHTTPGet(host, port) + if err == nil { + return + } + if time.Now().After(deadline) { + t.Fatalf("expected %s:%d to be reachable: %v", host, port, err) + } + time.Sleep(50 * time.Millisecond) + } +} + +func requireGatewayHTTPUnreachable(t *testing.T, host string, port int) { + t.Helper() + if err := gatewayHTTPGet(host, port); err == nil { + t.Fatalf("expected %s:%d to be unreachable", host, port) + } +} + +func gatewayHTTPGet(host string, port int) error { + client := &http.Client{ + Timeout: 300 * time.Millisecond, + Transport: &http.Transport{ + Proxy: nil, + }, + } + + resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} + +func mustGatewayAtoi(t *testing.T, value string) int { + t.Helper() + n, err := strconv.Atoi(value) + if err != nil { + t.Fatalf("Atoi(%q) error = %v", value, err) + } + return n +} diff --git a/pkg/health/server.go b/pkg/health/server.go index a152d8ab1..22346490c 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -4,10 +4,11 @@ import ( "context" "crypto/subtle" "encoding/json" - "fmt" "maps" + "net" "net/http" "os" + "strconv" "sync" "time" ) @@ -49,7 +50,7 @@ func NewServer(host string, port int, token string) *Server { mux.HandleFunc("/ready", s.readyHandler) mux.HandleFunc("/reload", s.reloadHandler) - addr := fmt.Sprintf("%s:%d", host, port) + addr := net.JoinHostPort(host, strconv.Itoa(port)) s.server = &http.Server{ Addr: addr, Handler: mux, diff --git a/pkg/health/server_test.go b/pkg/health/server_test.go index c4982fff9..31dbc37c0 100644 --- a/pkg/health/server_test.go +++ b/pkg/health/server_test.go @@ -305,6 +305,16 @@ func TestNewServer(t *testing.T) { } } +func TestNewServer_IPv6ListenAddrFormatting(t *testing.T) { + s := NewServer("::", 18790, "") + if s.server == nil { + t.Fatal("server should be initialized") + } + if s.server.Addr != "[::]:18790" { + t.Fatalf("server.Addr = %q, want %q", s.server.Addr, "[::]:18790") + } +} + func TestStartContext_Cancellation(t *testing.T) { s := NewServer("127.0.0.1", 0, "") diff --git a/pkg/netbind/netbind.go b/pkg/netbind/netbind.go new file mode 100644 index 000000000..ae6cacf49 --- /dev/null +++ b/pkg/netbind/netbind.go @@ -0,0 +1,606 @@ +package netbind + +import ( + "context" + "errors" + "fmt" + "net" + "strconv" + "strings" + "sync" +) + +type DefaultMode int + +const ( + DefaultLoopback DefaultMode = iota + DefaultAny +) + +type groupKind int + +const ( + groupAdaptiveLoopback groupKind = iota + groupAdaptiveAny + groupExact +) + +type exactBinding struct { + host string + network string + v6Only bool +} + +type bindGroup struct { + kind groupKind + allowIPv4 bool + allowIPv6 bool + exact exactBinding +} + +type Plan struct { + groups []bindGroup + ProbeHost string +} + +type OpenResult struct { + Listeners []net.Listener + BindHosts []string + Port string + ProbeHost string +} + +type tokenKind int + +const ( + tokenName tokenKind = iota + tokenLocalhost + tokenStar + tokenIPv4 + tokenIPv6 + tokenIPv4Any + tokenIPv6Any +) + +type hostToken struct { + kind tokenKind + canonical string + key string +} + +var ( + ipFamiliesOnce sync.Once + hasIPv4 bool + hasIPv6 bool +) + +func DetectIPFamilies() (bool, bool) { + ipFamiliesOnce.Do(func() { + if ips, err := net.LookupIP("localhost"); err == nil { + for _, ip := range ips { + if ip == nil { + continue + } + if ip.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + } + + if hasIPv4 && hasIPv6 { + return + } + + if addrs, err := net.InterfaceAddrs(); err == nil { + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok || ipnet.IP == nil { + continue + } + if ipnet.IP.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + } + }) + + return hasIPv4, hasIPv6 +} + +func SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string { + switch { + case hasIPv4 && hasIPv6: + return "localhost" + case hasIPv6: + return "::1" + case hasIPv4: + return "127.0.0.1" + default: + return "localhost" + } +} + +func SelectAdaptiveAnyHost(hasIPv4, hasIPv6 bool) string { + switch { + case hasIPv4 && hasIPv6: + return "::" + case hasIPv6: + return "::" + case hasIPv4: + return "0.0.0.0" + default: + return "::" + } +} + +func ResolveAdaptiveLoopbackHost() string { + hasIPv4, hasIPv6 := DetectIPFamilies() + return SelectAdaptiveLoopbackHost(hasIPv4, hasIPv6) +} + +func ResolveAdaptiveAnyHost() string { + hasIPv4, hasIPv6 := DetectIPFamilies() + return SelectAdaptiveAnyHost(hasIPv4, hasIPv6) +} + +func IsLoopbackHost(host string) bool { + host = strings.TrimSpace(host) + if host == "" { + return false + } + if strings.EqualFold(host, "localhost") { + return true + } + ip := net.ParseIP(strings.Trim(host, "[]")) + return ip != nil && ip.IsLoopback() +} + +func IsUnspecifiedHost(host string) bool { + host = strings.TrimSpace(host) + if host == "" { + return false + } + ip := net.ParseIP(strings.Trim(host, "[]")) + return ip != nil && ip.IsUnspecified() +} + +func NormalizeHostInput(raw string) (string, error) { + tokens, err := parseHostTokens(raw) + if err != nil { + return "", err + } + + parts := make([]string, 0, len(tokens)) + for _, token := range tokens { + parts = append(parts, token.canonical) + } + return strings.Join(parts, ","), nil +} + +func BuildPlan(raw string, defaultMode DefaultMode) (Plan, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return buildDefaultPlan(defaultMode), nil + } + + tokens, err := parseHostTokens(raw) + if err != nil { + return Plan{}, err + } + + for _, token := range tokens { + if token.kind == tokenStar { + return Plan{ + groups: []bindGroup{{kind: groupAdaptiveAny}}, + ProbeHost: ResolveAdaptiveLoopbackHost(), + }, nil + } + } + + hasIPv4Any := false + hasIPv6Any := false + for _, token := range tokens { + switch token.kind { + case tokenIPv4Any: + hasIPv4Any = true + case tokenIPv6Any: + hasIPv6Any = true + } + } + + allowLocalhostIPv4 := !hasIPv4Any + allowLocalhostIPv6 := !hasIPv6Any + + groups := make([]bindGroup, 0, len(tokens)) + seenExact := make(map[string]struct{}, len(tokens)) + addedLocalhost := false + + for _, token := range tokens { + switch token.kind { + case tokenLocalhost: + if addedLocalhost || (!allowLocalhostIPv4 && !allowLocalhostIPv6) { + continue + } + groups = append(groups, bindGroup{ + kind: groupAdaptiveLoopback, + allowIPv4: allowLocalhostIPv4, + allowIPv6: allowLocalhostIPv6, + }) + addedLocalhost = true + case tokenIPv4Any: + key := "exact:tcp4:0.0.0.0" + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: "0.0.0.0", + network: "tcp4", + }, + }) + case tokenIPv6Any: + key := "exact:tcp6:::" + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: "::", + network: "tcp6", + v6Only: true, + }, + }) + case tokenIPv4: + if hasIPv4Any { + continue + } + key := "exact:tcp4:" + strings.ToLower(token.canonical) + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: token.canonical, + network: "tcp4", + }, + }) + case tokenIPv6: + if hasIPv6Any { + continue + } + key := "exact:tcp6:" + strings.ToLower(token.canonical) + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: token.canonical, + network: "tcp6", + v6Only: true, + }, + }) + case tokenName: + key := "exact:tcp:" + token.key + if _, ok := seenExact[key]; ok { + continue + } + seenExact[key] = struct{}{} + groups = append(groups, bindGroup{ + kind: groupExact, + exact: exactBinding{ + host: token.canonical, + network: "tcp", + }, + }) + } + } + + plan := Plan{groups: groups} + plan.ProbeHost = probeHostForGroups(groups) + return plan, nil +} + +func OpenPlan(plan Plan, port string) (OpenResult, error) { + if port == "" { + return OpenResult{}, errors.New("port cannot be empty") + } + + selectedPort := port + listeners := make([]net.Listener, 0, len(plan.groups)) + bindHosts := make([]string, 0, len(plan.groups)) + bindSeen := make(map[string]struct{}, len(plan.groups)) + + closeAll := func() { + for _, ln := range listeners { + _ = ln.Close() + } + } + + for _, group := range plan.groups { + groupListeners, groupHosts, actualPort, err := openGroup(group, selectedPort) + if err != nil { + closeAll() + return OpenResult{}, err + } + if selectedPort == "0" && actualPort != "" { + selectedPort = actualPort + } + listeners = append(listeners, groupListeners...) + for _, host := range groupHosts { + key := strings.ToLower(host) + if _, ok := bindSeen[key]; ok { + continue + } + bindSeen[key] = struct{}{} + bindHosts = append(bindHosts, host) + } + } + + return OpenResult{ + Listeners: listeners, + BindHosts: bindHosts, + Port: selectedPort, + ProbeHost: plan.ProbeHost, + }, nil +} + +func buildDefaultPlan(defaultMode DefaultMode) Plan { + switch defaultMode { + case DefaultAny: + return Plan{ + groups: []bindGroup{{kind: groupAdaptiveAny}}, + ProbeHost: ResolveAdaptiveLoopbackHost(), + } + default: + return Plan{ + groups: []bindGroup{{ + kind: groupAdaptiveLoopback, + allowIPv4: true, + allowIPv6: true, + }}, + ProbeHost: ResolveAdaptiveLoopbackHost(), + } + } +} + +func probeHostForGroups(groups []bindGroup) string { + hasIPv4Any := false + hasIPv6Any := false + for _, group := range groups { + if group.kind == groupAdaptiveLoopback { + switch { + case group.allowIPv4 && group.allowIPv6: + return ResolveAdaptiveLoopbackHost() + case group.allowIPv6: + return "::1" + case group.allowIPv4: + return "127.0.0.1" + } + } + if group.kind == groupAdaptiveAny { + return ResolveAdaptiveLoopbackHost() + } + if group.kind != groupExact { + continue + } + switch group.exact.host { + case "0.0.0.0": + hasIPv4Any = true + case "::": + hasIPv6Any = true + } + } + + switch { + case hasIPv4Any && hasIPv6Any: + return ResolveAdaptiveLoopbackHost() + case hasIPv6Any: + return "::1" + case hasIPv4Any: + return "127.0.0.1" + } + + for _, group := range groups { + if group.kind == groupExact { + return group.exact.host + } + } + return ResolveAdaptiveLoopbackHost() +} + +func parseHostTokens(raw string) ([]hostToken, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, errors.New("host cannot be empty") + } + + parts := strings.Split(raw, ",") + tokens := make([]hostToken, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + for _, part := range parts { + token, err := parseHostToken(part) + if err != nil { + return nil, err + } + if _, ok := seen[token.key]; ok { + continue + } + seen[token.key] = struct{}{} + tokens = append(tokens, token) + } + + if len(tokens) == 0 { + return nil, errors.New("host cannot be empty") + } + + return tokens, nil +} + +func parseHostToken(raw string) (hostToken, error) { + host := strings.TrimSpace(raw) + if host == "" { + return hostToken{}, errors.New("host list contains an empty entry") + } + + if host == "*" { + return hostToken{kind: tokenStar, canonical: "*", key: "*"}, nil + } + if strings.EqualFold(host, "localhost") { + return hostToken{kind: tokenLocalhost, canonical: "localhost", key: "localhost"}, nil + } + + trimmed := strings.Trim(host, "[]") + if ip := net.ParseIP(trimmed); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + canonical := ip4.String() + kind := tokenIPv4 + if ip4.IsUnspecified() { + kind = tokenIPv4Any + } + return hostToken{kind: kind, canonical: canonical, key: canonical}, nil + } + + canonical := ip.String() + kind := tokenIPv6 + if ip.IsUnspecified() { + kind = tokenIPv6Any + } + return hostToken{kind: kind, canonical: canonical, key: strings.ToLower(canonical)}, nil + } + + return hostToken{ + kind: tokenName, + canonical: host, + key: strings.ToLower(host), + }, nil +} + +func openGroup(group bindGroup, port string) ([]net.Listener, []string, string, error) { + switch group.kind { + case groupAdaptiveLoopback: + return openAdaptiveLoopbackGroup(group.allowIPv6, group.allowIPv4, port) + case groupAdaptiveAny: + return openAdaptiveAnyGroup(port) + case groupExact: + ln, actualPort, err := openExactListener(group.exact, port) + if err != nil { + return nil, nil, "", err + } + return []net.Listener{ln}, []string{group.exact.host}, actualPort, nil + default: + return nil, nil, "", fmt.Errorf("unsupported bind group kind: %d", group.kind) + } +} + +func openAdaptiveLoopbackGroup(allowIPv6, allowIPv4 bool, port string) ([]net.Listener, []string, string, error) { + if allowIPv6 && allowIPv4 { + if ln6, actualPort, err6 := openExactListener( + exactBinding{host: "::1", network: "tcp6", v6Only: true}, + port, + ); err6 == nil { + if ln4, _, err4 := openExactListener( + exactBinding{host: "127.0.0.1", network: "tcp4"}, + actualPort, + ); err4 == nil { + return []net.Listener{ln6, ln4}, []string{"::1", "127.0.0.1"}, actualPort, nil + } + _ = ln6.Close() + } + } + + if allowIPv6 { + ln6, actualPort, err := openExactListener(exactBinding{host: "::1", network: "tcp6", v6Only: true}, port) + if err == nil { + return []net.Listener{ln6}, []string{"::1"}, actualPort, nil + } + } + + if allowIPv4 { + ln4, actualPort, err := openExactListener(exactBinding{host: "127.0.0.1", network: "tcp4"}, port) + if err == nil { + return []net.Listener{ln4}, []string{"127.0.0.1"}, actualPort, nil + } + } + + return nil, nil, "", fmt.Errorf("failed to open adaptive localhost listener on port %s", port) +} + +func openAdaptiveAnyGroup(port string) ([]net.Listener, []string, string, error) { + hasIPv4, hasIPv6 := DetectIPFamilies() + + if hasIPv4 && hasIPv6 { + if ln6, actualPort, err6 := openExactListener( + exactBinding{host: "::", network: "tcp6", v6Only: true}, + port, + ); err6 == nil { + if ln4, _, err4 := openExactListener( + exactBinding{host: "0.0.0.0", network: "tcp4"}, + actualPort, + ); err4 == nil { + return []net.Listener{ln6, ln4}, []string{"::", "0.0.0.0"}, actualPort, nil + } + _ = ln6.Close() + } + } + + if hasIPv6 { + ln6, actualPort, err := openExactListener(exactBinding{host: "::", network: "tcp6", v6Only: true}, port) + if err == nil { + return []net.Listener{ln6}, []string{"::"}, actualPort, nil + } + } + + if hasIPv4 { + ln4, actualPort, err := openExactListener(exactBinding{host: "0.0.0.0", network: "tcp4"}, port) + if err == nil { + return []net.Listener{ln4}, []string{"0.0.0.0"}, actualPort, nil + } + } + + return nil, nil, "", fmt.Errorf("failed to open adaptive any-host listener on port %s", port) +} + +func openExactListener(binding exactBinding, port string) (net.Listener, string, error) { + listenConfig := net.ListenConfig{} + if binding.network == "tcp6" && binding.v6Only { + listenConfig.Control = applyIPv6OnlyControl(true) + } + + ln, err := listenConfig.Listen(context.Background(), binding.network, net.JoinHostPort(binding.host, port)) + if err != nil { + return nil, "", err + } + + actualPort, err := listenerPort(ln) + if err != nil { + _ = ln.Close() + return nil, "", err + } + + return ln, actualPort, nil +} + +func listenerPort(ln net.Listener) (string, error) { + addr, ok := ln.Addr().(*net.TCPAddr) + if ok { + return strconv.Itoa(addr.Port), nil + } + + _, port, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + return "", err + } + return port, nil +} diff --git a/pkg/netbind/netbind_test.go b/pkg/netbind/netbind_test.go new file mode 100644 index 000000000..20b7ff141 --- /dev/null +++ b/pkg/netbind/netbind_test.go @@ -0,0 +1,280 @@ +package netbind + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "strconv" + "testing" + "time" +) + +func TestNormalizeHostInput(t *testing.T) { + tests := []struct { + name string + raw string + want string + wantErr bool + }{ + {name: "single host", raw: "127.0.0.1", want: "127.0.0.1"}, + {name: "trim and dedupe", raw: " [::1] , ::1 , 127.0.0.1 ", want: "::1,127.0.0.1"}, + {name: "star preserved", raw: "*,127.0.0.1", want: "*,127.0.0.1"}, + {name: "reject empty", raw: "127.0.0.1, ", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeHostInput(tt.raw) + if (err != nil) != tt.wantErr { + t.Fatalf("NormalizeHostInput() err = %v, wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + return + } + if got != tt.want { + t.Fatalf("NormalizeHostInput() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestBuildPlan_DefaultAnyUsesLoopbackProbe(t *testing.T) { + plan, err := BuildPlan("", DefaultAny) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + if plan.ProbeHost != ResolveAdaptiveLoopbackHost() { + t.Fatalf("ProbeHost = %q, want %q", plan.ProbeHost, ResolveAdaptiveLoopbackHost()) + } +} + +func TestOpenPlan_LocalhostSupportsLoopbackCommunication(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + + plan, err := BuildPlan("localhost", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + if hasIPv6 { + requireHTTPReachable(t, "::1", port) + } + if hasIPv4 { + requireHTTPReachable(t, "127.0.0.1", port) + } +} + +func TestOpenPlan_DefaultAnySupportsDualStackLoopback(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + + plan, err := BuildPlan("", DefaultAny) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + if hasIPv6 { + requireHTTPReachable(t, "::1", port) + } + if hasIPv4 { + requireHTTPReachable(t, "127.0.0.1", port) + } + + switch { + case hasIPv4 && hasIPv6: + if len(result.BindHosts) != 2 { + t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts) + } + case hasIPv6 || hasIPv4: + if len(result.BindHosts) != 1 { + t.Fatalf("len(BindHosts) = %d, want 1 (%#v)", len(result.BindHosts), result.BindHosts) + } + } +} + +func TestOpenPlan_ExplicitIPv6AnyIsIPv6Only(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv6 { + t.Skip("IPv6 is unavailable in this environment") + } + + plan, err := BuildPlan("::", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "::1", port) + if hasIPv4 { + requireHTTPUnreachable(t, "127.0.0.1", port) + } +} + +func TestOpenPlan_ExplicitIPv4AnyIsIPv4Only(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv4 { + t.Skip("IPv4 is unavailable in this environment") + } + + plan, err := BuildPlan("0.0.0.0", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "127.0.0.1", port) + if hasIPv6 { + requireHTTPUnreachable(t, "::1", port) + } +} + +func TestOpenPlan_MultiHostSupportsExplicitIPv4AndIPv6(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + plan, err := BuildPlan("127.0.0.1,::1", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "127.0.0.1", port) + requireHTTPReachable(t, "::1", port) +} + +func TestOpenPlan_WildcardRulesKeepIPv4AndIPv6AnyHosts(t *testing.T) { + hasIPv4, hasIPv6 := DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + plan, err := BuildPlan("::,::1,0.0.0.0,127.0.0.1", DefaultLoopback) + if err != nil { + t.Fatalf("BuildPlan() error = %v", err) + } + result, err := OpenPlan(plan, "0") + if err != nil { + t.Fatalf("OpenPlan() error = %v", err) + } + startTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireHTTPReachable(t, "127.0.0.1", port) + requireHTTPReachable(t, "::1", port) + if len(result.BindHosts) != 2 { + t.Fatalf("len(BindHosts) = %d, want 2 (%#v)", len(result.BindHosts), result.BindHosts) + } +} + +func startTestHTTPServer(t *testing.T, listeners []net.Listener) { + t.Helper() + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + }), + } + + errCh := make(chan error, len(listeners)) + for _, listener := range listeners { + ln := listener + go func() { + errCh <- server.Serve(ln) + }() + } + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(ctx) + for range listeners { + err := <-errCh + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("server.Serve() error = %v", err) + } + } + }) +} + +func requireHTTPReachable(t *testing.T, host string, port int) { + t.Helper() + + deadline := time.Now().Add(2 * time.Second) + for { + err := httpGET(host, port) + if err == nil { + return + } + if time.Now().After(deadline) { + t.Fatalf("expected %s:%d to be reachable: %v", host, port, err) + } + time.Sleep(50 * time.Millisecond) + } +} + +func requireHTTPUnreachable(t *testing.T, host string, port int) { + t.Helper() + + if err := httpGET(host, port); err == nil { + t.Fatalf("expected %s:%d to be unreachable", host, port) + } +} + +func httpGET(host string, port int) error { + client := &http.Client{ + Timeout: 300 * time.Millisecond, + Transport: &http.Transport{ + Proxy: nil, + }, + } + + resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} + +func mustAtoi(t *testing.T, value string) int { + t.Helper() + n, err := strconv.Atoi(value) + if err != nil { + t.Fatalf("Atoi(%q) error = %v", value, err) + } + return n +} diff --git a/pkg/netbind/socket_v6only_unix.go b/pkg/netbind/socket_v6only_unix.go new file mode 100644 index 000000000..20cf7bbce --- /dev/null +++ b/pkg/netbind/socket_v6only_unix.go @@ -0,0 +1,25 @@ +//go:build !windows + +package netbind + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error { + return func(_, _ string, rawConn syscall.RawConn) error { + var controlErr error + if err := rawConn.Control(func(fd uintptr) { + value := 0 + if enabled { + value = 1 + } + controlErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, value) + }); err != nil { + return err + } + return controlErr + } +} diff --git a/pkg/netbind/socket_v6only_windows.go b/pkg/netbind/socket_v6only_windows.go new file mode 100644 index 000000000..006b4e1ac --- /dev/null +++ b/pkg/netbind/socket_v6only_windows.go @@ -0,0 +1,25 @@ +//go:build windows + +package netbind + +import ( + "syscall" + + "golang.org/x/sys/windows" +) + +func applyIPv6OnlyControl(enabled bool) func(string, string, syscall.RawConn) error { + return func(_, _ string, rawConn syscall.RawConn) error { + var controlErr error + if err := rawConn.Control(func(fd uintptr) { + value := 0 + if enabled { + value = 1 + } + controlErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, value) + }); err != nil { + return err + } + return controlErr + } +} diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 0dec45cba..fa5652323 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" ppid "github.com/sipeed/picoclaw/pkg/pid" "github.com/sipeed/picoclaw/web/backend/utils" ) @@ -119,6 +120,7 @@ var ( gatewayRestartGracePeriod = 5 * time.Second gatewayRestartForceKillWindow = 3 * time.Second gatewayRestartPollInterval = 100 * time.Millisecond + gatewayExecCommand = exec.Command ) var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { @@ -262,7 +264,7 @@ func (h *Handler) getGatewayHealthForPidData( host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) } if host == "" { - host = "127.0.0.1" + host = netbind.ResolveAdaptiveLoopbackHost() } url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health" @@ -723,7 +725,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int execPath := utils.FindPicoclawBinary() logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath)) - cmd = exec.Command(execPath, h.gatewayCommandArgs()...) + cmd = gatewayExecCommand(execPath, h.gatewayCommandArgs()...) cmd.Env = os.Environ() // Forward the launcher's config path via the environment variable that // GetConfigPath() already reads, so the gateway sub-process uses the same @@ -731,8 +733,9 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int if h.configPath != "" { cmd.Env = append(cmd.Env, config.EnvConfig+"="+h.configPath) } - if host := h.gatewayHostOverride(); host != "" { - cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+host) + gatewayHostOverride := h.gatewayHostOverride() + if gatewayHostOverride != "" { + cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+gatewayHostOverride) } stdoutPipe, err := cmd.StdoutPipe() diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index f8e8eadba..c6c2073e2 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -8,9 +8,15 @@ import ( "strings" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/netbind" ) func (h *Handler) effectiveLauncherPublic() bool { + if h.serverHostExplicit { + // -host takes precedence over -public and launcher-config public setting. + return false + } + if h.serverPublicExplicit { return h.serverPublic } @@ -24,8 +30,11 @@ func (h *Handler) effectiveLauncherPublic() bool { } func (h *Handler) gatewayHostOverride() string { + if h.serverHostExplicit { + return strings.TrimSpace(h.serverHostInput) + } if h.effectiveLauncherPublic() { - return "0.0.0.0" + return "*" } return "" } @@ -41,10 +50,11 @@ func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string { } func gatewayProbeHost(bindHost string) string { - if bindHost == "" || bindHost == "0.0.0.0" { - return "127.0.0.1" + plan, err := netbind.BuildPlan(bindHost, netbind.DefaultLoopback) + if err != nil || strings.TrimSpace(plan.ProbeHost) == "" { + return netbind.ResolveAdaptiveLoopbackHost() } - return bindHost + return plan.ProbeHost } func (h *Handler) gatewayProxyURL() *url.URL { @@ -72,7 +82,7 @@ func requestHostName(r *http.Request) string { if strings.TrimSpace(r.Host) != "" { return r.Host } - return "127.0.0.1" + return netbind.ResolveAdaptiveLoopbackHost() } func requestWSScheme(r *http.Request) string { diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index 7150b6fee..d0fc26d7b 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -3,6 +3,7 @@ package api import ( "crypto/tls" "errors" + "net" "net/http" "net/http/httptest" "path/filepath" @@ -10,6 +11,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/web/backend/launcherconfig" ) @@ -26,8 +28,8 @@ func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) { h := NewHandler(configPath) h.SetServerOptions(18800, true, true, nil) - if got := h.gatewayHostOverride(); got != "0.0.0.0" { - t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0") + if got := h.gatewayHostOverride(); got != "*" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "*") } } @@ -64,8 +66,36 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) { } func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) { - if got := gatewayProbeHost("0.0.0.0"); got != "127.0.0.1" { - t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1") + want := "127.0.0.1" + if got := gatewayProbeHost("0.0.0.0"); got != want { + t.Fatalf("gatewayProbeHost() = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesPreferredLoopbackForEmptyBind(t *testing.T) { + want := netbind.ResolveAdaptiveLoopbackHost() + if got := gatewayProbeHost(""); got != want { + t.Fatalf("gatewayProbeHost(empty) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesPreferredLoopbackForLocalhostBind(t *testing.T) { + want := netbind.ResolveAdaptiveLoopbackHost() + if got := gatewayProbeHost("localhost"); got != want { + t.Fatalf("gatewayProbeHost(localhost) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesLoopbackForIPv6WildcardBind(t *testing.T) { + want := "::1" + if got := gatewayProbeHost("::"); got != want { + t.Fatalf("gatewayProbeHost(::) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesFirstConcreteHostForMultiHostBind(t *testing.T) { + if got := gatewayProbeHost("127.0.0.1,::1"); got != "127.0.0.1" { + t.Fatalf("gatewayProbeHost(multi) = %q, want %q", got, "127.0.0.1") } } @@ -137,8 +167,9 @@ func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) { _ = statusCode _ = err - if requestedURL != "http://127.0.0.1:18791/health" { - t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health") + want := "http://" + net.JoinHostPort(netbind.ResolveAdaptiveLoopbackHost(), "18791") + "/health" + if requestedURL != want { + t.Fatalf("health url = %q, want %q", requestedURL, want) } } @@ -240,3 +271,43 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) { t.Fatalf("buildWsURL() = %q, want %q", got, "ws://localhost:18800/pico/ws") } } + +func TestGatewayHostOverrideWithExplicitHostAndAlignedGatewayHost(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, false, false, nil) + h.SetServerBindHost("0.0.0.0", true) + + if got := h.gatewayHostOverride(); got != "0.0.0.0" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0") + } +} + +func TestGatewayHostOverrideWithExplicitHostAndLocalhostGatewayHost(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, false, false, nil) + h.SetServerBindHost("::", true) + + if got := h.gatewayHostOverride(); got != "::" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "::") + } +} + +func TestGatewayHostOverrideWithExplicitMultiHost(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, false, false, nil) + h.SetServerBindHost("127.0.0.1,::1", true) + + if got := h.gatewayHostOverride(); got != "127.0.0.1,::1" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "127.0.0.1,::1") + } +} + +func TestGatewayHostExplicitIgnoresPublicFlag(t *testing.T) { + h := NewHandler(filepath.Join(t.TempDir(), "config.json")) + h.SetServerOptions(18800, true, true, nil) + h.SetServerBindHost("127.0.0.1", true) + + if got := h.effectiveLauncherPublic(); got { + t.Fatalf("effectiveLauncherPublic() = %t, want false when explicit host is set", got) + } +} diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index d300b657c..78bf34a63 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -97,6 +97,7 @@ func resetGatewayTestState(t *testing.T) { originalHealthGet := gatewayHealthGet originalProcessMatcher := gatewayProcessMatcher + originalExecCommand := gatewayExecCommand originalRestartGracePeriod := gatewayRestartGracePeriod originalRestartForceKillWindow := gatewayRestartForceKillWindow originalRestartPollInterval := gatewayRestartPollInterval @@ -104,6 +105,7 @@ func resetGatewayTestState(t *testing.T) { t.Cleanup(func() { gatewayHealthGet = originalHealthGet gatewayProcessMatcher = originalProcessMatcher + gatewayExecCommand = originalExecCommand gatewayRestartGracePeriod = originalRestartGracePeriod gatewayRestartForceKillWindow = originalRestartForceKillWindow gatewayRestartPollInterval = originalRestartPollInterval @@ -119,6 +121,159 @@ func resetGatewayTestState(t *testing.T) { }) } +type gatewayStartEnvSnapshot struct { + GatewayHost string `json:"gateway_host"` + GatewayHostSet bool `json:"gateway_host_set"` + ConfigPath string `json:"config_path"` +} + +func TestGatewayStartHelperProcess(t *testing.T) { + var envPath string + for i, arg := range os.Args { + if arg == "--" && i+2 < len(os.Args) && os.Args[i+1] == "gateway-env-helper" { + envPath = os.Args[i+2] + break + } + } + if envPath == "" { + t.Skip("helper process") + } + + host, ok := os.LookupEnv(config.EnvGatewayHost) + raw, err := json.Marshal(gatewayStartEnvSnapshot{ + GatewayHost: host, + GatewayHostSet: ok, + ConfigPath: os.Getenv(config.EnvConfig), + }) + if err != nil { + _, _ = io.WriteString(os.Stderr, err.Error()) + os.Exit(2) + } + if err := os.WriteFile(envPath, raw, 0o600); err != nil { + _, _ = io.WriteString(os.Stderr, err.Error()) + os.Exit(2) + } + os.Exit(0) +} + +func unsetGatewayStartEnvForTest(t *testing.T, key string) { + t.Helper() + + prev, hadPrev := os.LookupEnv(key) + if err := os.Unsetenv(key); err != nil { + t.Fatalf("Unsetenv(%q) error = %v", key, err) + } + t.Cleanup(func() { + if hadPrev { + _ = os.Setenv(key, prev) + return + } + _ = os.Unsetenv(key) + }) +} + +func newGatewayStartTestHandler(t *testing.T) *Handler { + t.Helper() + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + h.SetServerOptions(18800, false, false, nil) + return h +} + +func startGatewayAndCaptureEnv(t *testing.T, h *Handler) gatewayStartEnvSnapshot { + t.Helper() + + unsetGatewayStartEnvForTest(t, config.EnvGatewayHost) + + envPath := filepath.Join(t.TempDir(), "gateway-child-env.json") + gatewayExecCommand = func(_ string, _ ...string) *exec.Cmd { + return exec.Command( + os.Args[0], + "-test.run=TestGatewayStartHelperProcess", + "--", + "gateway-env-helper", + envPath, + ) + } + + pid, err := h.startGatewayLocked("starting", 0) + if err != nil { + t.Fatalf("startGatewayLocked() error = %v", err) + } + if pid <= 0 { + t.Fatalf("startGatewayLocked() pid = %d, want > 0", pid) + } + + deadline := time.Now().Add(3 * time.Second) + for { + raw, err := os.ReadFile(envPath) + if err == nil { + var snapshot gatewayStartEnvSnapshot + err = json.Unmarshal(raw, &snapshot) + if err != nil { + t.Fatalf("Unmarshal(child env) error = %v", err) + } + return snapshot + } + if !os.IsNotExist(err) { + t.Fatalf("ReadFile(%q) error = %v", envPath, err) + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for gateway child env snapshot %q", envPath) + } + time.Sleep(20 * time.Millisecond) + } +} + +func TestStartGatewayLocked_ForwardsLauncherHostOverrideToGatewayEnv(t *testing.T) { + h := newGatewayStartTestHandler(t) + h.SetServerBindHost("127.0.0.1,::1", true) + + snapshot := startGatewayAndCaptureEnv(t, h) + if !snapshot.GatewayHostSet { + t.Fatal("gateway host env was not set") + } + if snapshot.GatewayHost != "127.0.0.1,::1" { + t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "127.0.0.1,::1") + } + if snapshot.ConfigPath != h.configPath { + t.Fatalf("config env = %q, want %q", snapshot.ConfigPath, h.configPath) + } +} + +func TestStartGatewayLocked_ForwardsLauncherHostFromEnvironmentToGatewayEnv(t *testing.T) { + h := newGatewayStartTestHandler(t) + h.SetServerBindHost("::", true) + + snapshot := startGatewayAndCaptureEnv(t, h) + if !snapshot.GatewayHostSet { + t.Fatal("gateway host env was not set") + } + if snapshot.GatewayHost != "::" { + t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "::") + } +} + +func TestStartGatewayLocked_ForwardsWildcardHostForPublicLauncher(t *testing.T) { + h := newGatewayStartTestHandler(t) + h.SetServerOptions(18800, true, true, nil) + + snapshot := startGatewayAndCaptureEnv(t, h) + if !snapshot.GatewayHostSet { + t.Fatal("gateway host env was not set") + } + if snapshot.GatewayHost != "*" { + t.Fatalf("gateway host env = %q, want %q", snapshot.GatewayHost, "*") + } +} + func TestGatewayStartReady_NoDefaultModel(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) diff --git a/web/backend/api/router.go b/web/backend/api/router.go index c6781baf1..76f63607e 100644 --- a/web/backend/api/router.go +++ b/web/backend/api/router.go @@ -2,6 +2,7 @@ package api import ( "net/http" + "strings" "sync" "github.com/sipeed/picoclaw/web/backend/launcherconfig" @@ -13,6 +14,8 @@ type Handler struct { serverPort int serverPublic bool serverPublicExplicit bool + serverHostInput string + serverHostExplicit bool serverCIDRs []string debug bool oauthMu sync.Mutex @@ -41,9 +44,21 @@ func (h *Handler) SetServerOptions(port int, public bool, publicExplicit bool, a h.serverPort = port h.serverPublic = public h.serverPublicExplicit = publicExplicit + h.serverHostInput = "" + h.serverHostExplicit = false h.serverCIDRs = append([]string(nil), allowedCIDRs...) } +// SetServerBindHost stores the launcher's effective bind host. +// When explicit is true, hostInput is the normalized -host / PICOCLAW_LAUNCHER_HOST value. +func (h *Handler) SetServerBindHost(hostInput string, explicit bool) { + h.serverHostInput = strings.TrimSpace(hostInput) + if !explicit { + h.serverHostInput = "" + } + h.serverHostExplicit = explicit +} + func (h *Handler) SetDebug(debug bool) { h.debug = debug } diff --git a/web/backend/app_runtime.go b/web/backend/app_runtime.go index ab564db2c..a06396526 100644 --- a/web/backend/app_runtime.go +++ b/web/backend/app_runtime.go @@ -34,22 +34,30 @@ func shutdownApp() { apiHandler.Shutdown() } - if server != nil { - // Disable keep-alive to allow graceful shutdown - server.SetKeepAlivesEnabled(false) - - ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) - defer cancel() - if err := server.Shutdown(ctx); err != nil { - // Context deadline exceeded is expected if there are active connections - // This is not necessarily an error, so log it at info level - if errors.Is(err, context.DeadlineExceeded) { - logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout) - } else { - logger.Errorf("Server shutdown error: %v", err) + if len(servers) > 0 { + for _, srv := range servers { + if srv == nil { + continue + } + + // Disable keep-alive to allow graceful shutdown + srv.SetKeepAlivesEnabled(false) + + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + err := srv.Shutdown(ctx) + cancel() + + if err != nil { + // Context deadline exceeded is expected if there are active connections + // This is not necessarily an error, so log it at info level + if errors.Is(err, context.DeadlineExceeded) { + logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout) + } else { + logger.Errorf("Server shutdown error: %v", err) + } + } else { + logger.Infof("Server shutdown completed successfully") } - } else { - logger.Infof("Server shutdown completed successfully") } } } diff --git a/web/backend/launcherconfig/config.go b/web/backend/launcherconfig/config.go index 60c369f4f..b6faa63fe 100644 --- a/web/backend/launcherconfig/config.go +++ b/web/backend/launcherconfig/config.go @@ -16,6 +16,10 @@ const ( FileName = "launcher-config.json" // DefaultPort is the default port for the web launcher. DefaultPort = 18800 + // EnvLauncherToken overrides launcher dashboard token. + EnvLauncherToken = "PICOCLAW_LAUNCHER_TOKEN" + // EnvLauncherHost overrides launcher listen host. + EnvLauncherHost = "PICOCLAW_LAUNCHER_HOST" // dashboardSigningKeyBytes is the HMAC-SHA256 key size (256 bits). dashboardSigningKeyBytes = 32 @@ -59,7 +63,7 @@ func Validate(cfg Config) error { // EnsureDashboardSecrets returns signing key bytes and the effective dashboard token for this // process. The signing key is freshly random each call; the token comes from -// PICOCLAW_LAUNCHER_TOKEN when set, otherwise launcher-config.json launcher_token, +// EnvLauncherToken when set, otherwise launcher-config.json launcher_token, // otherwise a new random token. func EnsureDashboardSecrets( cfg Config, @@ -69,7 +73,7 @@ func EnsureDashboardSecrets( return "", nil, "", err } - effectiveToken = strings.TrimSpace(os.Getenv("PICOCLAW_LAUNCHER_TOKEN")) + effectiveToken = strings.TrimSpace(os.Getenv(EnvLauncherToken)) if effectiveToken != "" { return effectiveToken, signingKey, DashboardTokenSourceEnv, nil } diff --git a/web/backend/main.go b/web/backend/main.go index c5d25f6ef..7f776ff3f 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -15,17 +15,20 @@ import ( "errors" "flag" "fmt" + "net" "net/http" "net/url" "os" "os/signal" "path/filepath" "strconv" + "strings" "syscall" "time" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/web/backend/api" "github.com/sipeed/picoclaw/web/backend/dashboardauth" "github.com/sipeed/picoclaw/web/backend/launcherconfig" @@ -44,7 +47,7 @@ const ( var ( appVersion = config.Version - server *http.Server + servers []*http.Server serverAddr string // browserLaunchURL is opened by openBrowser() (auto-open + tray "open console"). // Includes ?token= for same-machine dashboard login; keep serverAddr without secrets for other use. @@ -65,6 +68,255 @@ func dashboardTokenConfigHelpPath(source launcherconfig.DashboardTokenSource, la return launcherPath } +func resolveLauncherHostInput(flagHost string, explicitFlag bool, envHost string) (string, bool, error) { + if explicitFlag { + normalized, err := netbind.NormalizeHostInput(flagHost) + if err != nil { + return "", false, err + } + return normalized, true, nil + } + + envHost = strings.TrimSpace(envHost) + if envHost == "" { + return "", false, nil + } + + normalized, err := netbind.NormalizeHostInput(envHost) + if err != nil { + return "", false, err + } + return normalized, true, nil +} + +func openLauncherListeners(hostInput string, public bool, port string) (netbind.OpenResult, error) { + defaultMode := netbind.DefaultLoopback + if strings.TrimSpace(hostInput) == "" && public { + defaultMode = netbind.DefaultAny + } + + plan, err := netbind.BuildPlan(hostInput, defaultMode) + if err != nil { + return netbind.OpenResult{}, err + } + return netbind.OpenPlan(plan, port) +} + +func appendUniqueHost(hosts []string, seen map[string]struct{}, host string) []string { + host = strings.TrimSpace(host) + if host == "" { + return hosts + } + key := strings.ToLower(host) + if _, ok := seen[key]; ok { + return hosts + } + seen[key] = struct{}{} + return append(hosts, host) +} + +func hasWildcardBindHosts(bindHosts []string) bool { + for _, bindHost := range bindHosts { + if netbind.IsUnspecifiedHost(bindHost) { + return true + } + } + return false +} + +func wildcardBindHostFamilies(bindHosts []string) (hasIPv4, hasIPv6 bool) { + for _, bindHost := range bindHosts { + host := strings.TrimSpace(bindHost) + if host == "" { + continue + } + + if !netbind.IsUnspecifiedHost(host) { + continue + } + + ip := net.ParseIP(strings.Trim(host, "[]")) + if ip == nil { + continue + } + if ip.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + + return hasIPv4, hasIPv6 +} + +func wildcardAdvertiseIP(bindHosts []string, ipv4, ipv6 string) string { + hasIPv4Wildcard, hasIPv6Wildcard := wildcardBindHostFamilies(bindHosts) + v4 := strings.TrimSpace(ipv4) + v6 := strings.TrimSpace(ipv6) + + switch { + case hasIPv4Wildcard && hasIPv6Wildcard: + if v6 != "" { + return v6 + } + return v4 + case hasIPv6Wildcard: + return v6 + case hasIPv4Wildcard: + return v4 + default: + return "" + } +} + +func advertiseIPForWildcardBindHosts(bindHosts []string) string { + return wildcardAdvertiseIP(bindHosts, utils.GetLocalIPv4(), utils.GetLocalIPv6()) +} + +func appendLauncherConsoleHostList(hosts []string, seen map[string]struct{}, values []string) []string { + for _, value := range values { + hosts = appendUniqueHost(hosts, seen, value) + } + return hosts +} + +func shouldShowLocalhostConsoleEntry(hostInput string) bool { + normalizedHostInput := strings.TrimSpace(hostInput) + if normalizedHostInput == "" { + return true + } + + for token := range strings.SplitSeq(normalizedHostInput, ",") { + token = strings.TrimSpace(token) + if token == "" { + continue + } + if token == "*" || strings.EqualFold(token, "localhost") { + return true + } + + ip := net.ParseIP(strings.Trim(token, "[]")) + if ip == nil { + continue + } + if ip4 := ip.To4(); ip4 != nil { + if ip4.String() == "127.0.0.1" || ip4.String() == "0.0.0.0" { + return true + } + continue + } + if ip.String() == "::1" || ip.String() == "::" { + return true + } + } + + return false +} + +func isConsoleDisplayGlobalIPv6(ip net.IP) bool { + if ip == nil || ip.IsLoopback() || ip.To4() != nil { + return false + } + ip = ip.To16() + if ip == nil { + return false + } + return ip[0]&0xe0 == 0x20 +} + +func launcherConsoleHostsWithLocalAddrs( + hostInput string, + public bool, + ipv4s []string, + globalIPv6s []string, +) []string { + hosts := make([]string, 0, 8) + seen := make(map[string]struct{}, 8) + + if shouldShowLocalhostConsoleEntry(hostInput) { + hosts = appendUniqueHost(hosts, seen, "localhost") + } + + normalizedHostInput := strings.TrimSpace(hostInput) + if normalizedHostInput == "" { + if public { + hosts = appendLauncherConsoleHostList(hosts, seen, globalIPv6s) + hosts = appendLauncherConsoleHostList(hosts, seen, ipv4s) + } + return hosts + } + + hasStar := false + hasIPv4Any := false + hasIPv6Any := false + for _, token := range strings.Split(normalizedHostInput, ",") { + switch strings.TrimSpace(token) { + case "*": + hasStar = true + case "0.0.0.0": + hasIPv4Any = true + case "::": + hasIPv6Any = true + } + } + + if hasStar { + hosts = appendLauncherConsoleHostList(hosts, seen, globalIPv6s) + hosts = appendLauncherConsoleHostList(hosts, seen, ipv4s) + return hosts + } + + for _, token := range strings.Split(normalizedHostInput, ",") { + token = strings.TrimSpace(token) + if token == "" || strings.EqualFold(token, "localhost") || netbind.IsLoopbackHost(token) { + continue + } + + ip := net.ParseIP(strings.Trim(token, "[]")) + switch { + case token == "::": + hosts = appendLauncherConsoleHostList(hosts, seen, globalIPv6s) + case token == "0.0.0.0": + hosts = appendLauncherConsoleHostList(hosts, seen, ipv4s) + case ip != nil && ip.To4() != nil: + if hasIPv4Any { + continue + } + hosts = appendUniqueHost(hosts, seen, ip.String()) + case ip != nil: + if hasIPv6Any { + continue + } + if isConsoleDisplayGlobalIPv6(ip) { + hosts = appendUniqueHost(hosts, seen, ip.String()) + } + default: + hosts = appendUniqueHost(hosts, seen, token) + } + } + + return hosts +} + +func launcherConsoleHosts(hostInput string, public bool) []string { + return launcherConsoleHostsWithLocalAddrs( + hostInput, + public, + utils.GetLocalIPv4s(), + utils.GetGlobalIPv6s(), + ) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} + // maskSecret masks a secret for display. It always shows up to the first 3 // runes. The last 4 runes are only appended when at least 5 runes remain // hidden in the middle (i.e. string length >= 12), so an 8-char minimum @@ -85,7 +337,8 @@ func maskSecret(s string) string { func main() { port := flag.String("port", "18800", "Port to listen on") - public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only") + host := flag.String("host", "", "Host to listen on (overrides -public when set)") + public := flag.Bool("public", false, "Listen on all interfaces (dual-stack) instead of localhost only") noBrowser = flag.Bool("no-browser", false, "Do not auto-open browser on startup") lang := flag.String("lang", "", "Language: en (English) or zh (Chinese). Default: auto-detect from system locale") console := flag.Bool("console", false, "Console mode, no GUI") @@ -112,6 +365,8 @@ func main() { os.Args[0], ) fmt.Fprintf(os.Stderr, " Allow access from other devices on the local network\n") + fmt.Fprintf(os.Stderr, " %s -host :: ./config.json\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " Bind launcher host explicitly with exact host semantics\n") fmt.Fprintf(os.Stderr, " %s -console -d ./config.json\n", os.Args[0]) fmt.Fprintf(os.Stderr, " Run in the terminal with debug logs enabled\n") } @@ -175,8 +430,9 @@ func main() { logger.DebugC( "web", fmt.Sprintf( - "Launcher flags: console=%t public=%t no_browser=%t config=%s", + "Launcher flags: console=%t host=%q public=%t no_browser=%t config=%s", enableConsole, + *host, *public, *noBrowser, absPath, @@ -186,10 +442,13 @@ func main() { var explicitPort bool var explicitPublic bool + var explicitHost bool flag.Visit(func(f *flag.Flag) { switch f.Name { case "port": explicitPort = true + case "host": + explicitHost = true case "public": explicitPublic = true } @@ -210,6 +469,23 @@ func main() { if !explicitPublic { effectivePublic = launcherCfg.Public } + envHost := strings.TrimSpace(os.Getenv(launcherconfig.EnvLauncherHost)) + + hostInput, hostOverrideActive, err := resolveLauncherHostInput(*host, explicitHost, envHost) + if err != nil { + logger.Fatalf("Invalid host %q: %v", firstNonEmpty(strings.TrimSpace(*host), envHost), err) + } + if hostOverrideActive { + effectivePublic = false + } + + if !explicitHost && hostOverrideActive { + logger.InfoC("web", "Using launcher host from environment PICOCLAW_LAUNCHER_HOST") + } + + if hostOverrideActive && explicitPublic { + logger.InfoC("web", "Ignoring -public because launcher host was explicitly set") + } portNum, err := strconv.Atoi(effectivePort) if err != nil || portNum < 1 || portNum > 65535 { @@ -219,7 +495,13 @@ func main() { logger.Fatalf("Invalid port %q: %v", effectivePort, err) } - dashboardToken, dashboardSigningKey, dashboardTokenSource, dashErr := launcherconfig.EnsureDashboardSecrets( + openResult, err := openLauncherListeners(hostInput, effectivePublic, effectivePort) + if err != nil { + logger.Fatalf("Failed to open launcher listener(s): %v", err) + } + listeners := openResult.Listeners + + dashboardToken, dashboardSigningKey, _, dashErr := launcherconfig.EnsureDashboardSecrets( launcherCfg, ) if dashErr != nil { @@ -227,6 +509,7 @@ func main() { } dashboardSessionCookie := middleware.SessionCookieValue(dashboardSigningKey, dashboardToken) + fmt.Println("dashboardToken: ", dashboardToken) // Open the bcrypt password store (creates the DB file on first run). authStore, authStoreErr := dashboardauth.New(picoHome) var passwordStore api.PasswordStore @@ -246,14 +529,6 @@ func main() { logger.ErrorC("web", fmt.Sprintf("Warning: could not open auth store: %v", authStoreErr)) } - // Determine listen address - var addr string - if effectivePublic { - addr = "0.0.0.0:" + effectivePort - } else { - addr = "127.0.0.1:" + effectivePort - } - // Initialize Server components mux := http.NewServeMux() @@ -271,6 +546,7 @@ func main() { logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) } apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs) + apiHandler.SetServerBindHost(hostInput, hostOverrideActive) apiHandler.RegisterRoutes(mux) // Frontend Embedded Assets @@ -297,49 +573,30 @@ func main() { // Print startup banner and token (console mode only). if enableConsole || debug { + consoleHosts := launcherConsoleHosts(hostInput, effectivePublic) + fmt.Print(utils.Banner) fmt.Println() fmt.Println(" Open the following URL in your browser:") fmt.Println() - fmt.Printf(" >> http://localhost:%s <<\n", effectivePort) - if effectivePublic { - if ip := utils.GetLocalIP(); ip != "" { - fmt.Printf(" >> http://%s:%s <<\n", ip, effectivePort) - } + for _, host := range consoleHosts { + fmt.Printf(" >> http://%s <<\n", net.JoinHostPort(host, effectivePort)) } fmt.Println() - switch dashboardTokenSource { - case launcherconfig.DashboardTokenSourceRandom: - fmt.Printf(" Dashboard password (this run): %s\n", maskSecret(dashboardToken)) - case launcherconfig.DashboardTokenSourceEnv: - fmt.Printf(" Dashboard password: from environment variable PICOCLAW_LAUNCHER_TOKEN\n") - case launcherconfig.DashboardTokenSourceConfig: - fmt.Printf(" Dashboard password: configured in %s\n", launcherPath) - } - fmt.Println() - } - - switch dashboardTokenSource { - case launcherconfig.DashboardTokenSourceEnv: - logger.InfoC("web", "Dashboard password: environment PICOCLAW_LAUNCHER_TOKEN") - case launcherconfig.DashboardTokenSourceConfig: - logger.InfoC("web", fmt.Sprintf("Dashboard password: configured in %s", launcherPath)) - case launcherconfig.DashboardTokenSourceRandom: - if !enableConsole { - logger.InfoC("web", "Dashboard password (this run): "+maskSecret(dashboardToken)) - } } // Log startup info to file - logger.InfoC("web", fmt.Sprintf("Server will listen on http://localhost:%s", effectivePort)) - if effectivePublic { - if ip := utils.GetLocalIP(); ip != "" { - logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s:%s", ip, effectivePort)) + for _, ln := range listeners { + logger.InfoC("web", fmt.Sprintf("Server will listen on http://%s", ln.Addr().String())) + } + if hasWildcardBindHosts(openResult.BindHosts) { + if ip := advertiseIPForWildcardBindHosts(openResult.BindHosts); ip != "" { + logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s", net.JoinHostPort(ip, effectivePort))) } } // Share the local URL with the launcher runtime. - serverAddr = fmt.Sprintf("http://localhost:%s", effectivePort) + serverAddr = fmt.Sprintf("http://%s", net.JoinHostPort(openResult.ProbeHost, effectivePort)) if dashboardToken != "" { browserLaunchURL = serverAddr + "?token=" + url.QueryEscape(dashboardToken) } else { @@ -354,14 +611,19 @@ func main() { apiHandler.TryAutoStartGateway() }() - // Start the Server in a goroutine - server = &http.Server{Addr: addr, Handler: handler} - go func() { - logger.InfoC("web", fmt.Sprintf("Server listening on %s", addr)) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Fatalf("Server failed to start: %v", err) - } - }() + // Start the server(s) in goroutines. + servers = make([]*http.Server, 0, len(listeners)) + for _, ln := range listeners { + srv := &http.Server{Handler: handler} + servers = append(servers, srv) + + go func(s *http.Server, l net.Listener) { + logger.InfoC("web", fmt.Sprintf("Server listening on %s", l.Addr().String())) + if serveErr := s.Serve(l); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + logger.Fatalf("Server failed to start on %s: %v", l.Addr().String(), serveErr) + } + }(srv, ln) + } defer shutdownApp() diff --git a/web/backend/main_test.go b/web/backend/main_test.go index 82bf12b40..6df5370b1 100644 --- a/web/backend/main_test.go +++ b/web/backend/main_test.go @@ -1,8 +1,17 @@ package main import ( + "context" + "errors" + "io" + "net" + "net/http" + "strconv" + "strings" "testing" + "time" + "github.com/sipeed/picoclaw/pkg/netbind" "github.com/sipeed/picoclaw/web/backend/launcherconfig" ) @@ -73,25 +82,351 @@ func TestMaskSecret(t *testing.T) { input string want string }{ - // Long token (>=12 chars): first 3 + 10 stars + last 4 {"sdhjflsjdflksdf", "sdh**********ksdf"}, {"abcdefghijklmnopqrstuvwxyz", "abc**********wxyz"}, - // Exactly 12 chars (3+4+5 hidden): suffix shown {"abcdefghijkl", "abc**********ijkl"}, - // 8 chars (minimum password length): suffix NOT shown — only prefix+stars {"abcdefgh", "abc**********"}, - // 11 chars (one below threshold): suffix NOT shown {"abcdefghijk", "abc**********"}, - // 4..3 chars: prefix shown, no suffix {"abcdefg", "abc**********"}, {"abcd", "abc**********"}, - // <=3 chars: fully masked {"abc", "**********"}, {"", "**********"}, } + for _, tt := range tests { if got := maskSecret(tt.input); got != tt.want { t.Errorf("maskSecret(%q) = %q, want %q", tt.input, got, tt.want) } } } + +func TestResolveLauncherHostInput(t *testing.T) { + tests := []struct { + name string + flagHost string + explicitFlag bool + envHost string + wantHost string + wantActive bool + wantErr bool + }{ + { + name: "flag host wins", + flagHost: "127.0.0.1", + explicitFlag: true, + envHost: "::", + wantHost: "127.0.0.1", + wantActive: true, + }, + {name: "env host used when flag absent", envHost: "127.0.0.1,::1", wantHost: "127.0.0.1,::1", wantActive: true}, + {name: "blank env ignored", envHost: " ", wantHost: "", wantActive: false}, + {name: "invalid flag rejected", flagHost: "127.0.0.1, ", explicitFlag: true, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotHost, gotActive, err := resolveLauncherHostInput(tt.flagHost, tt.explicitFlag, tt.envHost) + if (err != nil) != tt.wantErr { + t.Fatalf("resolveLauncherHostInput() err = %v, wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + return + } + if gotHost != tt.wantHost { + t.Fatalf("resolveLauncherHostInput() host = %q, want %q", gotHost, tt.wantHost) + } + if gotActive != tt.wantActive { + t.Fatalf("resolveLauncherHostInput() active = %t, want %t", gotActive, tt.wantActive) + } + }) + } +} + +func TestLauncherConsoleHosts(t *testing.T) { + t.Run("default loopback shows localhost only", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit loopback hosts collapse to localhost", func(t *testing.T) { + tests := []struct { + name string + hostInput string + }{ + {name: "ipv6 loopback", hostInput: "::1"}, + {name: "ipv4 loopback", hostInput: "127.0.0.1"}, + {name: "localhost", hostInput: "localhost"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + tt.hostInput, + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + } + }) + + t.Run("public wildcard shows localhost then ipv6 and ipv4", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "", + true, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "2001:db8::1", "2001:db8::2", "192.168.1.2", "10.0.0.8"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit ipv6 any shows localhost then ipv6 variants", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "::", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "2001:db8::1", "2001:db8::2"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + + for _, host := range hosts { + if host == "::1" || host == "127.0.0.1" || strings.HasPrefix(strings.ToLower(host), "fe80:") { + t.Fatalf("hosts = %#v, loopback IPs must not be displayed", hosts) + } + } + }) + + t.Run("explicit ipv4 any shows localhost then lan ipv4", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "0.0.0.0", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "192.168.1.2", "10.0.0.8"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit wildcard star shows localhost first", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "*", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"localhost", "2001:db8::1", "2001:db8::2", "192.168.1.2", "10.0.0.8"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) + + t.Run("explicit multi-address binding without local tokens hides localhost", func(t *testing.T) { + hosts := launcherConsoleHostsWithLocalAddrs( + "192.168.1.2,10.0.0.8,2001:db8::1,2001:db8::2,fe80::1", + false, + []string{"192.168.1.2", "10.0.0.8"}, + []string{"2001:db8::1", "2001:db8::2"}, + ) + want := []string{"192.168.1.2", "10.0.0.8", "2001:db8::1", "2001:db8::2"} + if strings.Join(hosts, ",") != strings.Join(want, ",") { + t.Fatalf("hosts = %#v, want %#v", hosts, want) + } + }) +} + +func TestWildcardAdvertiseIP(t *testing.T) { + tests := []struct { + name string + bindHosts []string + ipv4 string + ipv6 string + want string + }{ + { + name: "ipv4 wildcard uses ipv4", + bindHosts: []string{"0.0.0.0"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "192.168.1.2", + }, + { + name: "dual wildcard prefers ipv6", + bindHosts: []string{"0.0.0.0", "::"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "2001:db8::1", + }, + { + name: "ipv6 wildcard uses ipv6", + bindHosts: []string{"::"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "2001:db8::1", + }, + { + name: "dual wildcard falls back to ipv4 when ipv6 missing", + bindHosts: []string{"0.0.0.0", "::"}, + ipv4: "192.168.1.2", + ipv6: "", + want: "192.168.1.2", + }, + { + name: "ipv6 wildcard without ipv6 does not advertise ipv4", + bindHosts: []string{"::"}, + ipv4: "192.168.1.2", + ipv6: "", + want: "", + }, + { + name: "non wildcard does not advertise", + bindHosts: []string{"127.0.0.1"}, + ipv4: "192.168.1.2", + ipv6: "2001:db8::1", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wildcardAdvertiseIP(tt.bindHosts, tt.ipv4, tt.ipv6); got != tt.want { + t.Fatalf("wildcardAdvertiseIP(%#v, %q, %q) = %q, want %q", tt.bindHosts, tt.ipv4, tt.ipv6, got, tt.want) + } + }) + } +} + +func TestOpenLauncherListeners_HonorsIPv6OnlyHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv6 { + t.Skip("IPv6 is unavailable in this environment") + } + + result, err := openLauncherListeners("::", false, "0") + if err != nil { + t.Fatalf("openLauncherListeners() error = %v", err) + } + startLauncherTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireLauncherHTTPReachable(t, "::1", port) + if hasIPv4 { + requireLauncherHTTPUnreachable(t, "127.0.0.1", port) + } +} + +func TestOpenLauncherListeners_SupportsExplicitMultiHost(t *testing.T) { + hasIPv4, hasIPv6 := netbind.DetectIPFamilies() + if !hasIPv4 || !hasIPv6 { + t.Skip("dual-stack loopback is unavailable in this environment") + } + + result, err := openLauncherListeners("127.0.0.1,::1", false, "0") + if err != nil { + t.Fatalf("openLauncherListeners() error = %v", err) + } + startLauncherTestHTTPServer(t, result.Listeners) + port := mustAtoi(t, result.Port) + + requireLauncherHTTPReachable(t, "127.0.0.1", port) + requireLauncherHTTPReachable(t, "::1", port) +} + +func startLauncherTestHTTPServer(t *testing.T, listeners []net.Listener) { + t.Helper() + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + }), + } + + errCh := make(chan error, len(listeners)) + for _, listener := range listeners { + ln := listener + go func() { + errCh <- server.Serve(ln) + }() + } + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(ctx) + for range listeners { + err := <-errCh + if err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("server.Serve() error = %v", err) + } + } + }) +} + +func requireLauncherHTTPReachable(t *testing.T, host string, port int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for { + err := launcherHTTPGet(host, port) + if err == nil { + return + } + if time.Now().After(deadline) { + t.Fatalf("expected %s:%d to be reachable: %v", host, port, err) + } + time.Sleep(50 * time.Millisecond) + } +} + +func requireLauncherHTTPUnreachable(t *testing.T, host string, port int) { + t.Helper() + if err := launcherHTTPGet(host, port); err == nil { + t.Fatalf("expected %s:%d to be unreachable", host, port) + } +} + +func launcherHTTPGet(host string, port int) error { + client := &http.Client{ + Timeout: 300 * time.Millisecond, + Transport: &http.Transport{ + Proxy: nil, + }, + } + + resp, err := client.Get("http://" + net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + return nil +} + +func mustAtoi(t *testing.T, value string) int { + t.Helper() + n, err := strconv.Atoi(value) + if err != nil { + t.Fatalf("Atoi(%q) error = %v", value, err) + } + return n +} diff --git a/web/backend/utils/runtime.go b/web/backend/utils/runtime.go index 0b9e30979..8899a664b 100644 --- a/web/backend/utils/runtime.go +++ b/web/backend/utils/runtime.go @@ -7,6 +7,7 @@ import ( "os/exec" "path/filepath" "runtime" + "strings" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" @@ -54,18 +55,93 @@ func FindPicoclawBinary() string { return "picoclaw" } -// GetLocalIP returns the local IP address of the machine. -func GetLocalIP() string { +func appendUniqueIP(addrs []string, seen map[string]struct{}, value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return addrs + } + if _, ok := seen[value]; ok { + return addrs + } + seen[value] = struct{}{} + return append(addrs, value) +} + +// GetLocalIPv4s returns all non-loopback local IPv4 addresses. +func GetLocalIPv4s() []string { addrs, err := net.InterfaceAddrs() if err != nil { - return "" + return nil } + results := make([]string, 0, 4) + seen := make(map[string]struct{}, 4) for _, a := range addrs { - if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { - return ipnet.IP.String() + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP == nil || ipnet.IP.IsLoopback() { + continue + } + if ip4 := ipnet.IP.To4(); ip4 != nil { + results = appendUniqueIP(results, seen, ip4.String()) } } - return "" + return results +} + +func isDisplayGlobalIPv6(ip net.IP) bool { + if ip == nil || ip.IsLoopback() || ip.To4() != nil { + return false + } + ip = ip.To16() + if ip == nil { + return false + } + // Only show IPv6 global unicast addresses in 2000::/3. + return ip[0]&0xe0 == 0x20 +} + +// GetGlobalIPv6s returns all IPv6 global unicast addresses. +func GetGlobalIPv6s() []string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil + } + results := make([]string, 0, 4) + seen := make(map[string]struct{}, 4) + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP == nil { + continue + } + ip := ipnet.IP + if !isDisplayGlobalIPv6(ip) { + continue + } + results = appendUniqueIP(results, seen, ip.String()) + } + return results +} + +// GetLocalIPv4 returns the first non-loopback local IPv4 address. +func GetLocalIPv4() string { + addrs := GetLocalIPv4s() + if len(addrs) == 0 { + return "" + } + return addrs[0] +} + +// GetLocalIPv6 returns the first IPv6 global unicast address. +func GetLocalIPv6() string { + addrs := GetGlobalIPv6s() + if len(addrs) == 0 { + return "" + } + return addrs[0] +} + +// GetLocalIP returns a non-loopback local IPv4 address for backward compatibility. +func GetLocalIP() string { + return GetLocalIPv4() } // OpenBrowser automatically opens the given URL in the default browser. diff --git a/web/frontend/src/api/launcher-auth.ts b/web/frontend/src/api/launcher-auth.ts index ed2e30687..d6bd93c4d 100644 --- a/web/frontend/src/api/launcher-auth.ts +++ b/web/frontend/src/api/launcher-auth.ts @@ -41,9 +41,7 @@ export async function postLauncherDashboardLogout(): Promise { return res.ok } -export type SetupResult = - | { ok: true } - | { ok: false; error: string } +export type SetupResult = { ok: true } | { ok: false; error: string } export async function postLauncherDashboardSetup( password: string, @@ -53,7 +51,10 @@ export async function postLauncherDashboardSetup( method: "POST", headers: { "Content-Type": "application/json" }, credentials: "same-origin", - body: JSON.stringify({ password: password.trim(), confirm: confirm.trim() }), + body: JSON.stringify({ + password: password.trim(), + confirm: confirm.trim(), + }), }) if (res.ok) return { ok: true } let msg = "Unknown error" diff --git a/web/frontend/src/components/agent/hub/market-skill-card.tsx b/web/frontend/src/components/agent/hub/market-skill-card.tsx index f3ee426a1..99b00db92 100644 --- a/web/frontend/src/components/agent/hub/market-skill-card.tsx +++ b/web/frontend/src/components/agent/hub/market-skill-card.tsx @@ -18,6 +18,11 @@ import { CardHeader, CardTitle, } from "@/components/ui/card" +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip" export function MarketSkillCard({ result, @@ -36,6 +41,17 @@ export function MarketSkillCard({ }) { const { t } = useTranslation() + const installDisabledReason = (() => { + if (installPending) + return t("pages.agent.skills.marketplace_installDisabled.installing") + if (result.installed) + return t("pages.agent.skills.marketplace_installDisabled.installed") + if (!canInstall) + return t("pages.agent.skills.marketplace_installDisabled.cannotInstall") + return t("pages.agent.skills.marketplace_install_action") + })() + const installDisabled = !canInstall || result.installed || installPending + return (
- + + + + + + + {installDisabledReason} + {result.installed && installedSkill ? ( - {gwError ?? t("header.gateway.action.stop")} + + {gwError ?? t("header.gateway.action.stop")} + ) : ( - + {/* Wrap in span so the tooltip still fires when the button is disabled */} - {(gwError || (!canStart && startReason)) ? ( + {gwError || (!canStart && startReason) ? ( {gwError ?? startReason} ) : null} diff --git a/web/frontend/src/components/chat/chat-composer.tsx b/web/frontend/src/components/chat/chat-composer.tsx index b0b25d1db..58612d846 100644 --- a/web/frontend/src/components/chat/chat-composer.tsx +++ b/web/frontend/src/components/chat/chat-composer.tsx @@ -7,6 +7,18 @@ import { Button } from "@/components/ui/button" import { cn } from "@/lib/utils" import type { ChatAttachment } from "@/store/chat" +export type ChatInputDisabledReason = + | "gatewayUnknown" + | "gatewayStarting" + | "gatewayRestarting" + | "gatewayStopping" + | "gatewayStopped" + | "gatewayError" + | "websocketConnecting" + | "websocketDisconnected" + | "websocketError" + | "noDefaultModel" + interface ChatComposerProps { input: string attachments: ChatAttachment[] @@ -14,8 +26,7 @@ interface ChatComposerProps { onAddImages: () => void onRemoveAttachment: (index: number) => void onSend: () => void - isConnected: boolean - hasDefaultModel: boolean + inputDisabledReason: ChatInputDisabledReason | null canSend: boolean } @@ -26,12 +37,16 @@ export function ChatComposer({ onAddImages, onRemoveAttachment, onSend, - isConnected, - hasDefaultModel, + inputDisabledReason, canSend, }: ChatComposerProps) { const { t } = useTranslation() - const canInput = isConnected && hasDefaultModel + const canInput = inputDisabledReason === null + const disabledMessage = + inputDisabledReason === null + ? null + : t(`chat.disabledPlaceholder.${inputDisabledReason}`) + const placeholder = disabledMessage ?? t("chat.placeholder") const handleKeyDown = (e: KeyboardEvent) => { if (e.nativeEvent.isComposing) return @@ -74,8 +89,9 @@ export function ChatComposer({ value={input} onChange={(e) => onInputChange(e.target.value)} onKeyDown={handleKeyDown} - placeholder={t("chat.placeholder")} + placeholder={placeholder} disabled={!canInput} + title={disabledMessage || undefined} className={cn( "placeholder:text-muted-foreground/50 max-h-[200px] min-h-[60px] resize-none border-0 bg-transparent px-2 py-1 text-[15px] shadow-none transition-colors focus-visible:ring-0 focus-visible:outline-none dark:bg-transparent", !canInput && "cursor-not-allowed", @@ -83,6 +99,11 @@ export function ChatComposer({ minRows={1} maxRows={8} /> + {!canInput && disabledMessage && ( +
+ {disabledMessage} +
+ )}
@@ -100,15 +121,17 @@ export function ChatComposer({
- + {canInput ? ( + + ) : null}
diff --git a/web/frontend/src/components/chat/chat-page.tsx b/web/frontend/src/components/chat/chat-page.tsx index e8e07a801..4129d812a 100644 --- a/web/frontend/src/components/chat/chat-page.tsx +++ b/web/frontend/src/components/chat/chat-page.tsx @@ -4,7 +4,10 @@ import { useTranslation } from "react-i18next" import { toast } from "sonner" import { AssistantMessage } from "@/components/chat/assistant-message" -import { ChatComposer } from "@/components/chat/chat-composer" +import { + ChatComposer, + type ChatInputDisabledReason, +} from "@/components/chat/chat-composer" import { ChatEmptyState } from "@/components/chat/chat-empty-state" import { ModelSelector } from "@/components/chat/model-selector" import { SessionHistoryMenu } from "@/components/chat/session-history-menu" @@ -16,7 +19,9 @@ import { useChatModels } from "@/hooks/use-chat-models" import { useGateway } from "@/hooks/use-gateway" import { usePicoChat } from "@/hooks/use-pico-chat" import { useSessionHistory } from "@/hooks/use-session-history" +import type { ConnectionState } from "@/store/chat" import type { ChatAttachment } from "@/store/chat" +import type { GatewayState } from "@/store/gateway" const MAX_IMAGE_SIZE_BYTES = 7 * 1024 * 1024 const MAX_IMAGE_SIZE_LABEL = "7 MB" @@ -44,6 +49,58 @@ function readFileAsDataUrl(file: File): Promise { }) } +function resolveChatInputDisabledReason({ + hasDefaultModel, + connectionState, + gatewayState, +}: { + hasDefaultModel: boolean + connectionState: ConnectionState + gatewayState: GatewayState +}): ChatInputDisabledReason | null { + if (gatewayState === "unknown") { + return "gatewayUnknown" + } + + if (gatewayState === "starting") { + return "gatewayStarting" + } + + if (gatewayState === "restarting") { + return "gatewayRestarting" + } + + if (gatewayState === "stopping") { + return "gatewayStopping" + } + + if (gatewayState === "stopped") { + return "gatewayStopped" + } + + if (gatewayState === "error") { + return "gatewayError" + } + + if (connectionState === "connecting") { + return "websocketConnecting" + } + + if (connectionState === "error") { + return "websocketError" + } + + if (connectionState === "disconnected") { + return "websocketDisconnected" + } + + if (!hasDefaultModel) { + return "noDefaultModel" + } + + return null +} + export function ChatPage() { const { t } = useTranslation() const scrollRef = useRef(null) @@ -65,7 +122,6 @@ export function ChatPage() { const { state: gwState } = useGateway() const isGatewayRunning = gwState === "running" - const isChatConnected = connectionState === "connected" const { defaultModelName, @@ -75,7 +131,13 @@ export function ChatPage() { localModels, handleSetDefault, } = useChatModels({ isConnected: isGatewayRunning }) - const canSend = isChatConnected && Boolean(defaultModelName) + const hasDefaultModel = Boolean(defaultModelName) + const inputDisabledReason = resolveChatInputDisabledReason({ + hasDefaultModel, + connectionState, + gatewayState: gwState, + }) + const canInput = inputDisabledReason === null const { sessions, @@ -110,7 +172,7 @@ export function ChatPage() { }, [messages, isTyping, isAtBottom]) const handleSend = () => { - if ((!input.trim() && attachments.length === 0) || !canSend) return + if ((!input.trim() && attachments.length === 0) || !canInput) return if ( sendMessage({ content: input, @@ -123,7 +185,7 @@ export function ChatPage() { } const handleAddImages = () => { - if (!canSend) return + if (!canInput) return fileInputRef.current?.click() } @@ -180,7 +242,8 @@ export function ChatPage() { } } - const canSubmit = canSend && (Boolean(input.trim()) || attachments.length > 0) + const canSubmit = + canInput && (Boolean(input.trim()) || attachments.length > 0) return (
@@ -278,8 +341,7 @@ export function ChatPage() { onAddImages={handleAddImages} onRemoveAttachment={handleRemoveAttachment} onSend={handleSend} - isConnected={isChatConnected} - hasDefaultModel={Boolean(defaultModelName)} + inputDisabledReason={inputDisabledReason} canSend={canSubmit} />
diff --git a/web/frontend/src/components/models/model-card.tsx b/web/frontend/src/components/models/model-card.tsx index 3489f22e7..44730bb57 100644 --- a/web/frontend/src/components/models/model-card.tsx +++ b/web/frontend/src/components/models/model-card.tsx @@ -10,6 +10,11 @@ import { useTranslation } from "react-i18next" import type { ModelInfo } from "@/api/models" import { Button } from "@/components/ui/button" +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip" interface ModelCardProps { model: ModelInfo @@ -33,6 +38,23 @@ export function ModelCard({ const canSetDefault = model.available && !model.is_default && !model.is_virtual + const setDefaultLabel = t("models.action.setDefault") + const setDefaultDisabledReason = (() => { + if (settingDefault) return t("models.action.setDefaultDisabled.setting") + if (!model.available) + return t("models.action.setDefaultDisabled.unavailable") + if (model.is_default) return t("models.action.setDefaultDisabled.isDefault") + if (model.is_virtual) return t("models.action.setDefaultDisabled.isVirtual") + return setDefaultLabel + })() + + const editLabel = t("models.action.edit") + const deleteLabel = t("models.action.delete") + const deleteDisabledReason = model.is_default + ? t("models.action.deleteDisabled.isDefault") + : deleteLabel + const deleteDisabled = model.is_default + return (
) : ( - + + + + + + + {setDefaultDisabledReason} + )} - + + + + + + + {deleteDisabledReason} +
diff --git a/web/frontend/src/features/chat/controller.ts b/web/frontend/src/features/chat/controller.ts index c5c93d2e8..28ef491fa 100644 --- a/web/frontend/src/features/chat/controller.ts +++ b/web/frontend/src/features/chat/controller.ts @@ -12,10 +12,7 @@ import { generateSessionId, readStoredSessionId, } from "@/features/chat/state" -import { - invalidateSocket, - isCurrentSocket, -} from "@/features/chat/websocket" +import { invalidateSocket, isCurrentSocket } from "@/features/chat/websocket" import i18n from "@/i18n" import { type ChatAttachment, diff --git a/web/frontend/src/features/chat/protocol.ts b/web/frontend/src/features/chat/protocol.ts index a7edfc21b..717b42f84 100644 --- a/web/frontend/src/features/chat/protocol.ts +++ b/web/frontend/src/features/chat/protocol.ts @@ -1,10 +1,7 @@ import { toast } from "sonner" import { normalizeUnixTimestamp } from "@/features/chat/state" -import { - type AssistantMessageKind, - updateChatStore, -} from "@/store/chat" +import { type AssistantMessageKind, updateChatStore } from "@/store/chat" export interface PicoMessage { type: string diff --git a/web/frontend/src/hooks/use-gateway.ts b/web/frontend/src/hooks/use-gateway.ts index 31bee0e91..cbf132941 100644 --- a/web/frontend/src/hooks/use-gateway.ts +++ b/web/frontend/src/hooks/use-gateway.ts @@ -77,5 +77,15 @@ export function useGateway() { } }, [state]) - return { state, loading, canStart, startReason, restartRequired, start, stop, restart, error } + return { + state, + loading, + canStart, + startReason, + restartRequired, + start, + stop, + restart, + error, + } } diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index b5ba80533..4bc585f3c 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -39,6 +39,18 @@ "welcome": "How can I help you today?", "welcomeDesc": "Ask me about weather, settings, or any other tasks. I'm here to assist you.", "placeholder": "Start a new message...\nPress Enter to send, Shift + Enter for a new line", + "disabledPlaceholder": { + "gatewayUnknown": "Unable to chat: Gateway status is still being checked. Please wait, then refresh the page or restart Launcher if needed.", + "gatewayStarting": "Unable to chat: Gateway is starting. Wait for startup to complete, then try again.", + "gatewayRestarting": "Unable to chat: Gateway is restarting. Please wait for restart to finish.", + "gatewayStopping": "Unable to chat: Gateway is stopping. Wait for it to stop, then start Gateway again.", + "gatewayStopped": "Unable to chat: Gateway is not started. Click Start Gateway in the top bar, then retry.", + "gatewayError": "Unable to chat: Gateway is in an error state. Check logs, then restart Gateway or Launcher.", + "websocketConnecting": "Connecting to chat service... Please wait.", + "websocketDisconnected": "Unable to chat: WebSocket connection is disconnected. Check network and gateway status, then refresh the page or restart Launcher.", + "websocketError": "Unable to chat: WebSocket connection failed. Check network and gateway status, then retry.", + "noDefaultModel": "Unable to chat: No default model is selected. Set a default model on the Models page." + }, "newChat": "New Chat", "notConnected": "Gateway is not running. Start it to chat.", "thinking": { @@ -56,6 +68,10 @@ "deleteSession": "Delete session", "messagesCount": "{{count}} messages", "noModel": "Select model", + "inputDisabled": { + "notConnected": "Gateway is not running. Start it to chat.", + "noModel": "No default model configured. Go to Models page to set one." + }, "attachImage": "Add images", "removeImage": "Remove image", "uploadedImage": "Uploaded image", @@ -200,7 +216,16 @@ "action": { "edit": "Edit API key", "setDefault": "Set as default", - "delete": "Delete model" + "delete": "Delete model", + "setDefaultDisabled": { + "setting": "Setting as default...", + "unavailable": "Cannot set unavailable model as default", + "isDefault": "Already the default model", + "isVirtual": "Cannot set virtual model as default" + }, + "deleteDisabled": { + "isDefault": "Cannot delete the default model" + } }, "defaultOnSave": { "label": "Default Model", @@ -488,6 +513,11 @@ "version": "Installed Version", "lines": "Line Count", "characters": "Character Count" + }, + "marketplace_installDisabled": { + "installing": "Installing...", + "installed": "Already installed", + "cannotInstall": "Cannot install: related tool is not enabled" } }, "tools": { diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index 710dfa437..0177bc08a 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -39,6 +39,18 @@ "welcome": "今天我能为您做些什么?", "welcomeDesc": "您可以询问我天气、设置或其他任何任务,我随时为您效劳。", "placeholder": "输入新消息...\n按 Enter 发送,Shift + Enter 换行", + "disabledPlaceholder": { + "gatewayUnknown": "无法对话:网关状态仍在检测中。请稍候重试,如仍无效请刷新页面或重启 Launcher。", + "gatewayStarting": "无法对话:网关正在启动。请等待启动完成后重试。", + "gatewayRestarting": "无法对话:网关正在重启。请等待重启完成。", + "gatewayStopping": "无法对话:网关正在停止。请等待停止完成后重新启动服务。", + "gatewayStopped": "无法对话:网关服务未启动。请点击顶部栏的“启动服务”后重试。", + "gatewayError": "无法对话:网关处于错误状态。请检查日志后重启网关或 Launcher。", + "websocketConnecting": "正在连接聊天服务,请稍候。", + "websocketDisconnected": "无法对话:WebSocket 连接已断开。请检查网络与服务状态,然后刷新页面或重启 Launcher。", + "websocketError": "无法对话:WebSocket 连接失败。请检查网络与服务状态后重试。", + "noDefaultModel": "无法对话:尚未设置默认模型。请前往模型页面设置默认模型。" + }, "newChat": "新建对话", "notConnected": "服务未运行,请先启动以进行对话。", "thinking": { @@ -56,6 +68,10 @@ "deleteSession": "删除会话", "messagesCount": "{{count}} 条消息", "noModel": "选择模型", + "inputDisabled": { + "notConnected": "服务未运行,请先启动以进行对话。", + "noModel": "未设置默认模型,请前往模型页面进行配置。" + }, "attachImage": "添加图片", "removeImage": "移除图片", "uploadedImage": "已上传图片", @@ -200,7 +216,16 @@ "action": { "edit": "编辑 API Key", "setDefault": "设为默认", - "delete": "删除模型" + "delete": "删除模型", + "setDefaultDisabled": { + "setting": "正在设为默认...", + "unavailable": "无法将不可用的模型设为默认", + "isDefault": "该模型已是默认模型", + "isVirtual": "无法将虚拟模型设为默认" + }, + "deleteDisabled": { + "isDefault": "无法删除默认模型" + } }, "defaultOnSave": { "label": "默认模型", @@ -488,6 +513,11 @@ "version": "已安装版本", "lines": "行数", "characters": "字符数" + }, + "marketplace_installDisabled": { + "installing": "正在安装...", + "installed": "已安装", + "cannotInstall": "无法安装:相关工具未启用" } }, "tools": { diff --git a/web/frontend/src/routes/__root.tsx b/web/frontend/src/routes/__root.tsx index b5af5de45..60d45ef84 100644 --- a/web/frontend/src/routes/__root.tsx +++ b/web/frontend/src/routes/__root.tsx @@ -53,7 +53,9 @@ const RootLayout = () => { globalThis.location.assign("/launcher-login") } else { setAuthError( - err instanceof Error ? err.message : "Auth service unavailable, please try to delete the launcher-auth.db at picoclaw home directory and restart the application.", + err instanceof Error + ? err.message + : "Auth service unavailable, please try to delete the launcher-auth.db at picoclaw home directory and restart the application.", ) } }) diff --git a/web/frontend/src/routes/launcher-login.tsx b/web/frontend/src/routes/launcher-login.tsx index c5626fbb0..caa548c79 100644 --- a/web/frontend/src/routes/launcher-login.tsx +++ b/web/frontend/src/routes/launcher-login.tsx @@ -3,7 +3,10 @@ import { createFileRoute } from "@tanstack/react-router" import * as React from "react" import { useTranslation } from "react-i18next" -import { postLauncherDashboardLogin, getLauncherAuthStatus } from "@/api/launcher-auth" +import { + getLauncherAuthStatus, + postLauncherDashboardLogin, +} from "@/api/launcher-auth" import { Button } from "@/components/ui/button" import { Card, @@ -37,7 +40,9 @@ function LauncherLoginPage() { globalThis.location.assign("/launcher-setup") } }) - .catch(() => { /* network error — stay on login page */ }) + .catch(() => { + /* network error — stay on login page */ + }) }, []) const loginWithToken = React.useCallback( diff --git a/web/frontend/src/routes/launcher-setup.tsx b/web/frontend/src/routes/launcher-setup.tsx index 876af94fb..87c934a09 100644 --- a/web/frontend/src/routes/launcher-setup.tsx +++ b/web/frontend/src/routes/launcher-setup.tsx @@ -6,141 +6,141 @@ import { useTranslation } from "react-i18next" import { postLauncherDashboardSetup } from "@/api/launcher-auth" import { Button } from "@/components/ui/button" import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, } from "@/components/ui/card" import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuTrigger, + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, } from "@/components/ui/dropdown-menu" import { Input } from "@/components/ui/input" import { Label } from "@/components/ui/label" import { useTheme } from "@/hooks/use-theme" function LauncherSetupPage() { - const { t, i18n } = useTranslation() - const { theme, toggleTheme } = useTheme() - const [password, setPassword] = React.useState("") - const [confirm, setConfirm] = React.useState("") - const [submitting, setSubmitting] = React.useState(false) - const [error, setError] = React.useState("") + const { t, i18n } = useTranslation() + const { theme, toggleTheme } = useTheme() + const [password, setPassword] = React.useState("") + const [confirm, setConfirm] = React.useState("") + const [submitting, setSubmitting] = React.useState(false) + const [error, setError] = React.useState("") - const onSubmit = async (e: React.FormEvent) => { - e.preventDefault() - setError("") - if (password !== confirm) { - setError(t("launcherSetup.errorMismatch")) - return - } - setSubmitting(true) - try { - const result = await postLauncherDashboardSetup(password, confirm) - if (result.ok) { - globalThis.location.assign("/launcher-login") - return - } - setError(result.error) - } catch { - setError(t("launcherSetup.errorNetwork")) - } finally { - setSubmitting(false) - } + const onSubmit = async (e: React.FormEvent) => { + e.preventDefault() + setError("") + if (password !== confirm) { + setError(t("launcherSetup.errorMismatch")) + return } + setSubmitting(true) + try { + const result = await postLauncherDashboardSetup(password, confirm) + if (result.ok) { + globalThis.location.assign("/launcher-login") + return + } + setError(result.error) + } catch { + setError(t("launcherSetup.errorNetwork")) + } finally { + setSubmitting(false) + } + } - return ( -
-
- - - - - - i18n.changeLanguage("en")}> - English - - i18n.changeLanguage("zh")}> - 简体中文 - - - - -
+ return ( +
+
+ + + + + + i18n.changeLanguage("en")}> + English + + i18n.changeLanguage("zh")}> + 简体中文 + + + + +
-
- - - {t("launcherSetup.title")} - {t("launcherSetup.description")} - - -
-
- - setPassword(e.target.value)} - placeholder={t("launcherSetup.passwordPlaceholder")} - /> -
-
- - setConfirm(e.target.value)} - placeholder={t("launcherSetup.confirmPlaceholder")} - /> -
- - {error ? ( -

- {error} -

- ) : null} -
-
-
-
-
- ) +
+ + + {t("launcherSetup.title")} + {t("launcherSetup.description")} + + +
+
+ + setPassword(e.target.value)} + placeholder={t("launcherSetup.passwordPlaceholder")} + /> +
+
+ + setConfirm(e.target.value)} + placeholder={t("launcherSetup.confirmPlaceholder")} + /> +
+ + {error ? ( +

+ {error} +

+ ) : null} +
+
+
+
+
+ ) } export const Route = createFileRoute("/launcher-setup")({ - component: LauncherSetupPage, + component: LauncherSetupPage, })