fix(web): stop pinning Pico WebSocket origins during setup

- remove request-origin seeding from `EnsurePicoChannel`
- keep `allow_origins` empty by default for auto-configured Pico channels
- relax launcher Pico WebSocket proxy origin validation
- update Pico backend tests for the new setup and proxy behavior
This commit is contained in:
wenjie
2026-04-16 19:04:47 +08:00
parent d002e1517b
commit f8190f04b7
4 changed files with 85 additions and 152 deletions
+1 -1
View File
@@ -732,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
gateway.logs.Reset()
// Ensure Pico Channel is configured before starting gateway
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
// Non-fatal: gateway can still start without pico channel
+4 -70
View File
@@ -5,11 +5,8 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
@@ -58,52 +55,6 @@ func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *h
return wsProxy
}
func canonicalOrigin(raw string) (string, bool) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil || u == nil {
return "", false
}
scheme := strings.ToLower(strings.TrimSpace(u.Scheme))
if scheme != "http" && scheme != "https" {
return "", false
}
host := strings.TrimSpace(u.Hostname())
if host == "" {
return "", false
}
port := u.Port()
if port == "" {
if scheme == "https" {
port = "443"
} else {
port = "80"
}
}
return scheme + "://" + net.JoinHostPort(host, port), true
}
func (h *Handler) expectedPicoProxyOrigin(r *http.Request) string {
return requestHTTPScheme(r) + "://" + h.picoWebUIAddr(r)
}
func (h *Handler) validPicoProxyOrigin(r *http.Request) bool {
want, ok := canonicalOrigin(h.expectedPicoProxyOrigin(r))
if !ok {
return false
}
got, ok := canonicalOrigin(r.Header.Get("Origin"))
if !ok {
return false
}
return got == want
}
func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) {
if cfg == nil {
return config.PicoSettings{}, false
@@ -146,16 +97,10 @@ func (h *Handler) writePicoInfoResponse(
}
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
// It relies on launcher dashboard auth and same-origin browser access, then
// injects the raw pico token only on the upstream gateway request.
// It relies on launcher dashboard auth, then injects the raw pico token only
// on the upstream gateway request.
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !h.validPicoProxyOrigin(r) {
logger.Warnf("Invalid Pico WebSocket origin: %q", r.Header.Get("Origin"))
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
gateway.mu.Lock()
ensurePicoTokenCachedLocked(h.configPath)
cachedPID := gateway.pidData
@@ -252,12 +197,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't
// already configured. Returns true when the config was modified.
//
// callerOrigin is the Origin header from the setup request. If non-empty and
// no origins are configured yet, it's written as the allowed origin so the
// WebSocket handshake works for whatever host the caller is on (LAN, custom
// port, etc.). Pass "" when there's no request context.
func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
func (h *Handler) EnsurePicoChannel() (bool, error) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
return false, fmt.Errorf("failed to load config: %w", err)
@@ -282,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
picoCfg.Token = *config.NewSecureString(generateSecureToken())
changed = true
}
// Seed origins from the request instead of hardcoding ports.
if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" {
picoCfg.AllowOrigins = []string{callerOrigin}
changed = true
}
}
}
@@ -304,7 +238,7 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
//
// POST /api/pico/setup
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
changed, err := h.EnsurePicoChannel(r.Header.Get("Origin"))
changed, err := h.EnsurePicoChannel()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
+79 -80
View File
@@ -25,7 +25,7 @@ func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -56,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -76,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
}
}
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -95,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
for _, origin := range picoCfg.AllowOrigins {
if origin == "*" {
t.Error("setup must not set wildcard origin '*'")
}
}
}
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
// Without a caller origin, allow_origins stays empty (CheckOrigin
// allows all when the list is empty, so the channel still works).
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins)
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
lanOrigin := "http://192.168.1.9:18800"
if _, err := h.EnsurePicoChannel(lanOrigin); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -148,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin {
t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin)
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
@@ -174,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -218,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) {
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -258,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) {
}
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -285,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
origin := "http://localhost:18800"
// First call sets things up
if _, err := h.EnsurePicoChannel(origin); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("first EnsurePicoChannel() error = %v", err)
}
@@ -302,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
token1 := picoCfg.Token.String()
// Second call should be a no-op
changed, err := h.EnsurePicoChannel(origin)
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("second EnsurePicoChannel() error = %v", err)
}
@@ -322,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
}
}
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
@@ -347,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" {
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins)
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
@@ -391,7 +360,7 @@ func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -741,45 +710,75 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
}
}
func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) {
func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) {
origMatcher := gatewayProcessMatcher
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "proxied")
}))
defer server.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
bc := cfg.Channels["pico"]
bc.Enabled = true
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
decoded.(*config.PicoSettings).SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startGatewayLikeProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
writeTestPidFile(t, ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "test-token",
Host: cfg.Gateway.Host,
Port: cfg.Gateway.Port,
})
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
t.Cleanup(func() {
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
})
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "ui-token"
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil)
req.Header.Set("Origin", "http://evil.example")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
}
}
func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
if !h.validPicoProxyOrigin(req) {
t.Fatal("validPicoProxyOrigin() = false, want true")
}
}
func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
if h.validPicoProxyOrigin(req) {
t.Fatal("validPicoProxyOrigin() = true, want false")
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
+1 -1
View File
@@ -544,7 +544,7 @@ func main() {
// API Routes (e.g. /api/status)
apiHandler = api.NewHandler(absPath)
apiHandler.SetDebug(debug)
if _, err = apiHandler.EnsurePicoChannel(""); err != nil {
if _, err = apiHandler.EnsurePicoChannel(); err != nil {
logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err))
}
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)