fix(web): improve Pico URL and origin handling behind proxies

- read client scheme from X-Forwarded-Proto and RFC 7239 Forwarded
- derive client-visible ports from forwarded host information
- add coverage for HTTPS origins without explicit ports
- verify behavior when proxies omit forwarded protocol headers
This commit is contained in:
wenjie
2026-04-16 18:31:42 +08:00
parent 4b76196e2c
commit d002e1517b
3 changed files with 86 additions and 20 deletions
+36 -14
View File
@@ -85,8 +85,22 @@ func requestHostName(r *http.Request) string {
return netbind.ResolveAdaptiveLoopbackHost()
}
func forwardedProtoFirst(r *http.Request) string {
raw := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto"))
if raw == "" {
raw = forwardedRFC7239Proto(r)
}
if raw == "" {
return ""
}
if i := strings.IndexByte(raw, ','); i >= 0 {
raw = strings.TrimSpace(raw[:i])
}
return strings.ToLower(raw)
}
func requestWSScheme(r *http.Request) string {
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
if forwarded := forwardedProtoFirst(r); forwarded != "" {
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
if proto == "https" || proto == "wss" {
return "wss"
@@ -105,7 +119,7 @@ func requestWSScheme(r *http.Request) string {
// requestHTTPScheme returns http or https for URLs that are not WebSockets (e.g. SSE).
func requestHTTPScheme(r *http.Request) string {
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
if forwarded := forwardedProtoFirst(r); forwarded != "" {
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
if proto == "https" || proto == "wss" {
return "https"
@@ -117,6 +131,7 @@ func requestHTTPScheme(r *http.Request) string {
if r.TLS != nil {
return "https"
}
return "http"
}
@@ -138,6 +153,14 @@ func forwardedHostFirst(r *http.Request) string {
// forwardedRFC7239Host parses host= from the first Forwarded header element (RFC 7239).
func forwardedRFC7239Host(r *http.Request) string {
return forwardedRFC7239Param(r, "host")
}
func forwardedRFC7239Proto(r *http.Request) string {
return forwardedRFC7239Param(r, "proto")
}
func forwardedRFC7239Param(r *http.Request, key string) string {
v := strings.TrimSpace(r.Header.Get("Forwarded"))
if v == "" {
return ""
@@ -146,7 +169,7 @@ func forwardedRFC7239Host(r *http.Request) string {
for _, part := range strings.Split(first, ";") {
part = strings.TrimSpace(part)
low := strings.ToLower(part)
if !strings.HasPrefix(low, "host=") {
if !strings.HasPrefix(low, key+"=") {
continue
}
val := strings.TrimSpace(part[strings.IndexByte(part, '=')+1:])
@@ -177,13 +200,21 @@ func clientVisiblePort(r *http.Request, serverListenPort int) string {
if p := forwardedPortFirst(r); p != "" {
return p
}
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
if _, port, err := net.SplitHostPort(fwdHost); err == nil && port != "" {
return port
}
}
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
return port
}
if strings.TrimSpace(r.Host) == "" && forwardedHostFirst(r) == "" {
return strconv.Itoa(serverListenPort)
}
if requestHTTPScheme(r) == "https" {
return "443"
}
return strconv.Itoa(serverListenPort)
return "80"
}
// joinClientVisibleHostPort builds host:port for absolute URLs returned to the browser.
@@ -205,16 +236,7 @@ func (h *Handler) picoWebUIAddr(r *http.Request) string {
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
return joinClientVisibleHostPort(r, fwdHost, wsPort)
}
host := requestHostName(r)
// Use clientVisiblePort only when an explicit port is present in headers
// or Host header — do not infer from TLS/scheme, as serverPort takes priority.
if p := forwardedPortFirst(r); p != "" {
return net.JoinHostPort(host, p)
}
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
return net.JoinHostPort(host, port)
}
return net.JoinHostPort(host, strconv.Itoa(wsPort))
return joinClientVisibleHostPort(r, requestHostName(r), wsPort)
}
func (h *Handler) buildWsURL(r *http.Request) string {
+23 -6
View File
@@ -185,8 +185,8 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
req.Host = "chat.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws")
if got := h.buildWsURL(req); got != "wss://chat.example.com:443/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:443/pico/ws")
}
}
@@ -202,8 +202,8 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
req.Host = "secure.example.com"
req.TLS = &tls.ConnectionState{}
if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws")
if got := h.buildWsURL(req); got != "wss://secure.example.com:443/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:443/pico/ws")
}
}
@@ -254,8 +254,25 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
req.TLS = &tls.ConnectionState{}
req.Header.Set("X-Forwarded-Proto", "http")
if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws")
if got := h.buildWsURL(req); got != "ws://chat.example.com:80/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:80/pico/ws")
}
}
func TestBuildWsURLDoesNotTrustOriginWhenProxyOmitsForwardedProto(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
if got := h.buildWsURL(req); got != "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws" {
t.Fatalf(
"buildWsURL() = %q, want %q",
got,
"ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws",
)
}
}
+27
View File
@@ -756,6 +756,33 @@ func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) {
}
}
func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
if !h.validPicoProxyOrigin(req) {
t.Fatal("validPicoProxyOrigin() = false, want true")
}
}
func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
if h.validPicoProxyOrigin(req) {
t.Fatal("validPicoProxyOrigin() = true, want false")
}
}
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()