diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index c6c2073e2..03af7a9d3 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -85,8 +85,22 @@ func requestHostName(r *http.Request) string { return netbind.ResolveAdaptiveLoopbackHost() } +func forwardedProtoFirst(r *http.Request) string { + raw := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")) + if raw == "" { + raw = forwardedRFC7239Proto(r) + } + if raw == "" { + return "" + } + if i := strings.IndexByte(raw, ','); i >= 0 { + raw = strings.TrimSpace(raw[:i]) + } + return strings.ToLower(raw) +} + func requestWSScheme(r *http.Request) string { - if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" { + if forwarded := forwardedProtoFirst(r); forwarded != "" { proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0])) if proto == "https" || proto == "wss" { return "wss" @@ -105,7 +119,7 @@ func requestWSScheme(r *http.Request) string { // requestHTTPScheme returns http or https for URLs that are not WebSockets (e.g. SSE). func requestHTTPScheme(r *http.Request) string { - if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" { + if forwarded := forwardedProtoFirst(r); forwarded != "" { proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0])) if proto == "https" || proto == "wss" { return "https" @@ -117,6 +131,7 @@ func requestHTTPScheme(r *http.Request) string { if r.TLS != nil { return "https" } + return "http" } @@ -138,6 +153,14 @@ func forwardedHostFirst(r *http.Request) string { // forwardedRFC7239Host parses host= from the first Forwarded header element (RFC 7239). func forwardedRFC7239Host(r *http.Request) string { + return forwardedRFC7239Param(r, "host") +} + +func forwardedRFC7239Proto(r *http.Request) string { + return forwardedRFC7239Param(r, "proto") +} + +func forwardedRFC7239Param(r *http.Request, key string) string { v := strings.TrimSpace(r.Header.Get("Forwarded")) if v == "" { return "" @@ -146,7 +169,7 @@ func forwardedRFC7239Host(r *http.Request) string { for _, part := range strings.Split(first, ";") { part = strings.TrimSpace(part) low := strings.ToLower(part) - if !strings.HasPrefix(low, "host=") { + if !strings.HasPrefix(low, key+"=") { continue } val := strings.TrimSpace(part[strings.IndexByte(part, '=')+1:]) @@ -177,13 +200,21 @@ func clientVisiblePort(r *http.Request, serverListenPort int) string { if p := forwardedPortFirst(r); p != "" { return p } + if fwdHost := forwardedHostFirst(r); fwdHost != "" { + if _, port, err := net.SplitHostPort(fwdHost); err == nil && port != "" { + return port + } + } if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" { return port } + if strings.TrimSpace(r.Host) == "" && forwardedHostFirst(r) == "" { + return strconv.Itoa(serverListenPort) + } if requestHTTPScheme(r) == "https" { return "443" } - return strconv.Itoa(serverListenPort) + return "80" } // joinClientVisibleHostPort builds host:port for absolute URLs returned to the browser. @@ -205,16 +236,7 @@ func (h *Handler) picoWebUIAddr(r *http.Request) string { if fwdHost := forwardedHostFirst(r); fwdHost != "" { return joinClientVisibleHostPort(r, fwdHost, wsPort) } - host := requestHostName(r) - // Use clientVisiblePort only when an explicit port is present in headers - // or Host header — do not infer from TLS/scheme, as serverPort takes priority. - if p := forwardedPortFirst(r); p != "" { - return net.JoinHostPort(host, p) - } - if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" { - return net.JoinHostPort(host, port) - } - return net.JoinHostPort(host, strconv.Itoa(wsPort)) + return joinClientVisibleHostPort(r, requestHostName(r), wsPort) } func (h *Handler) buildWsURL(r *http.Request) string { diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index c9802b30b..54d1010d2 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -185,8 +185,8 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) { req.Host = "chat.example.com" req.Header.Set("X-Forwarded-Proto", "https") - if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "wss://chat.example.com:443/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:443/pico/ws") } } @@ -202,8 +202,8 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) { req.Host = "secure.example.com" req.TLS = &tls.ConnectionState{} - if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "wss://secure.example.com:443/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:443/pico/ws") } } @@ -254,8 +254,25 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) { req.TLS = &tls.ConnectionState{} req.Header.Set("X-Forwarded-Proto", "http") - if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "ws://chat.example.com:80/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:80/pico/ws") + } +} + +func TestBuildWsURLDoesNotTrustOriginWhenProxyOmitsForwardedProto(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil) + req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" + req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") + + if got := h.buildWsURL(req); got != "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws" { + t.Fatalf( + "buildWsURL() = %q, want %q", + got, + "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws", + ) } } diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 146f9e697..34b011127 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -756,6 +756,33 @@ func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) { } } +func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) + req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") + + if !h.validPicoProxyOrigin(req) { + t.Fatal("validPicoProxyOrigin() = false, want true") + } +} + +func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) + req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" + req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") + + if h.validPicoProxyOrigin(req) { + t.Fatal("validPicoProxyOrigin() = true, want false") + } +} + func mustGatewayTestPort(t *testing.T, rawURL string) int { t.Helper()