mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(web): stop pinning Pico WebSocket origins during setup
- remove request-origin seeding from `EnsurePicoChannel` - keep `allow_origins` empty by default for auto-configured Pico channels - relax launcher Pico WebSocket proxy origin validation - update Pico backend tests for the new setup and proxy behavior
This commit is contained in:
@@ -732,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
gateway.logs.Reset()
|
||||
|
||||
// Ensure Pico Channel is configured before starting gateway
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
|
||||
// Non-fatal: gateway can still start without pico channel
|
||||
|
||||
+4
-70
@@ -5,11 +5,8 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -58,52 +55,6 @@ func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *h
|
||||
return wsProxy
|
||||
}
|
||||
|
||||
func canonicalOrigin(raw string) (string, bool) {
|
||||
u, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil || u == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(strings.TrimSpace(u.Scheme))
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(u.Hostname())
|
||||
if host == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
port := u.Port()
|
||||
if port == "" {
|
||||
if scheme == "https" {
|
||||
port = "443"
|
||||
} else {
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
|
||||
return scheme + "://" + net.JoinHostPort(host, port), true
|
||||
}
|
||||
|
||||
func (h *Handler) expectedPicoProxyOrigin(r *http.Request) string {
|
||||
return requestHTTPScheme(r) + "://" + h.picoWebUIAddr(r)
|
||||
}
|
||||
|
||||
func (h *Handler) validPicoProxyOrigin(r *http.Request) bool {
|
||||
want, ok := canonicalOrigin(h.expectedPicoProxyOrigin(r))
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
got, ok := canonicalOrigin(r.Header.Get("Origin"))
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return got == want
|
||||
}
|
||||
|
||||
func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) {
|
||||
if cfg == nil {
|
||||
return config.PicoSettings{}, false
|
||||
@@ -146,16 +97,10 @@ func (h *Handler) writePicoInfoResponse(
|
||||
}
|
||||
|
||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||
// It relies on launcher dashboard auth and same-origin browser access, then
|
||||
// injects the raw pico token only on the upstream gateway request.
|
||||
// It relies on launcher dashboard auth, then injects the raw pico token only
|
||||
// on the upstream gateway request.
|
||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.validPicoProxyOrigin(r) {
|
||||
logger.Warnf("Invalid Pico WebSocket origin: %q", r.Header.Get("Origin"))
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
ensurePicoTokenCachedLocked(h.configPath)
|
||||
cachedPID := gateway.pidData
|
||||
@@ -252,12 +197,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't
|
||||
// already configured. Returns true when the config was modified.
|
||||
//
|
||||
// callerOrigin is the Origin header from the setup request. If non-empty and
|
||||
// no origins are configured yet, it's written as the allowed origin so the
|
||||
// WebSocket handshake works for whatever host the caller is on (LAN, custom
|
||||
// port, etc.). Pass "" when there's no request context.
|
||||
func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
||||
func (h *Handler) EnsurePicoChannel() (bool, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load config: %w", err)
|
||||
@@ -282,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
||||
picoCfg.Token = *config.NewSecureString(generateSecureToken())
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Seed origins from the request instead of hardcoding ports.
|
||||
if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" {
|
||||
picoCfg.AllowOrigins = []string{callerOrigin}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,7 +238,7 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
||||
//
|
||||
// POST /api/pico/setup
|
||||
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
changed, err := h.EnsurePicoChannel(r.Header.Get("Origin"))
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -76,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
||||
func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -95,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
for _, origin := range picoCfg.AllowOrigins {
|
||||
if origin == "*" {
|
||||
t.Error("setup must not set wildcard origin '*'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
bc := cfg.Channels["pico"]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
// Without a caller origin, allow_origins stays empty (CheckOrigin
|
||||
// allows all when the list is empty, so the channel still works).
|
||||
if len(picoCfg.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins)
|
||||
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
||||
func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
lanOrigin := "http://192.168.1.9:18800"
|
||||
if _, err := h.EnsurePicoChannel(lanOrigin); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -148,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin {
|
||||
t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin)
|
||||
if len(picoCfg.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
|
||||
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -218,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) {
|
||||
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -258,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) {
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -285,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
origin := "http://localhost:18800"
|
||||
|
||||
// First call sets things up
|
||||
if _, err := h.EnsurePicoChannel(origin); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("first EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -302,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
token1 := picoCfg.Token.String()
|
||||
|
||||
// Second call should be a no-op
|
||||
changed, err := h.EnsurePicoChannel(origin)
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("second EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -322,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
||||
func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
@@ -347,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" {
|
||||
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins)
|
||||
if len(picoCfg.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,7 +360,7 @@ func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -741,45 +710,75 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) {
|
||||
func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
||||
bc := cfg.Channels["pico"]
|
||||
bc.Enabled = true
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
decoded.(*config.PicoSettings).SetToken("ui-token")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
t.Cleanup(func() {
|
||||
gateway.pidData = origPidData
|
||||
gateway.picoToken = origPicoToken
|
||||
})
|
||||
|
||||
gateway.pidData = &ppid.PidFileData{}
|
||||
gateway.picoToken = "ui-token"
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil)
|
||||
req.Header.Set("Origin", "http://evil.example")
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
||||
|
||||
if !h.validPicoProxyOrigin(req) {
|
||||
t.Fatal("validPicoProxyOrigin() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
|
||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
||||
|
||||
if h.validPicoProxyOrigin(req) {
|
||||
t.Fatal("validPicoProxyOrigin() = true, want false")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user