mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(web): stop pinning Pico WebSocket origins during setup
- remove request-origin seeding from `EnsurePicoChannel` - keep `allow_origins` empty by default for auto-configured Pico channels - relax launcher Pico WebSocket proxy origin validation - update Pico backend tests for the new setup and proxy behavior
This commit is contained in:
@@ -732,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
|||||||
gateway.logs.Reset()
|
gateway.logs.Reset()
|
||||||
|
|
||||||
// Ensure Pico Channel is configured before starting gateway
|
// Ensure Pico Channel is configured before starting gateway
|
||||||
changed, err := h.EnsurePicoChannel("")
|
changed, err := h.EnsurePicoChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
|
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
|
||||||
// Non-fatal: gateway can still start without pico channel
|
// Non-fatal: gateway can still start without pico channel
|
||||||
|
|||||||
+4
-70
@@ -5,11 +5,8 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
@@ -58,52 +55,6 @@ func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *h
|
|||||||
return wsProxy
|
return wsProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
func canonicalOrigin(raw string) (string, bool) {
|
|
||||||
u, err := url.Parse(strings.TrimSpace(raw))
|
|
||||||
if err != nil || u == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
scheme := strings.ToLower(strings.TrimSpace(u.Scheme))
|
|
||||||
if scheme != "http" && scheme != "https" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
host := strings.TrimSpace(u.Hostname())
|
|
||||||
if host == "" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
port := u.Port()
|
|
||||||
if port == "" {
|
|
||||||
if scheme == "https" {
|
|
||||||
port = "443"
|
|
||||||
} else {
|
|
||||||
port = "80"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return scheme + "://" + net.JoinHostPort(host, port), true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) expectedPicoProxyOrigin(r *http.Request) string {
|
|
||||||
return requestHTTPScheme(r) + "://" + h.picoWebUIAddr(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) validPicoProxyOrigin(r *http.Request) bool {
|
|
||||||
want, ok := canonicalOrigin(h.expectedPicoProxyOrigin(r))
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
got, ok := canonicalOrigin(r.Header.Get("Origin"))
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return got == want
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) {
|
func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return config.PicoSettings{}, false
|
return config.PicoSettings{}, false
|
||||||
@@ -146,16 +97,10 @@ func (h *Handler) writePicoInfoResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||||
// It relies on launcher dashboard auth and same-origin browser access, then
|
// It relies on launcher dashboard auth, then injects the raw pico token only
|
||||||
// injects the raw pico token only on the upstream gateway request.
|
// on the upstream gateway request.
|
||||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !h.validPicoProxyOrigin(r) {
|
|
||||||
logger.Warnf("Invalid Pico WebSocket origin: %q", r.Header.Get("Origin"))
|
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
gateway.mu.Lock()
|
gateway.mu.Lock()
|
||||||
ensurePicoTokenCachedLocked(h.configPath)
|
ensurePicoTokenCachedLocked(h.configPath)
|
||||||
cachedPID := gateway.pidData
|
cachedPID := gateway.pidData
|
||||||
@@ -252,12 +197,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't
|
// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't
|
||||||
// already configured. Returns true when the config was modified.
|
// already configured. Returns true when the config was modified.
|
||||||
//
|
func (h *Handler) EnsurePicoChannel() (bool, error) {
|
||||||
// callerOrigin is the Origin header from the setup request. If non-empty and
|
|
||||||
// no origins are configured yet, it's written as the allowed origin so the
|
|
||||||
// WebSocket handshake works for whatever host the caller is on (LAN, custom
|
|
||||||
// port, etc.). Pass "" when there's no request context.
|
|
||||||
func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
|
||||||
cfg, err := config.LoadConfig(h.configPath)
|
cfg, err := config.LoadConfig(h.configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to load config: %w", err)
|
return false, fmt.Errorf("failed to load config: %w", err)
|
||||||
@@ -282,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
|||||||
picoCfg.Token = *config.NewSecureString(generateSecureToken())
|
picoCfg.Token = *config.NewSecureString(generateSecureToken())
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Seed origins from the request instead of hardcoding ports.
|
|
||||||
if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" {
|
|
||||||
picoCfg.AllowOrigins = []string{callerOrigin}
|
|
||||||
changed = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,7 +238,7 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
|||||||
//
|
//
|
||||||
// POST /api/pico/setup
|
// POST /api/pico/setup
|
||||||
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||||
changed, err := h.EnsurePicoChannel(r.Header.Get("Origin"))
|
changed, err := h.EnsurePicoChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
|||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
changed, err := h.EnsurePicoChannel("")
|
changed, err := h.EnsurePicoChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -56,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
|||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) {
|
||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil {
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
|||||||
t.Fatalf("GetDecoded() error = %v", err)
|
t.Fatalf("GetDecoded() error = %v", err)
|
||||||
}
|
}
|
||||||
picoCfg := decoded.(*config.PicoSettings)
|
picoCfg := decoded.(*config.PicoSettings)
|
||||||
for _, origin := range picoCfg.AllowOrigins {
|
|
||||||
if origin == "*" {
|
|
||||||
t.Error("setup must not set wildcard origin '*'")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
|
|
||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
||||||
h := NewHandler(configPath)
|
|
||||||
|
|
||||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := config.LoadConfig(configPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("LoadConfig() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bc := cfg.Channels["pico"]
|
|
||||||
decoded, err := bc.GetDecoded()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetDecoded() error = %v", err)
|
|
||||||
}
|
|
||||||
picoCfg := decoded.(*config.PicoSettings)
|
|
||||||
// Without a caller origin, allow_origins stays empty (CheckOrigin
|
|
||||||
// allows all when the list is empty, so the channel still works).
|
|
||||||
if len(picoCfg.AllowOrigins) != 0 {
|
if len(picoCfg.AllowOrigins) != 0 {
|
||||||
t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins)
|
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) {
|
||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
lanOrigin := "http://192.168.1.9:18800"
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||||
if _, err := h.EnsurePicoChannel(lanOrigin); err != nil {
|
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
|||||||
t.Fatalf("GetDecoded() error = %v", err)
|
t.Fatalf("GetDecoded() error = %v", err)
|
||||||
}
|
}
|
||||||
picoCfg := decoded.(*config.PicoSettings)
|
picoCfg := decoded.(*config.PicoSettings)
|
||||||
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin {
|
if len(picoCfg.AllowOrigins) != 0 {
|
||||||
t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin)
|
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
|
|||||||
|
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
changed, err := h.EnsurePicoChannel("")
|
changed, err := h.EnsurePicoChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -218,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) {
|
|||||||
|
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
changed, err := h.EnsurePicoChannel("")
|
changed, err := h.EnsurePicoChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -258,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
|||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
origin := "http://localhost:18800"
|
|
||||||
|
|
||||||
// First call sets things up
|
// First call sets things up
|
||||||
if _, err := h.EnsurePicoChannel(origin); err != nil {
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||||
t.Fatalf("first EnsurePicoChannel() error = %v", err)
|
t.Fatalf("first EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
|||||||
token1 := picoCfg.Token.String()
|
token1 := picoCfg.Token.String()
|
||||||
|
|
||||||
// Second call should be a no-op
|
// Second call should be a no-op
|
||||||
changed, err := h.EnsurePicoChannel(origin)
|
changed, err := h.EnsurePicoChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("second EnsurePicoChannel() error = %v", err)
|
t.Fatalf("second EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -322,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) {
|
||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
@@ -347,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
|||||||
t.Fatalf("GetDecoded() error = %v", err)
|
t.Fatalf("GetDecoded() error = %v", err)
|
||||||
}
|
}
|
||||||
picoCfg := decoded.(*config.PicoSettings)
|
picoCfg := decoded.(*config.PicoSettings)
|
||||||
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" {
|
if len(picoCfg.AllowOrigins) != 0 {
|
||||||
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins)
|
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,7 +360,7 @@ func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
|
|||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
|
|
||||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -741,45 +710,75 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) {
|
func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) {
|
||||||
|
origMatcher := gatewayProcessMatcher
|
||||||
|
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||||
|
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||||
|
|
||||||
|
home := t.TempDir()
|
||||||
|
t.Setenv("PICOCLAW_HOME", home)
|
||||||
|
|
||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||||
h := NewHandler(configPath)
|
h := NewHandler(configPath)
|
||||||
handler := h.handleWebSocketProxy()
|
handler := h.handleWebSocketProxy()
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/pico/ws" {
|
||||||
|
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = io.WriteString(w, "proxied")
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
cfg.Gateway.Host = "127.0.0.1"
|
||||||
|
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
||||||
|
bc := cfg.Channels["pico"]
|
||||||
|
bc.Enabled = true
|
||||||
|
decoded, err := bc.GetDecoded()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetDecoded() error = %v", err)
|
||||||
|
}
|
||||||
|
decoded.(*config.PicoSettings).SetToken("ui-token")
|
||||||
|
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||||
|
t.Fatalf("SaveConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := startGatewayLikeProcess(t)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if cmd.Process != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
_ = cmd.Wait()
|
||||||
|
})
|
||||||
|
writeTestPidFile(t, ppid.PidFileData{
|
||||||
|
PID: cmd.Process.Pid,
|
||||||
|
Token: "test-token",
|
||||||
|
Host: cfg.Gateway.Host,
|
||||||
|
Port: cfg.Gateway.Port,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
ppid.RemovePidFile(globalConfigDir())
|
||||||
|
})
|
||||||
|
|
||||||
|
origPidData := gateway.pidData
|
||||||
|
origPicoToken := gateway.picoToken
|
||||||
|
t.Cleanup(func() {
|
||||||
|
gateway.pidData = origPidData
|
||||||
|
gateway.picoToken = origPicoToken
|
||||||
|
})
|
||||||
|
|
||||||
|
gateway.pidData = &ppid.PidFileData{}
|
||||||
|
gateway.picoToken = "ui-token"
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil)
|
||||||
req.Header.Set("Origin", "http://evil.example")
|
req.Header.Set("Origin", "http://evil.example")
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
handler(rec, req)
|
handler(rec, req)
|
||||||
|
|
||||||
if rec.Code != http.StatusForbidden {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) {
|
|
||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
||||||
h := NewHandler(configPath)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
|
||||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
|
||||||
req.Header.Set("X-Forwarded-Proto", "https")
|
|
||||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
|
||||||
|
|
||||||
if !h.validPicoProxyOrigin(req) {
|
|
||||||
t.Fatal("validPicoProxyOrigin() = false, want true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) {
|
|
||||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
|
||||||
h := NewHandler(configPath)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
|
||||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
|
||||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
|
||||||
|
|
||||||
if h.validPicoProxyOrigin(req) {
|
|
||||||
t.Fatal("validPicoProxyOrigin() = true, want false")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -544,7 +544,7 @@ func main() {
|
|||||||
// API Routes (e.g. /api/status)
|
// API Routes (e.g. /api/status)
|
||||||
apiHandler = api.NewHandler(absPath)
|
apiHandler = api.NewHandler(absPath)
|
||||||
apiHandler.SetDebug(debug)
|
apiHandler.SetDebug(debug)
|
||||||
if _, err = apiHandler.EnsurePicoChannel(""); err != nil {
|
if _, err = apiHandler.EnsurePicoChannel(); err != nil {
|
||||||
logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err))
|
logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err))
|
||||||
}
|
}
|
||||||
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
|
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
|
||||||
|
|||||||
Reference in New Issue
Block a user