refactor(web): secure Pico websocket access behind launcher auth

- stop exposing the raw Pico token to the frontend
- add /api/pico/info for non-secret Pico connection metadata
- proxy /pico/ws through the launcher with same-origin and dashboard auth checks
- inject the upstream Pico websocket protocol server-side
- update frontend chat connection flow and Vite websocket proxy path
- refresh related docs and tests
This commit is contained in:
wenjie
2026-04-16 16:47:23 +08:00
parent 6126ede963
commit 4b76196e2c
14 changed files with 253 additions and 171 deletions
-4
View File
@@ -94,8 +94,6 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token in case user changed it.
refreshPicoToken(&cfg)
h.applyRuntimeLogLevel()
logger.Infof("configuration updated successfully")
@@ -193,8 +191,6 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token in case user changed it.
refreshPicoToken(&newCfg)
h.applyRuntimeLogLevel()
logger.Infof("configuration updated successfully")
+6 -26
View File
@@ -17,7 +17,6 @@ import (
"syscall"
"time"
"github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -37,28 +36,12 @@ var gateway = struct {
startupDeadline time.Time
logs *LogBuffer
pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json
picoToken string // cached pico token from config (for proxy auth validation)
picoToken string // cached raw pico token for upstream gateway proxy injection
}{
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
}
// refreshPicoToken updates gateway.picoToken from cfg
func refreshPicoToken(cfg *config.Config) {
gateway.mu.Lock()
defer gateway.mu.Unlock()
var picoCfg config.PicoSettings
if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil {
decoded, err := bc.GetDecoded()
if err == nil && decoded != nil {
if p, ok := decoded.(*config.PicoSettings); ok {
picoCfg = *p
}
}
}
gateway.picoToken = picoCfg.Token.String()
}
// refreshPicoTokensLocked reads the pico token from config and caches it.
// Caller must hold gateway.mu (or be sole writer).
func refreshPicoTokensLocked(configPath string) {
@@ -101,18 +84,15 @@ const (
tokenPrefix = "token."
)
// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth.
func picoComposedToken(token string) string {
// picoGatewayProtocol returns the gateway-facing pico subprotocol that the
// launcher should inject when proxying browser traffic upstream.
func picoGatewayProtocol() string {
gateway.mu.Lock()
defer gateway.mu.Unlock()
// if not initial pico token, don't allow gateway auth
if gateway.picoToken == "" || gateway.pidData == nil {
if gateway.picoToken == "" {
return ""
}
if tokenPrefix+gateway.picoToken != token {
return ""
}
return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken
return tokenPrefix + gateway.picoToken
}
var (
+6 -6
View File
@@ -50,7 +50,7 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
req.Host = "192.168.1.9:18800"
if got := h.buildWsURL(req); got != "ws://192.168.1.9:18800/pico/ws" {
@@ -181,7 +181,7 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
req.Host = "chat.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
@@ -198,7 +198,7 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil)
req.Host = "secure.example.com"
req.TLS = &tls.ConnectionState{}
@@ -224,7 +224,7 @@ func TestBuildPicoURLsPreferXForwardedHost(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/info", nil)
req.Host = "127.0.0.1:18800"
req.Header.Set("X-Forwarded-Host", "vscode-tunnel.example.com")
req.Header.Set("X-Forwarded-Proto", "https")
@@ -249,7 +249,7 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil)
req.Host = "chat.example.com"
req.TLS = &tls.ConnectionState{}
req.Header.Set("X-Forwarded-Proto", "http")
@@ -264,7 +264,7 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) {
h := NewHandler(configPath)
h.SetServerOptions(18800, false, false, nil)
req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/info", nil)
req.Host = "localhost:18800"
if got := h.buildWsURL(req); got != "ws://localhost:18800/pico/ws" {
+12
View File
@@ -121,6 +121,18 @@ func resetGatewayTestState(t *testing.T) {
})
}
func TestPicoGatewayProtocol(t *testing.T) {
resetGatewayTestState(t)
gateway.mu.Lock()
gateway.picoToken = "ui-token"
gateway.mu.Unlock()
if got := picoGatewayProtocol(); got != tokenPrefix+"ui-token" {
t.Fatalf("picoGatewayProtocol() = %q, want %q", got, tokenPrefix+"ui-token")
}
}
type gatewayStartEnvSnapshot struct {
GatewayHost string `json:"gateway_host"`
GatewayHostSet bool `json:"gateway_host_set"`
+125 -66
View File
@@ -5,8 +5,11 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
@@ -16,7 +19,7 @@ import (
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken)
mux.HandleFunc("GET /api/pico/info", h.handleGetPicoInfo)
mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken)
mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup)
@@ -28,12 +31,15 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
// The gateway bind host and port are resolved from the latest configuration.
func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy {
func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *httputil.ReverseProxy {
wsProxy := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
target := h.gatewayProxyURL()
r.SetURL(target)
r.Out.Header.Set(protocolKey, tokenPrefix+token)
r.Out.Header.Del(protocolKey)
if upstreamProtocol != "" {
r.Out.Header.Set(protocolKey, upstreamProtocol)
}
},
ModifyResponse: func(r *http.Response) error {
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
@@ -52,10 +58,104 @@ func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.Rev
return wsProxy
}
func canonicalOrigin(raw string) (string, bool) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil || u == nil {
return "", false
}
scheme := strings.ToLower(strings.TrimSpace(u.Scheme))
if scheme != "http" && scheme != "https" {
return "", false
}
host := strings.TrimSpace(u.Hostname())
if host == "" {
return "", false
}
port := u.Port()
if port == "" {
if scheme == "https" {
port = "443"
} else {
port = "80"
}
}
return scheme + "://" + net.JoinHostPort(host, port), true
}
func (h *Handler) expectedPicoProxyOrigin(r *http.Request) string {
return requestHTTPScheme(r) + "://" + h.picoWebUIAddr(r)
}
func (h *Handler) validPicoProxyOrigin(r *http.Request) bool {
want, ok := canonicalOrigin(h.expectedPicoProxyOrigin(r))
if !ok {
return false
}
got, ok := canonicalOrigin(r.Header.Get("Origin"))
if !ok {
return false
}
return got == want
}
func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) {
if cfg == nil {
return config.PicoSettings{}, false
}
bc := cfg.Channels.GetByType(config.ChannelPico)
if bc == nil {
return config.PicoSettings{}, false
}
var picoCfg config.PicoSettings
if err := bc.Decode(&picoCfg); err != nil {
return config.PicoSettings{}, false
}
return picoCfg, bc.Enabled
}
func (h *Handler) writePicoInfoResponse(
w http.ResponseWriter,
r *http.Request,
cfg *config.Config,
changed *bool,
) {
picoCfg, enabled := decodePicoSettings(cfg)
resp := map[string]any{
"ws_url": h.buildWsURL(r),
"enabled": enabled,
}
if changed != nil {
resp["changed"] = *changed
}
if picoCfg.Token.String() != "" {
resp["configured"] = true
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
// It validates the client token before forwarding; rejects immediately on failure.
// It relies on launcher dashboard auth and same-origin browser access, then
// injects the raw pico token only on the upstream gateway request.
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !h.validPicoProxyOrigin(r) {
logger.Warnf("Invalid Pico WebSocket origin: %q", r.Header.Get("Origin"))
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
gateway.mu.Lock()
ensurePicoTokenCachedLocked(h.configPath)
cachedPID := gateway.pidData
@@ -91,51 +191,38 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
return
}
prot := r.Header.Values(protocolKey)
if len(prot) > 0 {
origProtocol := prot[0]
newToken := picoComposedToken(prot[0])
if newToken != "" {
h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r)
return
}
upstreamProtocol := picoGatewayProtocol()
if upstreamProtocol == "" {
logger.Warn("Pico token unavailable for WebSocket proxy")
http.Error(w, "Pico channel not configured", http.StatusServiceUnavailable)
return
}
logger.Warnf("Invalid Pico token: %v", prot)
http.Error(w, "Invalid Pico token", http.StatusForbidden)
var origProtocol string
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
origProtocol = prot[0]
}
h.createWsProxy(origProtocol, upstreamProtocol).ServeHTTP(w, r)
}
}
// handleGetPicoToken returns the current WS token and URL for the frontend.
// handleGetPicoInfo returns non-secret Pico connection info for the launcher UI.
//
// GET /api/pico/token
func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) {
// GET /api/pico/info
func (h *Handler) handleGetPicoInfo(w http.ResponseWriter, r *http.Request) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
return
}
wsURL := h.buildWsURL(r)
w.Header().Set("Content-Type", "application/json")
bc := cfg.Channels.GetByType(config.ChannelPico)
var picoCfg config.PicoSettings
if bc != nil {
bc.Decode(&picoCfg)
}
enabled := false
if bc != nil {
enabled = bc.Enabled
}
json.NewEncoder(w).Encode(map[string]any{
"token": picoCfg.Token.String(),
"ws_url": wsURL,
"enabled": enabled,
})
h.writePicoInfoResponse(w, r, cfg, nil)
}
// handleRegenPicoToken generates a new Pico WebSocket token and saves it.
// handleRegenPicoToken rotates the raw Pico WebSocket token and returns
// non-secret connection info for the launcher UI.
//
// POST /api/pico/token
func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
@@ -160,18 +247,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token.
gateway.mu.Lock()
gateway.picoToken = token
gateway.mu.Unlock()
wsURL := h.buildWsURL(r)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"token": token,
"ws_url": wsURL,
})
h.writePicoInfoResponse(w, r, cfg, nil)
}
// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't
@@ -234,31 +310,14 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
return
}
// Reload config (EnsurePicoChannel may have modified it) and refresh cache.
// Reload config (EnsurePicoChannel may have modified it).
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
return
}
if changed {
refreshPicoToken(cfg)
}
wsURL := h.buildWsURL(r)
var picoCfg2 config.PicoSettings
if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil {
if decoded, err := bc.GetDecoded(); err == nil && decoded != nil {
picoCfg2 = *decoded.(*config.PicoSettings)
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"token": picoCfg2.Token.String(),
"ws_url": wsURL,
"enabled": true,
"changed": changed,
})
h.writePicoInfoResponse(w, r, cfg, &changed)
}
// generateSecureToken creates a random 32-character hex string.
+68 -23
View File
@@ -11,11 +11,16 @@ import (
"strconv"
"testing"
"github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
func newPicoProxyRequest(method, path string) *http.Request {
req := httptest.NewRequest(method, "http://launcher.local:18800"+path, nil)
req.Header.Set("Origin", "http://launcher.local:18800")
return req
}
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
@@ -365,8 +370,8 @@ func TestHandlePicoSetup_Response(t *testing.T) {
t.Fatalf("failed to decode response: %v", err)
}
if resp["token"] == nil || resp["token"] == "" {
t.Error("response should contain a non-empty token")
if _, ok := resp["token"]; ok {
t.Error("response must not expose the raw pico token")
}
if resp["ws_url"] == nil || resp["ws_url"] == "" {
t.Error("response should contain ws_url")
@@ -377,6 +382,45 @@ func TestHandlePicoSetup_Response(t *testing.T) {
if resp["changed"] != true {
t.Error("response should have changed=true on first setup")
}
if resp["configured"] != true {
t.Error("response should have configured=true")
}
}
func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/api/pico/info", nil)
rec := httptest.NewRecorder()
h.handleGetPicoInfo(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if _, ok := resp["token"]; ok {
t.Fatal("info response must not expose the raw pico token")
}
if resp["enabled"] != true {
t.Fatalf("enabled = %#v, want true", resp["enabled"])
}
if resp["configured"] != true {
t.Fatalf("configured = %#v, want true", resp["configured"])
}
if resp["ws_url"] == nil || resp["ws_url"] == "" {
t.Fatal("response should contain ws_url")
}
}
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
@@ -438,20 +482,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "pico"
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req1.Header.Set(protocolKey, tokenPrefix+"wrong_token")
req1 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
rec1 := httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusForbidden {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden)
}
req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req1.Header.Set(protocolKey, tokenPrefix+"pico")
rec1 = httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusOK {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
}
@@ -464,8 +498,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
t.Fatalf("SaveConfig() error = %v", err)
}
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req2.Header.Set(protocolKey, tokenPrefix+"pico")
req2 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
rec2 := httptest.NewRecorder()
handler(rec2, req2)
@@ -539,8 +572,7 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = ""
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"cached-token")
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
@@ -625,8 +657,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
@@ -634,7 +665,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token"
expected := tokenPrefix + "ui-token"
if got := rec.Body.String(); got != expected {
t.Fatalf("forwarded protocol = %q, want %q", got, expected)
}
@@ -696,8 +727,7 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
@@ -711,6 +741,21 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
}
}
func TestHandleWebSocketProxyRejectsInvalidOrigin(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws", nil)
req.Header.Set("Origin", "http://evil.example")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
}
}
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()