diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index ce3a9ca1e..1e2520920 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -59,6 +59,16 @@ func refreshPicoTokensLocked(configPath string) { gateway.picoToken = cfg.Channels.Pico.Token.String() } +// ensurePicoTokenCachedLocked lazily fills the in-memory pico token cache when +// the launcher has already discovered a running gateway via pidData, but has +// not yet refreshed the token into memory. +func ensurePicoTokenCachedLocked(configPath string) { + if gateway.picoToken != "" { + return + } + refreshPicoTokensLocked(configPath) +} + const ( protocolKey = "Sec-Websocket-Protocol" tokenPrefix = "token." diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 0e8cd07fc..c8ef47308 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -56,6 +56,7 @@ func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.Rev func (h *Handler) handleWebSocketProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { gateway.mu.Lock() + ensurePicoTokenCachedLocked(h.configPath) gatewayAvailable := gateway.pidData != nil gateway.mu.Unlock() diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index beff4d77f..ee5586746 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -377,6 +377,55 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { } } +func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + handler := h.handleWebSocketProxy() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/pico/ws" { + t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws") + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "proxied") + })) + defer server.Close() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = mustGatewayTestPort(t, server.URL) + cfg.Channels.Pico.Enabled = true + cfg.Channels.Pico.SetToken("cached-token") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + origPidData := gateway.pidData + origPicoToken := gateway.picoToken + t.Cleanup(func() { + gateway.pidData = origPidData + gateway.picoToken = origPicoToken + }) + + gateway.pidData = &ppid.PidFileData{} + gateway.picoToken = "" + + req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil) + req.Header.Set(protocolKey, tokenPrefix+"cached-token") + rec := httptest.NewRecorder() + handler(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if body := rec.Body.String(); body != "proxied" { + t.Fatalf("body = %q, want %q", body, "proxied") + } + if gateway.picoToken != "cached-token" { + t.Fatalf("gateway.picoToken = %q, want %q", gateway.picoToken, "cached-token") + } +} + func mustGatewayTestPort(t *testing.T, rawURL string) int { t.Helper()