From 4b76196e2ca7f1d6364bb7e8b27410b271cb4007 Mon Sep 17 00:00:00 2001 From: wenjie Date: Thu, 16 Apr 2026 16:47:23 +0800 Subject: [PATCH 1/3] refactor(web): secure Pico websocket access behind launcher auth - stop exposing the raw Pico token to the frontend - add /api/pico/info for non-secret Pico connection metadata - proxy /pico/ws through the launcher with same-origin and dashboard auth checks - inject the upstream Pico websocket protocol server-side - update frontend chat connection flow and Vite websocket proxy path - refresh related docs and tests --- docs/guides/docker.md | 2 +- pkg/channels/pico/protocol.go | 2 - pkg/gateway/gateway.go | 24 +-- web/backend/api/config.go | 4 - web/backend/api/gateway.go | 32 +-- web/backend/api/gateway_host_test.go | 12 +- web/backend/api/gateway_test.go | 12 ++ web/backend/api/pico.go | 191 ++++++++++++------ web/backend/api/pico_test.go | 91 ++++++--- .../middleware/launcher_dashboard_auth.go | 4 + .../launcher_dashboard_auth_test.go | 20 ++ web/frontend/src/api/pico.ts | 16 +- web/frontend/src/features/chat/controller.ts | 12 +- web/frontend/vite.config.ts | 2 +- 14 files changed, 253 insertions(+), 171 deletions(-) diff --git a/docs/guides/docker.md b/docs/guides/docker.md index 6c32879a6..3ccc7a2a7 100644 --- a/docs/guides/docker.md +++ b/docs/guides/docker.md @@ -27,7 +27,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d > **Docker Users**: By default, the Gateway listens on `127.0.0.1` which is not accessible from the host. If you need to access the health endpoints or expose ports, set `PICOCLAW_GATEWAY_HOST=0.0.0.0` in your environment or update `config.json`. > [!NOTE] -> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/token` and a `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled. +> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/info` and an authenticated `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled. ```bash # 5. Check logs diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index ecdc2d140..051beed1b 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -18,8 +18,6 @@ const ( TypeError = "error" TypePong = "pong" - PicoTokenPrefix = "pico-" - PayloadKeyContent = "content" PayloadKeyThought = "thought" diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 039f45075..f58590d5b 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -9,7 +9,6 @@ import ( "path/filepath" "sort" "strconv" - "strings" "sync" "sync/atomic" "syscall" @@ -27,7 +26,7 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" - "github.com/sipeed/picoclaw/pkg/channels/pico" + _ "github.com/sipeed/picoclaw/pkg/channels/pico" _ "github.com/sipeed/picoclaw/pkg/channels/qq" _ "github.com/sipeed/picoclaw/pkg/channels/slack" _ "github.com/sipeed/picoclaw/pkg/channels/teams_webhook" @@ -316,8 +315,6 @@ func executeReload( ) error { defer runningServices.reloading.Store(false) - overridePicoToken(newCfg, runningServices.authToken) - return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug) } @@ -386,8 +383,6 @@ func setupAndStartServices( fms.Start() } - overridePicoToken(cfg, authToken) - runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore) if err != nil { if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok { @@ -788,23 +783,6 @@ func setupCronTool( return cronService, nil } -// overridePicoToken replaces the pico channel token with the one from the PID file. -// The PID file is the single source of truth for the pico auth token; -// it is generated once at gateway startup and remains unchanged across reloads. -func overridePicoToken(cfg *config.Config, token string) { - picoBC := cfg.Channels.GetByType(config.ChannelPico) - if picoBC == nil || !picoBC.Enabled { - return - } - var picoCfg config.PicoSettings - picoBC.Decode(&picoCfg) - picoToken := picoCfg.Token.String() - if picoToken == "" || strings.HasPrefix(picoToken, pico.PicoTokenPrefix) { - return - } - picoCfg.SetToken(pico.PicoTokenPrefix + token + picoToken) -} - func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult { return func(prompt, channel, chatID string) *tools.ToolResult { if channel == "" || chatID == "" { diff --git a/web/backend/api/config.go b/web/backend/api/config.go index 80ab80f35..c7bd21197 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -94,8 +94,6 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { return } - // Refresh cached pico token in case user changed it. - refreshPicoToken(&cfg) h.applyRuntimeLogLevel() logger.Infof("configuration updated successfully") @@ -193,8 +191,6 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } - // Refresh cached pico token in case user changed it. - refreshPicoToken(&newCfg) h.applyRuntimeLogLevel() logger.Infof("configuration updated successfully") diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index fa5652323..ea43789d3 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -17,7 +17,6 @@ import ( "syscall" "time" - "github.com/sipeed/picoclaw/pkg/channels/pico" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" @@ -37,28 +36,12 @@ var gateway = struct { startupDeadline time.Time logs *LogBuffer pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json - picoToken string // cached pico token from config (for proxy auth validation) + picoToken string // cached raw pico token for upstream gateway proxy injection }{ runtimeStatus: "stopped", logs: NewLogBuffer(200), } -// refreshPicoToken updates gateway.picoToken from cfg -func refreshPicoToken(cfg *config.Config) { - gateway.mu.Lock() - defer gateway.mu.Unlock() - var picoCfg config.PicoSettings - if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil { - decoded, err := bc.GetDecoded() - if err == nil && decoded != nil { - if p, ok := decoded.(*config.PicoSettings); ok { - picoCfg = *p - } - } - } - gateway.picoToken = picoCfg.Token.String() -} - // refreshPicoTokensLocked reads the pico token from config and caches it. // Caller must hold gateway.mu (or be sole writer). func refreshPicoTokensLocked(configPath string) { @@ -101,18 +84,15 @@ const ( tokenPrefix = "token." ) -// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth. -func picoComposedToken(token string) string { +// picoGatewayProtocol returns the gateway-facing pico subprotocol that the +// launcher should inject when proxying browser traffic upstream. +func picoGatewayProtocol() string { gateway.mu.Lock() defer gateway.mu.Unlock() - // if not initial pico token, don't allow gateway auth - if gateway.picoToken == "" || gateway.pidData == nil { + if gateway.picoToken == "" { return "" } - if tokenPrefix+gateway.picoToken != token { - return "" - } - return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken + return tokenPrefix + gateway.picoToken } var ( diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index d0fc26d7b..c9802b30b 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -50,7 +50,7 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) { cfg.Gateway.Host = "127.0.0.1" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil) req.Host = "192.168.1.9:18800" if got := h.buildWsURL(req); got != "ws://192.168.1.9:18800/pico/ws" { @@ -181,7 +181,7 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil) req.Host = "chat.example.com" req.Header.Set("X-Forwarded-Proto", "https") @@ -198,7 +198,7 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil) req.Host = "secure.example.com" req.TLS = &tls.ConnectionState{} @@ -224,7 +224,7 @@ func TestBuildPicoURLsPreferXForwardedHost(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/info", nil) req.Host = "127.0.0.1:18800" req.Header.Set("X-Forwarded-Host", "vscode-tunnel.example.com") req.Header.Set("X-Forwarded-Proto", "https") @@ -249,7 +249,7 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil) req.Host = "chat.example.com" req.TLS = &tls.ConnectionState{} req.Header.Set("X-Forwarded-Proto", "http") @@ -264,7 +264,7 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) { h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) - req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/info", nil) req.Host = "localhost:18800" if got := h.buildWsURL(req); got != "ws://localhost:18800/pico/ws" { diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index 78bf34a63..998ed3317 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -121,6 +121,18 @@ func resetGatewayTestState(t *testing.T) { }) } +func TestPicoGatewayProtocol(t *testing.T) { + resetGatewayTestState(t) + + gateway.mu.Lock() + gateway.picoToken = "ui-token" + gateway.mu.Unlock() + + if got := picoGatewayProtocol(); got != tokenPrefix+"ui-token" { + t.Fatalf("picoGatewayProtocol() = %q, want %q", got, tokenPrefix+"ui-token") + } +} + type gatewayStartEnvSnapshot struct { GatewayHost string `json:"gateway_host"` GatewayHostSet bool `json:"gateway_host_set"` diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 00ffb8bb2..5e4848b01 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -5,8 +5,11 @@ import ( "encoding/hex" "encoding/json" "fmt" + "net" "net/http" "net/http/httputil" + "net/url" + "strings" "time" "github.com/sipeed/picoclaw/pkg/config" @@ -16,7 +19,7 @@ import ( // registerPicoRoutes binds Pico Channel management endpoints to the ServeMux. func (h *Handler) registerPicoRoutes(mux *http.ServeMux) { - mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken) + mux.HandleFunc("GET /api/pico/info", h.handleGetPicoInfo) mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken) mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup) @@ -28,12 +31,15 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) { // createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint. // The gateway bind host and port are resolved from the latest configuration. -func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy { +func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *httputil.ReverseProxy { wsProxy := &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { target := h.gatewayProxyURL() r.SetURL(target) - r.Out.Header.Set(protocolKey, tokenPrefix+token) + r.Out.Header.Del(protocolKey) + if upstreamProtocol != "" { + r.Out.Header.Set(protocolKey, upstreamProtocol) + } }, ModifyResponse: func(r *http.Response) error { if prot := r.Header.Values(protocolKey); len(prot) > 0 { @@ -52,10 +58,104 @@ func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.Rev return wsProxy } +func canonicalOrigin(raw string) (string, bool) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil || u == nil { + return "", false + } + + scheme := strings.ToLower(strings.TrimSpace(u.Scheme)) + if scheme != "http" && scheme != "https" { + return "", false + } + + host := strings.TrimSpace(u.Hostname()) + if host == "" { + return "", false + } + + port := u.Port() + if port == "" { + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + + return scheme + "://" + net.JoinHostPort(host, port), true +} + +func (h *Handler) expectedPicoProxyOrigin(r *http.Request) string { + return requestHTTPScheme(r) + "://" + h.picoWebUIAddr(r) +} + +func (h *Handler) validPicoProxyOrigin(r *http.Request) bool { + want, ok := canonicalOrigin(h.expectedPicoProxyOrigin(r)) + if !ok { + return false + } + + got, ok := canonicalOrigin(r.Header.Get("Origin")) + if !ok { + return false + } + + return got == want +} + +func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) { + if cfg == nil { + return config.PicoSettings{}, false + } + + bc := cfg.Channels.GetByType(config.ChannelPico) + if bc == nil { + return config.PicoSettings{}, false + } + + var picoCfg config.PicoSettings + if err := bc.Decode(&picoCfg); err != nil { + return config.PicoSettings{}, false + } + + return picoCfg, bc.Enabled +} + +func (h *Handler) writePicoInfoResponse( + w http.ResponseWriter, + r *http.Request, + cfg *config.Config, + changed *bool, +) { + picoCfg, enabled := decodePicoSettings(cfg) + + resp := map[string]any{ + "ws_url": h.buildWsURL(r), + "enabled": enabled, + } + if changed != nil { + resp["changed"] = *changed + } + if picoCfg.Token.String() != "" { + resp["configured"] = true + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + // handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections. -// It validates the client token before forwarding; rejects immediately on failure. +// It relies on launcher dashboard auth and same-origin browser access, then +// injects the raw pico token only on the upstream gateway request. func (h *Handler) handleWebSocketProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + if !h.validPicoProxyOrigin(r) { + logger.Warnf("Invalid Pico WebSocket origin: %q", r.Header.Get("Origin")) + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + gateway.mu.Lock() ensurePicoTokenCachedLocked(h.configPath) cachedPID := gateway.pidData @@ -91,51 +191,38 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc { http.Error(w, "Gateway not available", http.StatusServiceUnavailable) return } - prot := r.Header.Values(protocolKey) - if len(prot) > 0 { - origProtocol := prot[0] - newToken := picoComposedToken(prot[0]) - if newToken != "" { - h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r) - return - } + + upstreamProtocol := picoGatewayProtocol() + if upstreamProtocol == "" { + logger.Warn("Pico token unavailable for WebSocket proxy") + http.Error(w, "Pico channel not configured", http.StatusServiceUnavailable) + return } - logger.Warnf("Invalid Pico token: %v", prot) - http.Error(w, "Invalid Pico token", http.StatusForbidden) + var origProtocol string + if prot := r.Header.Values(protocolKey); len(prot) > 0 { + origProtocol = prot[0] + } + + h.createWsProxy(origProtocol, upstreamProtocol).ServeHTTP(w, r) } } -// handleGetPicoToken returns the current WS token and URL for the frontend. +// handleGetPicoInfo returns non-secret Pico connection info for the launcher UI. // -// GET /api/pico/token -func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) { +// GET /api/pico/info +func (h *Handler) handleGetPicoInfo(w http.ResponseWriter, r *http.Request) { cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } - wsURL := h.buildWsURL(r) - - w.Header().Set("Content-Type", "application/json") - bc := cfg.Channels.GetByType(config.ChannelPico) - var picoCfg config.PicoSettings - if bc != nil { - bc.Decode(&picoCfg) - } - enabled := false - if bc != nil { - enabled = bc.Enabled - } - json.NewEncoder(w).Encode(map[string]any{ - "token": picoCfg.Token.String(), - "ws_url": wsURL, - "enabled": enabled, - }) + h.writePicoInfoResponse(w, r, cfg, nil) } -// handleRegenPicoToken generates a new Pico WebSocket token and saves it. +// handleRegenPicoToken rotates the raw Pico WebSocket token and returns +// non-secret connection info for the launcher UI. // // POST /api/pico/token func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { @@ -160,18 +247,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { return } - // Refresh cached pico token. - gateway.mu.Lock() - gateway.picoToken = token - gateway.mu.Unlock() - - wsURL := h.buildWsURL(r) - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "token": token, - "ws_url": wsURL, - }) + h.writePicoInfoResponse(w, r, cfg, nil) } // EnsurePicoChannel enables the Pico channel with sane defaults if it isn't @@ -234,31 +310,14 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { return } - // Reload config (EnsurePicoChannel may have modified it) and refresh cache. + // Reload config (EnsurePicoChannel may have modified it). cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } - if changed { - refreshPicoToken(cfg) - } - wsURL := h.buildWsURL(r) - - var picoCfg2 config.PicoSettings - if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil { - if decoded, err := bc.GetDecoded(); err == nil && decoded != nil { - picoCfg2 = *decoded.(*config.PicoSettings) - } - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "token": picoCfg2.Token.String(), - "ws_url": wsURL, - "enabled": true, - "changed": changed, - }) + h.writePicoInfoResponse(w, r, cfg, &changed) } // generateSecureToken creates a random 32-character hex string. diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 807c796dc..146f9e697 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -11,11 +11,16 @@ import ( "strconv" "testing" - "github.com/sipeed/picoclaw/pkg/channels/pico" "github.com/sipeed/picoclaw/pkg/config" ppid "github.com/sipeed/picoclaw/pkg/pid" ) +func newPicoProxyRequest(method, path string) *http.Request { + req := httptest.NewRequest(method, "http://launcher.local:18800"+path, nil) + req.Header.Set("Origin", "http://launcher.local:18800") + return req +} + func TestEnsurePicoChannel_FreshConfig(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -365,8 +370,8 @@ func TestHandlePicoSetup_Response(t *testing.T) { t.Fatalf("failed to decode response: %v", err) } - if resp["token"] == nil || resp["token"] == "" { - t.Error("response should contain a non-empty token") + if _, ok := resp["token"]; ok { + t.Error("response must not expose the raw pico token") } if resp["ws_url"] == nil || resp["ws_url"] == "" { t.Error("response should contain ws_url") @@ -377,6 +382,45 @@ func TestHandlePicoSetup_Response(t *testing.T) { if resp["changed"] != true { t.Error("response should have changed=true on first setup") } + if resp["configured"] != true { + t.Error("response should have configured=true") + } +} + +func TestHandleGetPicoInfo_OmitsToken(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + if _, err := h.EnsurePicoChannel(""); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/api/pico/info", nil) + rec := httptest.NewRecorder() + + h.handleGetPicoInfo(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if _, ok := resp["token"]; ok { + t.Fatal("info response must not expose the raw pico token") + } + if resp["enabled"] != true { + t.Fatalf("enabled = %#v, want true", resp["enabled"]) + } + if resp["configured"] != true { + t.Fatalf("configured = %#v, want true", resp["configured"]) + } + if resp["ws_url"] == nil || resp["ws_url"] == "" { + t.Fatal("response should contain ws_url") + } } func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { @@ -438,20 +482,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { gateway.pidData = &ppid.PidFileData{} gateway.picoToken = "pico" - req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) - req1.Header.Set(protocolKey, tokenPrefix+"wrong_token") + req1 := newPicoProxyRequest(http.MethodGet, "/pico/ws") rec1 := httptest.NewRecorder() handler(rec1, req1) - if rec1.Code != http.StatusForbidden { - t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden) - } - - req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil) - req1.Header.Set(protocolKey, tokenPrefix+"pico") - rec1 = httptest.NewRecorder() - handler(rec1, req1) - if rec1.Code != http.StatusOK { t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK) } @@ -464,8 +498,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { t.Fatalf("SaveConfig() error = %v", err) } - req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) - req2.Header.Set(protocolKey, tokenPrefix+"pico") + req2 := newPicoProxyRequest(http.MethodGet, "/pico/ws") rec2 := httptest.NewRecorder() handler(rec2, req2) @@ -539,8 +572,7 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) { gateway.pidData = &ppid.PidFileData{} gateway.picoToken = "" - req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil) - req.Header.Set(protocolKey, tokenPrefix+"cached-token") + req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session") rec := httptest.NewRecorder() handler(rec, req) @@ -625,8 +657,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) { setGatewayRuntimeStatusLocked("stopped") gateway.mu.Unlock() - req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil) - req.Header.Set(protocolKey, tokenPrefix+"ui-token") + req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session") rec := httptest.NewRecorder() handler(rec, req) @@ -634,7 +665,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } - expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token" + expected := tokenPrefix + "ui-token" if got := rec.Body.String(); got != expected { t.Fatalf("forwarded protocol = %q, want %q", got, expected) } @@ -696,8 +727,7 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) { setGatewayRuntimeStatusLocked("running") gateway.mu.Unlock() - req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil) - req.Header.Set(protocolKey, tokenPrefix+"ui-token") + req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session") rec := httptest.NewRecorder() handler(rec, req) @@ -711,6 +741,21 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) { } } +func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + handler := h.handleWebSocketProxy() + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) + req.Header.Set("Origin", "http://evil.example") + rec := httptest.NewRecorder() + handler(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) + } +} + func mustGatewayTestPort(t *testing.T, rawURL string) int { t.Helper() diff --git a/web/backend/middleware/launcher_dashboard_auth.go b/web/backend/middleware/launcher_dashboard_auth.go index c1c4c19c6..d72bd0f00 100644 --- a/web/backend/middleware/launcher_dashboard_auth.go +++ b/web/backend/middleware/launcher_dashboard_auth.go @@ -218,6 +218,10 @@ func validLauncherDashboardAuth(r *http.Request, cfg LauncherDashboardAuthConfig } func rejectLauncherDashboardAuth(w http.ResponseWriter, r *http.Request, canonicalPath string) { + if canonicalPath == "/pico/ws" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } if strings.HasPrefix(canonicalPath, "/api/") { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) diff --git a/web/backend/middleware/launcher_dashboard_auth_test.go b/web/backend/middleware/launcher_dashboard_auth_test.go index 1b919bf96..7b7418998 100644 --- a/web/backend/middleware/launcher_dashboard_auth_test.go +++ b/web/backend/middleware/launcher_dashboard_auth_test.go @@ -40,6 +40,7 @@ func TestLauncherDashboardAuth_AllowsPublicPaths(t *testing.T) { {http.MethodPost, "/api/auth/logout", http.StatusTeapot}, {http.MethodGet, "/api/auth/logout", http.StatusUnauthorized}, {http.MethodGet, "/api/config", http.StatusUnauthorized}, + {http.MethodGet, "/pico/ws", http.StatusUnauthorized}, } { rec := httptest.NewRecorder() req := httptest.NewRequest(tc.method, tc.path, nil) @@ -160,3 +161,22 @@ func TestLauncherDashboardAuth_CookieAndBearer(t *testing.T) { t.Fatalf("bearer auth: status = %d", rec2.Code) } } + +func TestLauncherDashboardAuth_WebSocketUnauthorizedDoesNotRedirect(t *testing.T) { + cfg := LauncherDashboardAuthConfig{ExpectedCookie: "deadbeef", Token: "x"} + next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Fatal("next handler should not run without auth") + }) + h := LauncherDashboardAuth(cfg, next) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + if got := rec.Header().Get("Location"); got != "" { + t.Fatalf("Location = %q, want empty", got) + } +} diff --git a/web/frontend/src/api/pico.ts b/web/frontend/src/api/pico.ts index 6b8ceb49a..ca98a06da 100644 --- a/web/frontend/src/api/pico.ts +++ b/web/frontend/src/api/pico.ts @@ -2,16 +2,16 @@ import { launcherFetch } from "@/api/http" // API client for Pico Channel configuration. -interface PicoTokenResponse { - token: string +interface PicoInfoResponse { ws_url: string enabled: boolean + configured?: boolean } interface PicoSetupResponse { - token: string ws_url: string enabled: boolean + configured?: boolean changed: boolean } @@ -25,16 +25,16 @@ async function request(path: string, options?: RequestInit): Promise { return res.json() as Promise } -export async function getPicoToken(): Promise { - return request("/api/pico/token") +export async function getPicoInfo(): Promise { + return request("/api/pico/info") } -export async function regenPicoToken(): Promise { - return request("/api/pico/token", { method: "POST" }) +export async function regenPicoToken(): Promise { + return request("/api/pico/token", { method: "POST" }) } export async function setupPico(): Promise { return request("/api/pico/setup", { method: "POST" }) } -export type { PicoTokenResponse, PicoSetupResponse } +export type { PicoInfoResponse, PicoSetupResponse } diff --git a/web/frontend/src/features/chat/controller.ts b/web/frontend/src/features/chat/controller.ts index 28ef491fa..183b1ba6f 100644 --- a/web/frontend/src/features/chat/controller.ts +++ b/web/frontend/src/features/chat/controller.ts @@ -1,7 +1,6 @@ import { getDefaultStore } from "jotai" import { toast } from "sonner" -import { getPicoToken } from "@/api/pico" import { loadSessionMessages, mergeHistoryMessages, @@ -131,7 +130,6 @@ export async function connectChat() { updateChatStore({ connectionState: "connecting" }) try { - const { token } = await getPicoToken() const sessionId = activeSessionIdRef if (generation !== connectionGeneration) { @@ -139,18 +137,10 @@ export async function connectChat() { return } - if (!token) { - console.error("No pico token available") - updateChatStore({ connectionState: "error" }) - isConnecting = false - scheduleReconnect(generation, sessionId) - return - } - const wsScheme = window.location.protocol === "https:" ? "wss:" : "ws:" const wsUrl = `${wsScheme}//${window.location.host}/pico/ws` const url = `${wsUrl}?session_id=${encodeURIComponent(sessionId)}` - const socket = new WebSocket(url, [`token.${token}`]) + const socket = new WebSocket(url) if (generation !== connectionGeneration) { isConnecting = false diff --git a/web/frontend/vite.config.ts b/web/frontend/vite.config.ts index 0ef4e1415..57512c8b9 100644 --- a/web/frontend/vite.config.ts +++ b/web/frontend/vite.config.ts @@ -29,7 +29,7 @@ export default defineConfig({ target: "http://localhost:18800", changeOrigin: true, }, - "/ws": { + "/pico/ws": { target: "ws://localhost:18800", ws: true, }, From d002e1517ba1670d094a3752a6ff469e7c8cd00e Mon Sep 17 00:00:00 2001 From: wenjie Date: Thu, 16 Apr 2026 18:31:42 +0800 Subject: [PATCH 2/3] fix(web): improve Pico URL and origin handling behind proxies - read client scheme from X-Forwarded-Proto and RFC 7239 Forwarded - derive client-visible ports from forwarded host information - add coverage for HTTPS origins without explicit ports - verify behavior when proxies omit forwarded protocol headers --- web/backend/api/gateway_host.go | 50 ++++++++++++++++++++-------- web/backend/api/gateway_host_test.go | 29 ++++++++++++---- web/backend/api/pico_test.go | 27 +++++++++++++++ 3 files changed, 86 insertions(+), 20 deletions(-) diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index c6c2073e2..03af7a9d3 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -85,8 +85,22 @@ func requestHostName(r *http.Request) string { return netbind.ResolveAdaptiveLoopbackHost() } +func forwardedProtoFirst(r *http.Request) string { + raw := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")) + if raw == "" { + raw = forwardedRFC7239Proto(r) + } + if raw == "" { + return "" + } + if i := strings.IndexByte(raw, ','); i >= 0 { + raw = strings.TrimSpace(raw[:i]) + } + return strings.ToLower(raw) +} + func requestWSScheme(r *http.Request) string { - if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" { + if forwarded := forwardedProtoFirst(r); forwarded != "" { proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0])) if proto == "https" || proto == "wss" { return "wss" @@ -105,7 +119,7 @@ func requestWSScheme(r *http.Request) string { // requestHTTPScheme returns http or https for URLs that are not WebSockets (e.g. SSE). func requestHTTPScheme(r *http.Request) string { - if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" { + if forwarded := forwardedProtoFirst(r); forwarded != "" { proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0])) if proto == "https" || proto == "wss" { return "https" @@ -117,6 +131,7 @@ func requestHTTPScheme(r *http.Request) string { if r.TLS != nil { return "https" } + return "http" } @@ -138,6 +153,14 @@ func forwardedHostFirst(r *http.Request) string { // forwardedRFC7239Host parses host= from the first Forwarded header element (RFC 7239). func forwardedRFC7239Host(r *http.Request) string { + return forwardedRFC7239Param(r, "host") +} + +func forwardedRFC7239Proto(r *http.Request) string { + return forwardedRFC7239Param(r, "proto") +} + +func forwardedRFC7239Param(r *http.Request, key string) string { v := strings.TrimSpace(r.Header.Get("Forwarded")) if v == "" { return "" @@ -146,7 +169,7 @@ func forwardedRFC7239Host(r *http.Request) string { for _, part := range strings.Split(first, ";") { part = strings.TrimSpace(part) low := strings.ToLower(part) - if !strings.HasPrefix(low, "host=") { + if !strings.HasPrefix(low, key+"=") { continue } val := strings.TrimSpace(part[strings.IndexByte(part, '=')+1:]) @@ -177,13 +200,21 @@ func clientVisiblePort(r *http.Request, serverListenPort int) string { if p := forwardedPortFirst(r); p != "" { return p } + if fwdHost := forwardedHostFirst(r); fwdHost != "" { + if _, port, err := net.SplitHostPort(fwdHost); err == nil && port != "" { + return port + } + } if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" { return port } + if strings.TrimSpace(r.Host) == "" && forwardedHostFirst(r) == "" { + return strconv.Itoa(serverListenPort) + } if requestHTTPScheme(r) == "https" { return "443" } - return strconv.Itoa(serverListenPort) + return "80" } // joinClientVisibleHostPort builds host:port for absolute URLs returned to the browser. @@ -205,16 +236,7 @@ func (h *Handler) picoWebUIAddr(r *http.Request) string { if fwdHost := forwardedHostFirst(r); fwdHost != "" { return joinClientVisibleHostPort(r, fwdHost, wsPort) } - host := requestHostName(r) - // Use clientVisiblePort only when an explicit port is present in headers - // or Host header — do not infer from TLS/scheme, as serverPort takes priority. - if p := forwardedPortFirst(r); p != "" { - return net.JoinHostPort(host, p) - } - if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" { - return net.JoinHostPort(host, port) - } - return net.JoinHostPort(host, strconv.Itoa(wsPort)) + return joinClientVisibleHostPort(r, requestHostName(r), wsPort) } func (h *Handler) buildWsURL(r *http.Request) string { diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index c9802b30b..54d1010d2 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -185,8 +185,8 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) { req.Host = "chat.example.com" req.Header.Set("X-Forwarded-Proto", "https") - if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "wss://chat.example.com:443/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:443/pico/ws") } } @@ -202,8 +202,8 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) { req.Host = "secure.example.com" req.TLS = &tls.ConnectionState{} - if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "wss://secure.example.com:443/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:443/pico/ws") } } @@ -254,8 +254,25 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) { req.TLS = &tls.ConnectionState{} req.Header.Set("X-Forwarded-Proto", "http") - if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "ws://chat.example.com:80/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:80/pico/ws") + } +} + +func TestBuildWsURLDoesNotTrustOriginWhenProxyOmitsForwardedProto(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil) + req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" + req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") + + if got := h.buildWsURL(req); got != "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws" { + t.Fatalf( + "buildWsURL() = %q, want %q", + got, + "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws", + ) } } diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 146f9e697..34b011127 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -756,6 +756,33 @@ func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) { } } +func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) + req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") + + if !h.validPicoProxyOrigin(req) { + t.Fatal("validPicoProxyOrigin() = false, want true") + } +} + +func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) + req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" + req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") + + if h.validPicoProxyOrigin(req) { + t.Fatal("validPicoProxyOrigin() = true, want false") + } +} + func mustGatewayTestPort(t *testing.T, rawURL string) int { t.Helper() From f8190f04b7db62556a8e0cdb63a0a552b60752ce Mon Sep 17 00:00:00 2001 From: wenjie Date: Thu, 16 Apr 2026 19:04:47 +0800 Subject: [PATCH 3/3] fix(web): stop pinning Pico WebSocket origins during setup - remove request-origin seeding from `EnsurePicoChannel` - keep `allow_origins` empty by default for auto-configured Pico channels - relax launcher Pico WebSocket proxy origin validation - update Pico backend tests for the new setup and proxy behavior --- web/backend/api/gateway.go | 2 +- web/backend/api/pico.go | 74 +--------------- web/backend/api/pico_test.go | 159 +++++++++++++++++------------------ web/backend/main.go | 2 +- 4 files changed, 85 insertions(+), 152 deletions(-) diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index ea43789d3..201000ff3 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -732,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int gateway.logs.Reset() // Ensure Pico Channel is configured before starting gateway - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err)) // Non-fatal: gateway can still start without pico channel diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 5e4848b01..ffd0796c7 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -5,11 +5,8 @@ import ( "encoding/hex" "encoding/json" "fmt" - "net" "net/http" "net/http/httputil" - "net/url" - "strings" "time" "github.com/sipeed/picoclaw/pkg/config" @@ -58,52 +55,6 @@ func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *h return wsProxy } -func canonicalOrigin(raw string) (string, bool) { - u, err := url.Parse(strings.TrimSpace(raw)) - if err != nil || u == nil { - return "", false - } - - scheme := strings.ToLower(strings.TrimSpace(u.Scheme)) - if scheme != "http" && scheme != "https" { - return "", false - } - - host := strings.TrimSpace(u.Hostname()) - if host == "" { - return "", false - } - - port := u.Port() - if port == "" { - if scheme == "https" { - port = "443" - } else { - port = "80" - } - } - - return scheme + "://" + net.JoinHostPort(host, port), true -} - -func (h *Handler) expectedPicoProxyOrigin(r *http.Request) string { - return requestHTTPScheme(r) + "://" + h.picoWebUIAddr(r) -} - -func (h *Handler) validPicoProxyOrigin(r *http.Request) bool { - want, ok := canonicalOrigin(h.expectedPicoProxyOrigin(r)) - if !ok { - return false - } - - got, ok := canonicalOrigin(r.Header.Get("Origin")) - if !ok { - return false - } - - return got == want -} - func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) { if cfg == nil { return config.PicoSettings{}, false @@ -146,16 +97,10 @@ func (h *Handler) writePicoInfoResponse( } // handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections. -// It relies on launcher dashboard auth and same-origin browser access, then -// injects the raw pico token only on the upstream gateway request. +// It relies on launcher dashboard auth, then injects the raw pico token only +// on the upstream gateway request. func (h *Handler) handleWebSocketProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if !h.validPicoProxyOrigin(r) { - logger.Warnf("Invalid Pico WebSocket origin: %q", r.Header.Get("Origin")) - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - gateway.mu.Lock() ensurePicoTokenCachedLocked(h.configPath) cachedPID := gateway.pidData @@ -252,12 +197,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { // EnsurePicoChannel enables the Pico channel with sane defaults if it isn't // already configured. Returns true when the config was modified. -// -// callerOrigin is the Origin header from the setup request. If non-empty and -// no origins are configured yet, it's written as the allowed origin so the -// WebSocket handshake works for whatever host the caller is on (LAN, custom -// port, etc.). Pass "" when there's no request context. -func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { +func (h *Handler) EnsurePicoChannel() (bool, error) { cfg, err := config.LoadConfig(h.configPath) if err != nil { return false, fmt.Errorf("failed to load config: %w", err) @@ -282,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { picoCfg.Token = *config.NewSecureString(generateSecureToken()) changed = true } - - // Seed origins from the request instead of hardcoding ports. - if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" { - picoCfg.AllowOrigins = []string{callerOrigin} - changed = true - } } } @@ -304,7 +238,7 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { // // POST /api/pico/setup func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { - changed, err := h.EnsurePicoChannel(r.Header.Get("Origin")) + changed, err := h.EnsurePicoChannel() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 34b011127..a56cd9ba2 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -25,7 +25,7 @@ func TestEnsurePicoChannel_FreshConfig(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -56,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -76,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { } } -func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { +func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -95,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - for _, origin := range picoCfg.AllowOrigins { - if origin == "*" { - t.Error("setup must not set wildcard origin '*'") - } - } -} - -func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "config.json") - h := NewHandler(configPath) - - if _, err := h.EnsurePicoChannel(""); err != nil { - t.Fatalf("EnsurePicoChannel() error = %v", err) - } - - cfg, err := config.LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error = %v", err) - } - - bc := cfg.Channels["pico"] - decoded, err := bc.GetDecoded() - if err != nil { - t.Fatalf("GetDecoded() error = %v", err) - } - picoCfg := decoded.(*config.PicoSettings) - // Without a caller origin, allow_origins stays empty (CheckOrigin - // allows all when the list is empty, so the channel still works). if len(picoCfg.AllowOrigins) != 0 { - t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins) + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } -func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { +func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - lanOrigin := "http://192.168.1.9:18800" - if _, err := h.EnsurePicoChannel(lanOrigin); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -148,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin { - t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin) + if len(picoCfg.AllowOrigins) != 0 { + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } @@ -174,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -218,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) { h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -258,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) { } h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -285,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - origin := "http://localhost:18800" - // First call sets things up - if _, err := h.EnsurePicoChannel(origin); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("first EnsurePicoChannel() error = %v", err) } @@ -302,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { token1 := picoCfg.Token.String() // Second call should be a no-op - changed, err := h.EnsurePicoChannel(origin) + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("second EnsurePicoChannel() error = %v", err) } @@ -322,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { } } -func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) { +func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -347,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" { - t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins) + if len(picoCfg.AllowOrigins) != 0 { + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } @@ -391,7 +360,7 @@ func TestHandleGetPicoInfo_OmitsToken(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -741,45 +710,75 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) { } } -func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) { +func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) { + origMatcher := gatewayProcessMatcher + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } + t.Cleanup(func() { gatewayProcessMatcher = origMatcher }) + + home := t.TempDir() + t.Setenv("PICOCLAW_HOME", home) + configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) handler := h.handleWebSocketProxy() - req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/pico/ws" { + t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws") + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "proxied") + })) + defer server.Close() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = mustGatewayTestPort(t, server.URL) + bc := cfg.Channels["pico"] + bc.Enabled = true + decoded, err := bc.GetDecoded() + if err != nil { + t.Fatalf("GetDecoded() error = %v", err) + } + decoded.(*config.PicoSettings).SetToken("ui-token") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: cfg.Gateway.Host, + Port: cfg.Gateway.Port, + }) + t.Cleanup(func() { + ppid.RemovePidFile(globalConfigDir()) + }) + + origPidData := gateway.pidData + origPicoToken := gateway.picoToken + t.Cleanup(func() { + gateway.pidData = origPidData + gateway.picoToken = origPicoToken + }) + + gateway.pidData = &ppid.PidFileData{} + gateway.picoToken = "ui-token" + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil) req.Header.Set("Origin", "http://evil.example") rec := httptest.NewRecorder() handler(rec, req) - if rec.Code != http.StatusForbidden { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) - } -} - -func TestValidPicoProxyOriginAcceptsHTTPSOriginWithoutExplicitPort(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "config.json") - h := NewHandler(configPath) - - req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) - req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" - req.Header.Set("X-Forwarded-Proto", "https") - req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") - - if !h.validPicoProxyOrigin(req) { - t.Fatal("validPicoProxyOrigin() = false, want true") - } -} - -func TestValidPicoProxyOriginRejectsHTTPSOriginWhenProxyOmitsForwardedProto(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "config.json") - h := NewHandler(configPath) - - req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil) - req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" - req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") - - if h.validPicoProxyOrigin(req) { - t.Fatal("validPicoProxyOrigin() = true, want false") + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } } diff --git a/web/backend/main.go b/web/backend/main.go index 01ef5edf0..e42558398 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -544,7 +544,7 @@ func main() { // API Routes (e.g. /api/status) apiHandler = api.NewHandler(absPath) apiHandler.SetDebug(debug) - if _, err = apiHandler.EnsurePicoChannel(""); err != nil { + if _, err = apiHandler.EnsurePicoChannel(); err != nil { logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) } apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)