From 448027c02ae571aa9fe6a22e5f7dd5924cbe52ae Mon Sep 17 00:00:00 2001 From: lc6464 <64722907+lc6464@users.noreply.github.com> Date: Mon, 13 Apr 2026 21:33:22 +0800 Subject: [PATCH] fix(host): align launcher and gateway host normalization semantics --- cmd/picoclaw/internal/gateway/command.go | 19 ++- cmd/picoclaw/internal/gateway/command_test.go | 26 ++++ pkg/config/config.go | 3 + pkg/config/gateway.go | 28 +++++ pkg/config/gateway_host_env_test.go | 61 ++++++++++ web/backend/api/gateway.go | 15 ++- web/backend/api/gateway_host.go | 114 ++++++++++++++++-- web/backend/api/gateway_host_test.go | 55 +++++++++ web/backend/main.go | 25 +++- web/backend/main_test.go | 24 ++++ web/backend/utils/runtime.go | 32 ++++- 11 files changed, 380 insertions(+), 22 deletions(-) create mode 100644 pkg/config/gateway_host_env_test.go diff --git a/cmd/picoclaw/internal/gateway/command.go b/cmd/picoclaw/internal/gateway/command.go index 5d81cb24e..5487a20bb 100644 --- a/cmd/picoclaw/internal/gateway/command.go +++ b/cmd/picoclaw/internal/gateway/command.go @@ -14,6 +14,14 @@ import ( "github.com/sipeed/picoclaw/pkg/utils" ) +func resolveGatewayHostOverride(explicit bool, host string) (string, error) { + host = strings.TrimSpace(host) + if explicit && host == "" { + return "", fmt.Errorf("the --host option cannot be empty") + } + return host, nil +} + func NewGatewayCommand() *cobra.Command { var debug bool var noTruncate bool @@ -37,11 +45,14 @@ func NewGatewayCommand() *cobra.Command { return nil }, - RunE: func(_ *cobra.Command, _ []string) error { - host = strings.TrimSpace(host) - if host != "" { + RunE: func(cmd *cobra.Command, _ []string) error { + resolvedHost, err := resolveGatewayHostOverride(cmd.Flags().Changed("host"), host) + if err != nil { + return err + } + if resolvedHost != "" { prevHost, hadPrev := os.LookupEnv(config.EnvGatewayHost) - if err := os.Setenv(config.EnvGatewayHost, host); err != nil { + if err := os.Setenv(config.EnvGatewayHost, resolvedHost); err != nil { return fmt.Errorf("failed to set %s: %w", config.EnvGatewayHost, err) } defer func() { diff --git a/cmd/picoclaw/internal/gateway/command_test.go b/cmd/picoclaw/internal/gateway/command_test.go index 6be5f0ba3..b53d5253c 100644 --- a/cmd/picoclaw/internal/gateway/command_test.go +++ b/cmd/picoclaw/internal/gateway/command_test.go @@ -31,3 +31,29 @@ func TestNewGatewayCommand(t *testing.T) { assert.NotNil(t, cmd.Flags().Lookup("allow-empty")) assert.NotNil(t, cmd.Flags().Lookup("host")) } + +func TestResolveGatewayHostOverride(t *testing.T) { + tests := []struct { + name string + explicit bool + host string + wantHost string + wantErr bool + }{ + {name: "implicit empty host is allowed", explicit: false, host: "", wantHost: "", wantErr: false}, + {name: "explicit empty host rejected", explicit: true, host: " ", wantHost: "", wantErr: true}, + {name: "explicit localhost kept", explicit: true, host: " localhost ", wantHost: "localhost", wantErr: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveGatewayHostOverride(tt.explicit, tt.host) + if (err != nil) != tt.wantErr { + t.Fatalf("resolveGatewayHostOverride() err = %v, wantErr %t", err, tt.wantErr) + } + if got != tt.wantHost { + t.Fatalf("resolveGatewayHostOverride() host = %q, want %q", got, tt.wantHost) + } + }) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 9488fd96c..07e52de97 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1073,6 +1073,8 @@ func LoadConfig(path string) (*Config, error) { applyLegacyBindingsMigration(data, cfg) + gatewayHostBeforeEnv := cfg.Gateway.Host + if err = env.Parse(cfg); err != nil { return nil, err } @@ -1080,6 +1082,7 @@ func LoadConfig(path string) (*Config, error) { if err = InitChannelList(cfg.Channels); err != nil { return nil, err } + cfg.Gateway.Host = resolveGatewayHostFromEnv(gatewayHostBeforeEnv) // Expand multi-key configs into separate entries for key-level failover cfg.ModelList = expandMultiKeyModels(cfg.ModelList) diff --git a/pkg/config/gateway.go b/pkg/config/gateway.go index e9f4085d3..5cae346cc 100644 --- a/pkg/config/gateway.go +++ b/pkg/config/gateway.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "os" + "strings" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -49,6 +50,33 @@ func EffectiveGatewayLogLevel(cfg *Config) string { return normalizeGatewayLogLevel(cfg.Gateway.LogLevel) } +func normalizeGatewayHost(host string) string { + host = strings.TrimSpace(host) + if host != "" { + return host + } + + defaultHost := strings.TrimSpace(DefaultConfig().Gateway.Host) + if defaultHost == "" { + return "127.0.0.1" + } + return defaultHost +} + +func resolveGatewayHostFromEnv(baseHost string) string { + envHost, ok := os.LookupEnv(EnvGatewayHost) + if !ok { + return normalizeGatewayHost(baseHost) + } + + envHost = strings.TrimSpace(envHost) + if envHost == "" { + return normalizeGatewayHost(baseHost) + } + + return envHost +} + // ResolveGatewayLogLevel reads the configured gateway log level without triggering // the full config loader, so startup code can apply logging before config load logs run. // The PICOCLAW_LOG_LEVEL environment variable overrides the file value. diff --git a/pkg/config/gateway_host_env_test.go b/pkg/config/gateway_host_env_test.go new file mode 100644 index 000000000..3754eefdf --- /dev/null +++ b/pkg/config/gateway_host_env_test.go @@ -0,0 +1,61 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +func writeGatewayHostTestConfig(t *testing.T, host string) string { + t.Helper() + + configPath := filepath.Join(t.TempDir(), "config.json") + raw := fmt.Sprintf(`{"version":2,"gateway":{"host":%q,"port":18790}}`, host) + if err := os.WriteFile(configPath, []byte(raw), 0o600); err != nil { + t.Fatalf("WriteFile(configPath): %v", err) + } + return configPath +} + +func TestLoadConfig_GatewayHostEnvTrimmed(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, "127.0.0.1") + t.Setenv(EnvGatewayHost, " ::1 ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Gateway.Host != "::1" { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "::1") + } +} + +func TestLoadConfig_GatewayHostBlankEnvFallsBackToConfigHost(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, " localhost ") + t.Setenv(EnvGatewayHost, " ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Gateway.Host != "localhost" { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, "localhost") + } +} + +func TestLoadConfig_GatewayHostBlankEnvAndConfigFallsBackToDefault(t *testing.T) { + configPath := writeGatewayHostTestConfig(t, " ") + t.Setenv(EnvGatewayHost, " ") + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + + defaultHost := strings.TrimSpace(DefaultConfig().Gateway.Host) + if cfg.Gateway.Host != defaultHost { + t.Fatalf("cfg.Gateway.Host = %q, want %q", cfg.Gateway.Host, defaultHost) + } +} diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 0dec45cba..28b5f3540 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -731,8 +731,19 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int if h.configPath != "" { cmd.Env = append(cmd.Env, config.EnvConfig+"="+h.configPath) } - if host := h.gatewayHostOverride(); host != "" { - cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+host) + gatewayHostOverride := h.gatewayHostOverrideForConfig(cfg) + if h.serverHostExplicit && gatewayHostOverride == "" { + logger.WarnC( + "gateway", + fmt.Sprintf( + "Explicit launcher host %q was not forwarded to gateway because configured gateway host is %q; gateway keeps original bind host", + strings.TrimSpace(h.serverHost), + strings.TrimSpace(cfg.Gateway.Host), + ), + ) + } + if gatewayHostOverride != "" { + cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+gatewayHostOverride) } stdoutPipe, err := cmd.StdoutPipe() diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index 19f65d34e..a5aa33c32 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -6,10 +6,76 @@ import ( "net/url" "strconv" "strings" + "sync" "github.com/sipeed/picoclaw/pkg/config" ) +var ( + adaptiveLoopbackHostOnce sync.Once + adaptiveLoopbackHost string +) + +func selectAdaptiveLoopbackHost(hasIPv4, hasIPv6 bool) string { + switch { + case hasIPv4 && hasIPv6: + return "localhost" + case hasIPv6: + return "::1" + case hasIPv4: + return "127.0.0.1" + default: + return "127.0.0.1" + } +} + +func isLoopbackEquivalentHost(host string) bool { + host = strings.TrimSpace(host) + if host == "" { + return false + } + if strings.EqualFold(host, "localhost") { + return true + } + trimmed := strings.Trim(host, "[]") + ip := net.ParseIP(trimmed) + return ip != nil && ip.IsLoopback() +} + +func resolveAdaptiveLoopbackHost() string { + adaptiveLoopbackHostOnce.Do(func() { + ips, err := net.LookupIP("localhost") + if err != nil { + adaptiveLoopbackHost = selectAdaptiveLoopbackHost(false, false) + return + } + + hasIPv4 := false + hasIPv6 := false + for _, ip := range ips { + if ip == nil { + continue + } + if ip.To4() != nil { + hasIPv4 = true + continue + } + hasIPv6 = true + } + + adaptiveLoopbackHost = selectAdaptiveLoopbackHost(hasIPv4, hasIPv6) + }) + return adaptiveLoopbackHost +} + +func resolveDefaultLoopbackHost() string { + return resolveAdaptiveLoopbackHost() +} + +func resolveLocalhostLoopbackHost() string { + return resolveAdaptiveLoopbackHost() +} + func (h *Handler) effectiveLauncherPublic() bool { if h.serverHostExplicit { // -host takes precedence over -public and launcher-config public setting. @@ -30,27 +96,33 @@ func (h *Handler) effectiveLauncherPublic() bool { func canonicalLauncherBindHost(host string) string { host = strings.TrimSpace(host) - if host == "" || strings.EqualFold(host, "localhost") { - return "127.0.0.1" + if host == "" { + return resolveDefaultLoopbackHost() + } + if strings.EqualFold(host, "localhost") { + return resolveLocalhostLoopbackHost() } return host } -func (h *Handler) launcherAndGatewayBindHostsAligned() bool { - cfg, err := config.LoadConfig(h.configPath) - if err != nil || cfg == nil { +func (h *Handler) launcherAndGatewayBindHostsAligned(cfg *config.Config) bool { + if cfg == nil { return false } // With -host specified, -public is ignored, so launcher's legacy bind host is loopback. launcherHost := canonicalLauncherBindHost("127.0.0.1") gatewayHost := canonicalLauncherBindHost(cfg.Gateway.Host) + if isLoopbackEquivalentHost(launcherHost) && isLoopbackEquivalentHost(gatewayHost) { + return true + } + return launcherHost == gatewayHost } -func (h *Handler) gatewayHostOverride() string { +func (h *Handler) gatewayHostOverrideForConfig(cfg *config.Config) string { if h.serverHostExplicit { - if h.launcherAndGatewayBindHostsAligned() { + if h.launcherAndGatewayBindHostsAligned(cfg) { return strings.TrimSpace(h.serverHost) } return "" @@ -62,8 +134,20 @@ func (h *Handler) gatewayHostOverride() string { return "" } +func (h *Handler) gatewayHostOverride() string { + if !h.serverHostExplicit { + return h.gatewayHostOverrideForConfig(nil) + } + + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + return "" + } + return h.gatewayHostOverrideForConfig(cfg) +} + func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string { - if override := h.gatewayHostOverride(); override != "" { + if override := h.gatewayHostOverrideForConfig(cfg); override != "" { return override } if cfg == nil { @@ -73,7 +157,19 @@ func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string { } func gatewayProbeHost(bindHost string) string { - if bindHost == "" || bindHost == "0.0.0.0" { + bindHost = strings.TrimSpace(bindHost) + if bindHost == "" { + return resolveDefaultLoopbackHost() + } + if strings.EqualFold(bindHost, "localhost") { + return resolveLocalhostLoopbackHost() + } + + trimmed := strings.Trim(bindHost, "[]") + if ip := net.ParseIP(trimmed); ip != nil && ip.IsUnspecified() { + if ip.To4() == nil { + return "::1" + } return "127.0.0.1" } return bindHost diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index c71d1a24d..56d4a9ca8 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -63,12 +63,54 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) { } } +func TestSelectAdaptiveLoopbackHost(t *testing.T) { + tests := []struct { + name string + hasIPv4 bool + hasIPv6 bool + want string + }{ + {name: "dual stack prefers localhost", hasIPv4: true, hasIPv6: true, want: "localhost"}, + {name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::1"}, + {name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "127.0.0.1"}, + {name: "fallback", hasIPv4: false, hasIPv6: false, want: "127.0.0.1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := selectAdaptiveLoopbackHost(tt.hasIPv4, tt.hasIPv6); got != tt.want { + t.Fatalf("selectAdaptiveLoopbackHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want) + } + }) + } +} + func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) { if got := gatewayProbeHost("0.0.0.0"); got != "127.0.0.1" { t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1") } } +func TestGatewayProbeHostUsesPreferredLoopbackForEmptyBind(t *testing.T) { + want := resolveDefaultLoopbackHost() + if got := gatewayProbeHost(""); got != want { + t.Fatalf("gatewayProbeHost(empty) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesPreferredLoopbackForLocalhostBind(t *testing.T) { + want := resolveLocalhostLoopbackHost() + if got := gatewayProbeHost("localhost"); got != want { + t.Fatalf("gatewayProbeHost(localhost) = %q, want %q", got, want) + } +} + +func TestGatewayProbeHostUsesLoopbackForIPv6WildcardBind(t *testing.T) { + if got := gatewayProbeHost("::"); got != "::1" { + t.Fatalf("gatewayProbeHost(::) = %q, want %q", got, "::1") + } +} + func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -254,6 +296,19 @@ func TestGatewayHostOverrideWithExplicitHostAndAlignedGatewayHost(t *testing.T) } } +func TestGatewayHostOverrideWithExplicitHostAndLocalhostGatewayHost(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + writeGatewayHostConfig(t, configPath, "localhost") + + h := NewHandler(configPath) + h.SetServerOptions(18800, false, false, nil) + h.SetServerBindHost("::", true) + + if got := h.gatewayHostOverride(); got != "::" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "::") + } +} + func TestGatewayHostOverrideWithExplicitHostAndMismatchedGatewayHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") writeGatewayHostConfig(t, configPath, "0.0.0.0") diff --git a/web/backend/main.go b/web/backend/main.go index 088fda3d5..41251d1bf 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -108,6 +108,21 @@ func browserHostForLauncher(bindHost string) string { return bindHost } +func wildcardAdvertiseIP(bindHost, ipv4, ipv6 string) string { + switch strings.TrimSpace(bindHost) { + case "0.0.0.0": + return strings.TrimSpace(ipv4) + case "::": + return strings.TrimSpace(ipv6) + default: + return "" + } +} + +func advertiseIPForWildcardBindHost(bindHost string) string { + return wildcardAdvertiseIP(bindHost, utils.GetLocalIPv4(), utils.GetLocalIPv6()) +} + // maskSecret masks a secret for display. It always shows up to the first 3 // runes. The last 4 runes are only appended when at least 5 runes remain // hidden in the middle (i.e. string length >= 12), so an 8-char minimum @@ -157,7 +172,7 @@ func main() { ) fmt.Fprintf(os.Stderr, " Allow access from other devices on the local network\n") fmt.Fprintf(os.Stderr, " %s -host 0.0.0.0 ./config.json\n", os.Args[0]) - fmt.Fprintf(os.Stderr, " Bind launcher and gateway host explicitly\n") + fmt.Fprintf(os.Stderr, " Bind launcher host explicitly (gateway forwarding follows compatibility rules)\n") fmt.Fprintf(os.Stderr, " %s -console -d ./config.json\n", os.Args[0]) fmt.Fprintf(os.Stderr, " Run in the terminal with debug logs enabled\n") } @@ -368,8 +383,8 @@ func main() { fmt.Println() fmt.Printf(" >> http://localhost:%s <<\n", effectivePort) if isWildcardBindHost(effectiveHost) { - if ip := utils.GetLocalIP(); ip != "" { - fmt.Printf(" >> http://%s:%s <<\n", ip, effectivePort) + if ip := advertiseIPForWildcardBindHost(effectiveHost); ip != "" { + fmt.Printf(" >> http://%s <<\n", net.JoinHostPort(ip, effectivePort)) } } if hostExplicit { @@ -401,8 +416,8 @@ func main() { // Log startup info to file logger.InfoC("web", fmt.Sprintf("Server will listen on http://%s", net.JoinHostPort(effectiveHost, effectivePort))) if isWildcardBindHost(effectiveHost) { - if ip := utils.GetLocalIP(); ip != "" { - logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s:%s", ip, effectivePort)) + if ip := advertiseIPForWildcardBindHost(effectiveHost); ip != "" { + logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s", net.JoinHostPort(ip, effectivePort))) } } diff --git a/web/backend/main_test.go b/web/backend/main_test.go index 40555dbe1..6f68e61ac 100644 --- a/web/backend/main_test.go +++ b/web/backend/main_test.go @@ -201,3 +201,27 @@ func TestBrowserHostForLauncher(t *testing.T) { t.Fatalf("browserHostForLauncher(192.168.1.10) = %q, want %q", got, "192.168.1.10") } } + +func TestWildcardAdvertiseIP(t *testing.T) { + tests := []struct { + name string + bindHost string + ipv4 string + ipv6 string + want string + }{ + {name: "ipv4 wildcard uses ipv4", bindHost: "0.0.0.0", ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: "192.168.1.2"}, + {name: "ipv6 wildcard uses ipv6", bindHost: "::", ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: "2001:db8::1"}, + {name: "ipv6 wildcard with no ipv6 address", bindHost: "::", ipv4: "192.168.1.2", ipv6: "", want: ""}, + {name: "ipv4 wildcard with no ipv4 address", bindHost: "0.0.0.0", ipv4: "", ipv6: "2001:db8::1", want: ""}, + {name: "non wildcard does not advertise", bindHost: "127.0.0.1", ipv4: "192.168.1.2", ipv6: "2001:db8::1", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := wildcardAdvertiseIP(tt.bindHost, tt.ipv4, tt.ipv6); got != tt.want { + t.Fatalf("wildcardAdvertiseIP(%q, %q, %q) = %q, want %q", tt.bindHost, tt.ipv4, tt.ipv6, got, tt.want) + } + }) + } +} diff --git a/web/backend/utils/runtime.go b/web/backend/utils/runtime.go index 0b9e30979..7cceff707 100644 --- a/web/backend/utils/runtime.go +++ b/web/backend/utils/runtime.go @@ -54,8 +54,8 @@ func FindPicoclawBinary() string { return "picoclaw" } -// GetLocalIP returns the local IP address of the machine. -func GetLocalIP() string { +// GetLocalIPv4 returns a non-loopback local IPv4 address. +func GetLocalIPv4() string { addrs, err := net.InterfaceAddrs() if err != nil { return "" @@ -68,6 +68,34 @@ func GetLocalIP() string { return "" } +// GetLocalIPv6 returns a non-loopback local IPv6 address. +func GetLocalIPv6() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP == nil { + continue + } + ip := ipnet.IP + if ip.IsLoopback() || ip.To4() != nil { + continue + } + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + continue + } + return ip.String() + } + return "" +} + +// GetLocalIP returns a non-loopback local IPv4 address for backward compatibility. +func GetLocalIP() string { + return GetLocalIPv4() +} + // OpenBrowser automatically opens the given URL in the default browser. func OpenBrowser(url string) error { switch runtime.GOOS {