mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(web): improve Pico URL and origin handling behind proxies
- read client scheme from X-Forwarded-Proto and RFC 7239 Forwarded - derive client-visible ports from forwarded host information - add coverage for HTTPS origins without explicit ports - verify behavior when proxies omit forwarded protocol headers
This commit is contained in:
@@ -85,8 +85,22 @@ func requestHostName(r *http.Request) string {
|
||||
return netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func forwardedProtoFirst(r *http.Request) string {
|
||||
raw := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto"))
|
||||
if raw == "" {
|
||||
raw = forwardedRFC7239Proto(r)
|
||||
}
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.IndexByte(raw, ','); i >= 0 {
|
||||
raw = strings.TrimSpace(raw[:i])
|
||||
}
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
|
||||
func requestWSScheme(r *http.Request) string {
|
||||
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
|
||||
if forwarded := forwardedProtoFirst(r); forwarded != "" {
|
||||
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
|
||||
if proto == "https" || proto == "wss" {
|
||||
return "wss"
|
||||
@@ -105,7 +119,7 @@ func requestWSScheme(r *http.Request) string {
|
||||
|
||||
// requestHTTPScheme returns http or https for URLs that are not WebSockets (e.g. SSE).
|
||||
func requestHTTPScheme(r *http.Request) string {
|
||||
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
|
||||
if forwarded := forwardedProtoFirst(r); forwarded != "" {
|
||||
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
|
||||
if proto == "https" || proto == "wss" {
|
||||
return "https"
|
||||
@@ -117,6 +131,7 @@ func requestHTTPScheme(r *http.Request) string {
|
||||
if r.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
return "http"
|
||||
}
|
||||
|
||||
@@ -138,6 +153,14 @@ func forwardedHostFirst(r *http.Request) string {
|
||||
|
||||
// forwardedRFC7239Host parses host= from the first Forwarded header element (RFC 7239).
|
||||
func forwardedRFC7239Host(r *http.Request) string {
|
||||
return forwardedRFC7239Param(r, "host")
|
||||
}
|
||||
|
||||
func forwardedRFC7239Proto(r *http.Request) string {
|
||||
return forwardedRFC7239Param(r, "proto")
|
||||
}
|
||||
|
||||
func forwardedRFC7239Param(r *http.Request, key string) string {
|
||||
v := strings.TrimSpace(r.Header.Get("Forwarded"))
|
||||
if v == "" {
|
||||
return ""
|
||||
@@ -146,7 +169,7 @@ func forwardedRFC7239Host(r *http.Request) string {
|
||||
for _, part := range strings.Split(first, ";") {
|
||||
part = strings.TrimSpace(part)
|
||||
low := strings.ToLower(part)
|
||||
if !strings.HasPrefix(low, "host=") {
|
||||
if !strings.HasPrefix(low, key+"=") {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(part[strings.IndexByte(part, '=')+1:])
|
||||
@@ -177,13 +200,21 @@ func clientVisiblePort(r *http.Request, serverListenPort int) string {
|
||||
if p := forwardedPortFirst(r); p != "" {
|
||||
return p
|
||||
}
|
||||
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
|
||||
if _, port, err := net.SplitHostPort(fwdHost); err == nil && port != "" {
|
||||
return port
|
||||
}
|
||||
}
|
||||
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
|
||||
return port
|
||||
}
|
||||
if strings.TrimSpace(r.Host) == "" && forwardedHostFirst(r) == "" {
|
||||
return strconv.Itoa(serverListenPort)
|
||||
}
|
||||
if requestHTTPScheme(r) == "https" {
|
||||
return "443"
|
||||
}
|
||||
return strconv.Itoa(serverListenPort)
|
||||
return "80"
|
||||
}
|
||||
|
||||
// joinClientVisibleHostPort builds host:port for absolute URLs returned to the browser.
|
||||
@@ -205,16 +236,7 @@ func (h *Handler) picoWebUIAddr(r *http.Request) string {
|
||||
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
|
||||
return joinClientVisibleHostPort(r, fwdHost, wsPort)
|
||||
}
|
||||
host := requestHostName(r)
|
||||
// Use clientVisiblePort only when an explicit port is present in headers
|
||||
// or Host header — do not infer from TLS/scheme, as serverPort takes priority.
|
||||
if p := forwardedPortFirst(r); p != "" {
|
||||
return net.JoinHostPort(host, p)
|
||||
}
|
||||
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
return net.JoinHostPort(host, strconv.Itoa(wsPort))
|
||||
return joinClientVisibleHostPort(r, requestHostName(r), wsPort)
|
||||
}
|
||||
|
||||
func (h *Handler) buildWsURL(r *http.Request) string {
|
||||
|
||||
@@ -185,8 +185,8 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
|
||||
req.Host = "chat.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws")
|
||||
if got := h.buildWsURL(req); got != "wss://chat.example.com:443/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:443/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,8 +202,8 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
|
||||
if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws")
|
||||
if got := h.buildWsURL(req); got != "wss://secure.example.com:443/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:443/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,8 +254,25 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
|
||||
if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws")
|
||||
if got := h.buildWsURL(req); got != "ws://chat.example.com:80/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:80/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLDoesNotTrustOriginWhenProxyOmitsForwardedProto(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
|
||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
||||
|
||||
if got := h.buildWsURL(req); got != "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws" {
|
||||
t.Fatalf(
|
||||
"buildWsURL() = %q, want %q",
|
||||
got,
|
||||
"ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -756,6 +756,33 @@ func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
||||
|
||||
if !h.validPicoProxyOrigin(req) {
|
||||
t.Fatal("validPicoProxyOrigin() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
||||
|
||||
if h.validPicoProxyOrigin(req) {
|
||||
t.Fatal("validPicoProxyOrigin() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
||||
t.Helper()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user