From f8190f04b7db62556a8e0cdb63a0a552b60752ce Mon Sep 17 00:00:00 2001 From: wenjie Date: Thu, 16 Apr 2026 19:04:47 +0800 Subject: [PATCH] fix(web): stop pinning Pico WebSocket origins during setup - remove request-origin seeding from `EnsurePicoChannel` - keep `allow_origins` empty by default for auto-configured Pico channels - relax launcher Pico WebSocket proxy origin validation - update Pico backend tests for the new setup and proxy behavior --- web/backend/api/gateway.go | 2 +- web/backend/api/pico.go | 74 +--------------- web/backend/api/pico_test.go | 159 +++++++++++++++++------------------ web/backend/main.go | 2 +- 4 files changed, 85 insertions(+), 152 deletions(-) diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index ea43789d3..201000ff3 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -732,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int gateway.logs.Reset() // Ensure Pico Channel is configured before starting gateway - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err)) // Non-fatal: gateway can still start without pico channel diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 5e4848b01..ffd0796c7 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -5,11 +5,8 @@ import ( "encoding/hex" "encoding/json" "fmt" - "net" "net/http" "net/http/httputil" - "net/url" - "strings" "time" "github.com/sipeed/picoclaw/pkg/config" @@ -58,52 +55,6 @@ func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *h return wsProxy } -func canonicalOrigin(raw string) (string, bool) { - u, err := url.Parse(strings.TrimSpace(raw)) - if err != nil || u == nil { - return "", false - } - - scheme := strings.ToLower(strings.TrimSpace(u.Scheme)) - if scheme != "http" && scheme != "https" { - return "", false - } - - host := strings.TrimSpace(u.Hostname()) - if host == "" { - return "", false - } - - port := u.Port() - if port == "" { - if scheme == "https" { - port = "443" - } else { - port = "80" - } - } - - return scheme + "://" + net.JoinHostPort(host, port), true -} - -func (h *Handler) expectedPicoProxyOrigin(r *http.Request) string { - return requestHTTPScheme(r) + "://" + h.picoWebUIAddr(r) -} - -func (h *Handler) validPicoProxyOrigin(r *http.Request) bool { - want, ok := canonicalOrigin(h.expectedPicoProxyOrigin(r)) - if !ok { - return false - } - - got, ok := canonicalOrigin(r.Header.Get("Origin")) - if !ok { - return false - } - - return got == want -} - func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) { if cfg == nil { return config.PicoSettings{}, false @@ -146,16 +97,10 @@ func (h *Handler) writePicoInfoResponse( } // handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections. -// It relies on launcher dashboard auth and same-origin browser access, then -// injects the raw pico token only on the upstream gateway request. +// It relies on launcher dashboard auth, then injects the raw pico token only +// on the upstream gateway request. func (h *Handler) handleWebSocketProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if !h.validPicoProxyOrigin(r) { - logger.Warnf("Invalid Pico WebSocket origin: %q", r.Header.Get("Origin")) - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - gateway.mu.Lock() ensurePicoTokenCachedLocked(h.configPath) cachedPID := gateway.pidData @@ -252,12 +197,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { // EnsurePicoChannel enables the Pico channel with sane defaults if it isn't // already configured. Returns true when the config was modified. -// -// callerOrigin is the Origin header from the setup request. If non-empty and -// no origins are configured yet, it's written as the allowed origin so the -// WebSocket handshake works for whatever host the caller is on (LAN, custom -// port, etc.). Pass "" when there's no request context. -func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { +func (h *Handler) EnsurePicoChannel() (bool, error) { cfg, err := config.LoadConfig(h.configPath) if err != nil { return false, fmt.Errorf("failed to load config: %w", err) @@ -282,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { picoCfg.Token = *config.NewSecureString(generateSecureToken()) changed = true } - - // Seed origins from the request instead of hardcoding ports. - if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" { - picoCfg.AllowOrigins = []string{callerOrigin} - changed = true - } } } @@ -304,7 +238,7 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { // // POST /api/pico/setup func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { - changed, err := h.EnsurePicoChannel(r.Header.Get("Origin")) + changed, err := h.EnsurePicoChannel() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 34b011127..a56cd9ba2 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -25,7 +25,7 @@ func TestEnsurePicoChannel_FreshConfig(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -56,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -76,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { } } -func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { +func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -95,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - for _, origin := range picoCfg.AllowOrigins { - if origin == "*" { - t.Error("setup must not set wildcard origin '*'") - } - } -} - -func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "config.json") - h := NewHandler(configPath) - - if _, err := h.EnsurePicoChannel(""); err != nil { - t.Fatalf("EnsurePicoChannel() error = %v", err) - } - - cfg, err := config.LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error = %v", err) - } - - bc := cfg.Channels["pico"] - decoded, err := bc.GetDecoded() - if err != nil { - t.Fatalf("GetDecoded() error = %v", err) - } - picoCfg := decoded.(*config.PicoSettings) - // Without a caller origin, allow_origins stays empty (CheckOrigin - // allows all when the list is empty, so the channel still works). if len(picoCfg.AllowOrigins) != 0 { - t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins) + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } -func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { +func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - lanOrigin := "http://192.168.1.9:18800" - if _, err := h.EnsurePicoChannel(lanOrigin); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -148,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin { - t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin) + if len(picoCfg.AllowOrigins) != 0 { + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } @@ -174,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -218,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) { h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -258,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) { } h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -285,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - origin := "http://localhost:18800" - // First call sets things up - if _, err := h.EnsurePicoChannel(origin); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("first EnsurePicoChannel() error = %v", err) } @@ -302,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { token1 := picoCfg.Token.String() // Second call should be a no-op - changed, err := h.EnsurePicoChannel(origin) + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("second EnsurePicoChannel() error = %v", err) } @@ -322,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { } } -func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) { +func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -347,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" { - t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins) + if len(picoCfg.AllowOrigins) != 0 { + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } @@ -391,7 +360,7 @@ func TestHandleGetPicoInfo_OmitsToken(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -741,45 +710,75 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) { } } -func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) { +func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) { + origMatcher := gatewayProcessMatcher + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } + t.Cleanup(func() { gatewayProcessMatcher = origMatcher }) + + home := t.TempDir() + t.Setenv("PICOCLAW_HOME", home) + configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) handler := h.handleWebSocketProxy() - req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/pico/ws" { + t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws") + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "proxied") + })) + defer server.Close() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = mustGatewayTestPort(t, server.URL) + bc := cfg.Channels["pico"] + bc.Enabled = true + decoded, err := bc.GetDecoded() + if err != nil { + t.Fatalf("GetDecoded() error = %v", err) + } + decoded.(*config.PicoSettings).SetToken("ui-token") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: cfg.Gateway.Host, + Port: cfg.Gateway.Port, + }) + t.Cleanup(func() { + ppid.RemovePidFile(globalConfigDir()) + }) + + origPidData := gateway.pidData + origPicoToken := gateway.picoToken + t.Cleanup(func() { + gateway.pidData = origPidData + gateway.picoToken = origPicoToken + }) + + gateway.pidData = &ppid.PidFileData{} + gateway.picoToken = "ui-token" + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil) req.Header.Set("Origin", "http://evil.example") rec := httptest.NewRecorder() handler(rec, req) - if rec.Code != http.StatusForbidden { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) - } -} - -func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "config.json") - h := NewHandler(configPath) - - req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) - req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" - req.Header.Set("X-Forwarded-Proto", "https") - req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") - - if !h.validPicoProxyOrigin(req) { - t.Fatal("validPicoProxyOrigin() = false, want true") - } -} - -func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "config.json") - h := NewHandler(configPath) - - req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) - req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" - req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") - - if h.validPicoProxyOrigin(req) { - t.Fatal("validPicoProxyOrigin() = true, want false") + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } } diff --git a/web/backend/main.go b/web/backend/main.go index 01ef5edf0..e42558398 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -544,7 +544,7 @@ func main() { // API Routes (e.g. /api/status) apiHandler = api.NewHandler(absPath) apiHandler.SetDebug(debug) - if _, err = apiHandler.EnsurePicoChannel(""); err != nil { + if _, err = apiHandler.EnsurePicoChannel(); err != nil { logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) } apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)