package api import ( "crypto/tls" "errors" "net" "net/http" "net/http/httptest" "path/filepath" "sync" "testing" "time" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/web/backend/launcherconfig" ) func resetAdaptiveIPFamiliesForTest() { adaptiveIPFamiliesOnce = sync.Once{} adaptiveHasIPv4 = false adaptiveHasIPv6 = false } func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") launcherPath := launcherconfig.PathForAppConfig(configPath) if err := launcherconfig.Save(launcherPath, launcherconfig.Config{ Port: 18800, Public: false, }); err != nil { t.Fatalf("launcherconfig.Save() error = %v", err) } h := NewHandler(configPath) h.SetServerOptions(18800, true, true, nil) if got := h.gatewayHostOverride(); got != resolveDefaultAnyHost() { t.Fatalf("gatewayHostOverride() = %q, want %q", got, resolveDefaultAnyHost()) } } func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") launcherPath := launcherconfig.PathForAppConfig(configPath) if err := launcherconfig.Save(launcherPath, launcherconfig.Config{ Port: 18800, Public: true, }); err != nil { t.Fatalf("launcherconfig.Save() error = %v", err) } h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) cfg := config.DefaultConfig() cfg.Gateway.Host = "127.0.0.1" cfg.Gateway.Port = 18790 req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil) req.Host = "192.168.1.9:18800" if got := h.buildWsURL(req); got != "ws://192.168.1.9:18800/pico/ws" { t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18800/pico/ws") } if got := h.buildPicoEventsURL(req); got != "http://192.168.1.9:18800/pico/events" { t.Fatalf("buildPicoEventsURL() = %q, want %q", got, "http://192.168.1.9:18800/pico/events") } if got := h.buildPicoSendURL(req); got != "http://192.168.1.9:18800/pico/send" { t.Fatalf("buildPicoSendURL() = %q, want %q", got, "http://192.168.1.9:18800/pico/send") } } func TestSelectAdaptiveLoopbackHost(t *testing.T) { tests := []struct { name string hasIPv4 bool hasIPv6 bool want string }{ {name: "dual stack prefers localhost", hasIPv4: true, hasIPv6: true, want: "localhost"}, {name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::1"}, {name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "127.0.0.1"}, {name: "fallback", hasIPv4: false, hasIPv6: false, want: "localhost"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := selectAdaptiveLoopbackHost(tt.hasIPv4, tt.hasIPv6); got != tt.want { t.Fatalf("selectAdaptiveLoopbackHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want) } }) } } func TestSelectAdaptiveAnyHost(t *testing.T) { tests := []struct { name string hasIPv4 bool hasIPv6 bool want string }{ {name: "dual stack prefers ipv6 wildcard", hasIPv4: true, hasIPv6: true, want: "::"}, {name: "ipv6 only", hasIPv4: false, hasIPv6: true, want: "::"}, {name: "ipv4 only", hasIPv4: true, hasIPv6: false, want: "0.0.0.0"}, {name: "fallback", hasIPv4: false, hasIPv6: false, want: "::"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := selectAdaptiveAnyHost(tt.hasIPv4, tt.hasIPv6); got != tt.want { t.Fatalf("selectAdaptiveAnyHost(%t, %t) = %q, want %q", tt.hasIPv4, tt.hasIPv6, got, tt.want) } }) } } func TestAdaptiveHostSelectionFallsBackToInterfaceAddrs(t *testing.T) { oldLookup := lookupLocalhostIPs oldList := listInterfaceAddrs lookupLocalhostIPs = func() ([]net.IP, error) { return nil, errors.New("lookup failed") } _, v4Net, err := net.ParseCIDR("192.0.2.10/24") if err != nil { t.Fatalf("ParseCIDR() error = %v", err) } listInterfaceAddrs = func() ([]net.Addr, error) { return []net.Addr{v4Net}, nil } resetAdaptiveIPFamiliesForTest() t.Cleanup(func() { lookupLocalhostIPs = oldLookup listInterfaceAddrs = oldList resetAdaptiveIPFamiliesForTest() }) if got := resolveDefaultAnyHost(); got != "0.0.0.0" { t.Fatalf("resolveDefaultAnyHost() = %q, want %q", got, "0.0.0.0") } if got := resolveDefaultLoopbackHost(); got != "127.0.0.1" { t.Fatalf("resolveDefaultLoopbackHost() = %q, want %q", got, "127.0.0.1") } } func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) { want := resolveDefaultLoopbackHost() if got := gatewayProbeHost("0.0.0.0"); got != want { t.Fatalf("gatewayProbeHost() = %q, want %q", got, want) } } func TestGatewayProbeHostUsesPreferredLoopbackForEmptyBind(t *testing.T) { want := resolveDefaultLoopbackHost() if got := gatewayProbeHost(""); got != want { t.Fatalf("gatewayProbeHost(empty) = %q, want %q", got, want) } } func TestGatewayProbeHostUsesPreferredLoopbackForLocalhostBind(t *testing.T) { want := resolveLocalhostLoopbackHost() if got := gatewayProbeHost("localhost"); got != want { t.Fatalf("gatewayProbeHost(localhost) = %q, want %q", got, want) } } func TestGatewayProbeHostUsesLoopbackForIPv6WildcardBind(t *testing.T) { want := resolveDefaultLoopbackHost() if got := gatewayProbeHost("::"); got != want { t.Fatalf("gatewayProbeHost(::) = %q, want %q", got, want) } } func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) cfg := config.DefaultConfig() cfg.Gateway.Host = "192.168.1.10" cfg.Gateway.Port = 18791 if err := config.SaveConfig(configPath, cfg); err != nil { t.Fatalf("SaveConfig() error = %v", err) } if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" { t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791") } } func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) cfg := config.DefaultConfig() cfg.Gateway.Host = "192.168.1.10" cfg.Gateway.Port = 18791 originalHealthGet := gatewayHealthGet t.Cleanup(func() { gatewayHealthGet = originalHealthGet }) var requestedURL string gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { requestedURL = url return nil, errors.New("probe failed") } _, statusCode, err := h.getGatewayHealth(cfg, time.Second) _ = statusCode _ = err if requestedURL != "http://192.168.1.10:18791/health" { t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health") } } func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) h.SetServerOptions(18800, true, true, nil) cfg := config.DefaultConfig() cfg.Gateway.Host = "127.0.0.1" cfg.Gateway.Port = 18791 originalHealthGet := gatewayHealthGet t.Cleanup(func() { gatewayHealthGet = originalHealthGet }) var requestedURL string gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { requestedURL = url return nil, errors.New("probe failed") } _, statusCode, err := h.getGatewayHealth(cfg, time.Second) _ = statusCode _ = err want := "http://" + net.JoinHostPort(resolveDefaultLoopbackHost(), "18791") + "/health" if requestedURL != want { t.Fatalf("health url = %q, want %q", requestedURL, want) } } func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) cfg := config.DefaultConfig() cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil) req.Host = "chat.example.com" req.Header.Set("X-Forwarded-Proto", "https") if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" { t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws") } } func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) cfg := config.DefaultConfig() cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil) req.Host = "secure.example.com" req.TLS = &tls.ConnectionState{} if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" { t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws") } } func TestBuildPicoURLsPreferXForwardedHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") launcherPath := launcherconfig.PathForAppConfig(configPath) if err := launcherconfig.Save(launcherPath, launcherconfig.Config{ Port: 18800, Public: true, }); err != nil { t.Fatalf("launcherconfig.Save() error = %v", err) } h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) cfg := config.DefaultConfig() cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/token", nil) req.Host = "127.0.0.1:18800" req.Header.Set("X-Forwarded-Host", "vscode-tunnel.example.com") req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Port", "443") if got := h.buildPicoEventsURL(req); got != "https://vscode-tunnel.example.com:443/pico/events" { t.Fatalf("buildPicoEventsURL() = %q, want %q", got, "https://vscode-tunnel.example.com:443/pico/events") } if got := h.buildPicoSendURL(req); got != "https://vscode-tunnel.example.com:443/pico/send" { t.Fatalf("buildPicoSendURL() = %q, want %q", got, "https://vscode-tunnel.example.com:443/pico/send") } if got := h.buildWsURL(req); got != "wss://vscode-tunnel.example.com:443/pico/ws" { t.Fatalf("buildWsURL() = %q, want %q", got, "wss://vscode-tunnel.example.com:443/pico/ws") } } func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) cfg := config.DefaultConfig() cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil) req.Host = "chat.example.com" req.TLS = &tls.ConnectionState{} req.Header.Set("X-Forwarded-Proto", "http") if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" { t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws") } } func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/token", nil) req.Host = "localhost:18800" if got := h.buildWsURL(req); got != "ws://localhost:18800/pico/ws" { t.Fatalf("buildWsURL() = %q, want %q", got, "ws://localhost:18800/pico/ws") } } func TestGatewayHostOverrideWithExplicitHostAndAlignedGatewayHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") writeGatewayHostConfig(t, configPath, "127.0.0.1") h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) h.SetServerBindHost("0.0.0.0", true) if got := h.gatewayHostOverride(); got != resolveDefaultAnyHost() { t.Fatalf("gatewayHostOverride() = %q, want %q", got, resolveDefaultAnyHost()) } } func TestGatewayHostOverrideWithExplicitHostAndLocalhostGatewayHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") writeGatewayHostConfig(t, configPath, "localhost") h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) h.SetServerBindHost("::", true) if got := h.gatewayHostOverride(); got != "::" { t.Fatalf("gatewayHostOverride() = %q, want %q", got, "::") } } func TestGatewayHostOverrideWithExplicitHostAndMismatchedGatewayHost(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") writeGatewayHostConfig(t, configPath, "0.0.0.0") h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) h.SetServerBindHost("192.168.1.10", true) if got := h.gatewayHostOverride(); got != "" { t.Fatalf("gatewayHostOverride() = %q, want empty", got) } } func TestGatewayHostExplicitIgnoresPublicFlag(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") writeGatewayHostConfig(t, configPath, "127.0.0.1") h := NewHandler(configPath) h.SetServerOptions(18800, true, true, nil) h.SetServerBindHost("127.0.0.1", true) if got := h.effectiveLauncherPublic(); got { t.Fatalf("effectiveLauncherPublic() = %t, want false when explicit host is set", got) } } func writeGatewayHostConfig(t *testing.T, configPath, host string) { t.Helper() cfg := config.DefaultConfig() cfg.Gateway.Host = host if err := config.SaveConfig(configPath, cfg); err != nil { t.Fatalf("SaveConfig() error = %v", err) } }