refactor(web): secure Pico websocket access behind launcher auth (#2545)

* refactor(web): secure Pico websocket access behind launcher auth

- stop exposing the raw Pico token to the frontend
- add /api/pico/info for non-secret Pico connection metadata
- proxy /pico/ws through the launcher with same-origin and dashboard auth checks
- inject the upstream Pico websocket protocol server-side
- update frontend chat connection flow and Vite websocket proxy path
- refresh related docs and tests

* fix(web): improve Pico URL and origin handling behind proxies

- read client scheme from X-Forwarded-Proto and RFC 7239 Forwarded
- derive client-visible ports from forwarded host information
- add coverage for HTTPS origins without explicit ports
- verify behavior when proxies omit forwarded protocol headers

* fix(web): stop pinning Pico WebSocket origins during setup

- remove request-origin seeding from `EnsurePicoChannel`
- keep `allow_origins` empty by default for auto-configured Pico channels
- relax launcher Pico WebSocket proxy origin validation
- update Pico backend tests for the new setup and proxy behavior
This commit is contained in:
Guoguo
2026-04-20 11:17:42 +08:00
committed by GitHub
16 changed files with 335 additions and 254 deletions
+1 -1
View File
@@ -27,7 +27,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d
> **Docker Users**: By default, the Gateway listens on `127.0.0.1` which is not accessible from the host. If you need to access the health endpoints or expose ports, set `PICOCLAW_GATEWAY_HOST=0.0.0.0` in your environment or update `config.json`.
> [!NOTE]
> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/token` and a `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled.
> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/info` and an authenticated `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled.
```bash
# 5. Check logs
-2
View File
@@ -18,8 +18,6 @@ const (
TypeError = "error"
TypePong = "pong"
PicoTokenPrefix = "pico-"
PayloadKeyContent = "content"
PayloadKeyThought = "thought"
+1 -23
View File
@@ -9,7 +9,6 @@ import (
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
@@ -27,7 +26,7 @@ import (
_ "github.com/sipeed/picoclaw/pkg/channels/line"
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
"github.com/sipeed/picoclaw/pkg/channels/pico"
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
_ "github.com/sipeed/picoclaw/pkg/channels/teams_webhook"
@@ -316,8 +315,6 @@ func executeReload(
) error {
defer runningServices.reloading.Store(false)
overridePicoToken(newCfg, runningServices.authToken)
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
}
@@ -386,8 +383,6 @@ func setupAndStartServices(
fms.Start()
}
overridePicoToken(cfg, authToken)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
@@ -788,23 +783,6 @@ func setupCronTool(
return cronService, nil
}
// overridePicoToken replaces the pico channel token with the one from the PID file.
// The PID file is the single source of truth for the pico auth token;
// it is generated once at gateway startup and remains unchanged across reloads.
func overridePicoToken(cfg *config.Config, token string) {
picoBC := cfg.Channels.GetByType(config.ChannelPico)
if picoBC == nil || !picoBC.Enabled {
return
}
var picoCfg config.PicoSettings
picoBC.Decode(&picoCfg)
picoToken := picoCfg.Token.String()
if picoToken == "" || strings.HasPrefix(picoToken, pico.PicoTokenPrefix) {
return
}
picoCfg.SetToken(pico.PicoTokenPrefix + token + picoToken)
}
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
return func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
-4
View File
@@ -94,8 +94,6 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token in case user changed it.
refreshPicoToken(&cfg)
h.applyRuntimeLogLevel()
logger.Infof("configuration updated successfully")
@@ -193,8 +191,6 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token in case user changed it.
refreshPicoToken(&newCfg)
h.applyRuntimeLogLevel()
logger.Infof("configuration updated successfully")
+7 -27
View File
@@ -17,7 +17,6 @@ import (
"syscall"
"time"
"github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -37,28 +36,12 @@ var gateway = struct {
startupDeadline time.Time
logs *LogBuffer
pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json
picoToken string // cached pico token from config (for proxy auth validation)
picoToken string // cached raw pico token for upstream gateway proxy injection
}{
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
}
// refreshPicoToken updates gateway.picoToken from cfg
func refreshPicoToken(cfg *config.Config) {
gateway.mu.Lock()
defer gateway.mu.Unlock()
var picoCfg config.PicoSettings
if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil {
decoded, err := bc.GetDecoded()
if err == nil && decoded != nil {
if p, ok := decoded.(*config.PicoSettings); ok {
picoCfg = *p
}
}
}
gateway.picoToken = picoCfg.Token.String()
}
// refreshPicoTokensLocked reads the pico token from config and caches it.
// Caller must hold gateway.mu (or be sole writer).
func refreshPicoTokensLocked(configPath string) {
@@ -101,18 +84,15 @@ const (
tokenPrefix = "token."
)
// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth.
func picoComposedToken(token string) string {
// picoGatewayProtocol returns the gateway-facing pico subprotocol that the
// launcher should inject when proxying browser traffic upstream.
func picoGatewayProtocol() string {
gateway.mu.Lock()
defer gateway.mu.Unlock()
// if not initial pico token, don't allow gateway auth
if gateway.picoToken == "" || gateway.pidData == nil {
if gateway.picoToken == "" {
return ""
}
if tokenPrefix+gateway.picoToken != token {
return ""
}
return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken
return tokenPrefix + gateway.picoToken
}
var (
@@ -752,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
gateway.logs.Reset()
// Ensure Pico Channel is configured before starting gateway
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
// Non-fatal: gateway can still start without pico channel
+36 -14
View File
@@ -85,8 +85,22 @@ func requestHostName(r *http.Request) string {
return netbind.ResolveAdaptiveLoopbackHost()
}
func forwardedProtoFirst(r *http.Request) string {
raw := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto"))
if raw == "" {
raw = forwardedRFC7239Proto(r)
}
if raw == "" {
return ""
}
if i := strings.IndexByte(raw, ','); i >= 0 {
raw = strings.TrimSpace(raw[:i])
}
return strings.ToLower(raw)
}
func requestWSScheme(r *http.Request) string {
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
if forwarded := forwardedProtoFirst(r); forwarded != "" {
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
if proto == "https" || proto == "wss" {
return "wss"
@@ -105,7 +119,7 @@ func requestWSScheme(r *http.Request) string {
// requestHTTPScheme returns http or https for URLs that are not WebSockets (e.g. SSE).
func requestHTTPScheme(r *http.Request) string {
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
if forwarded := forwardedProtoFirst(r); forwarded != "" {
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
if proto == "https" || proto == "wss" {
return "https"
@@ -117,6 +131,7 @@ func requestHTTPScheme(r *http.Request) string {
if r.TLS != nil {
return "https"
}
return "http"
}
@@ -138,6 +153,14 @@ func forwardedHostFirst(r *http.Request) string {
// forwardedRFC7239Host parses host= from the first Forwarded header element (RFC 7239).
func forwardedRFC7239Host(r *http.Request) string {
return forwardedRFC7239Param(r, "host")
}
func forwardedRFC7239Proto(r *http.Request) string {
return forwardedRFC7239Param(r, "proto")
}
func forwardedRFC7239Param(r *http.Request, key string) string {
v := strings.TrimSpace(r.Header.Get("Forwarded"))
if v == "" {
return ""
@@ -146,7 +169,7 @@ func forwardedRFC7239Host(r *http.Request) string {
for _, part := range strings.Split(first, ";") {
part = strings.TrimSpace(part)
low := strings.ToLower(part)
if !strings.HasPrefix(low, "host=") {
if !strings.HasPrefix(low, key+"=") {
continue
}
val := strings.TrimSpace(part[strings.IndexByte(part, '=')+1:])
@@ -177,13 +200,21 @@ func clientVisiblePort(r *http.Request, serverListenPort int) string {
if p := forwardedPortFirst(r); p != "" {
return p
}
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
if _, port, err := net.SplitHostPort(fwdHost); err == nil && port != "" {
return port
}
}
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
return port
}
if strings.TrimSpace(r.Host) == "" && forwardedHostFirst(r) == "" {
return strconv.Itoa(serverListenPort)
}
if requestHTTPScheme(r) == "https" {
return "443"
}
return strconv.Itoa(serverListenPort)
return "80"
}
// joinClientVisibleHostPort builds host:port for absolute URLs returned to the browser.
@@ -205,16 +236,7 @@ func (h *Handler) picoWebUIAddr(r *http.Request) string {
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
return joinClientVisibleHostPort(r, fwdHost, wsPort)
}
host := requestHostName(r)
// Use clientVisiblePort only when an explicit port is present in headers
// or Host header — do not infer from TLS/scheme, as serverPort takes priority.
if p := forwardedPortFirst(r); p != "" {
return net.JoinHostPort(host, p)
}
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
return net.JoinHostPort(host, port)
}
return net.JoinHostPort(host, strconv.Itoa(wsPort))
return joinClientVisibleHostPort(r, requestHostName(r), wsPort)
}
func (h *Handler) buildWsURL(r *http.Request) string {
+29 -12
View File
@@ -50,7 +50,7 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
req.Host = "192.168.1.9:18800"
if got := h.buildWsURL(req); got != "ws://192.168.1.9:18800/pico/ws" {
@@ -181,12 +181,12 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
req.Host = "chat.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws")
if got := h.buildWsURL(req); got != "wss://chat.example.com:443/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:443/pico/ws")
}
}
@@ -198,12 +198,12 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil)
req.Host = "secure.example.com"
req.TLS = &tls.ConnectionState{}
if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws")
if got := h.buildWsURL(req); got != "wss://secure.example.com:443/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:443/pico/ws")
}
}
@@ -224,7 +224,7 @@ func TestBuildPicoURLsPreferXForwardedHost(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/info", nil)
req.Host = "127.0.0.1:18800"
req.Header.Set("X-Forwarded-Host", "vscode-tunnel.example.com")
req.Header.Set("X-Forwarded-Proto", "https")
@@ -249,13 +249,30 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil)
req.Host = "chat.example.com"
req.TLS = &tls.ConnectionState{}
req.Header.Set("X-Forwarded-Proto", "http")
if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws")
if got := h.buildWsURL(req); got != "ws://chat.example.com:80/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:80/pico/ws")
}
}
func TestBuildWsURLDoesNotTrustOriginWhenProxyOmitsForwardedProto(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
if got := h.buildWsURL(req); got != "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws" {
t.Fatalf(
"buildWsURL() = %q, want %q",
got,
"ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws",
)
}
}
@@ -264,7 +281,7 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) {
h := NewHandler(configPath)
h.SetServerOptions(18800, false, false, nil)
req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/token", nil)
req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/info", nil)
req.Host = "localhost:18800"
if got := h.buildWsURL(req); got != "ws://localhost:18800/pico/ws" {
+12
View File
@@ -121,6 +121,18 @@ func resetGatewayTestState(t *testing.T) {
})
}
func TestPicoGatewayProtocol(t *testing.T) {
resetGatewayTestState(t)
gateway.mu.Lock()
gateway.picoToken = "ui-token"
gateway.mu.Unlock()
if got := picoGatewayProtocol(); got != tokenPrefix+"ui-token" {
t.Fatalf("picoGatewayProtocol() = %q, want %q", got, tokenPrefix+"ui-token")
}
}
type gatewayStartEnvSnapshot struct {
GatewayHost string `json:"gateway_host"`
GatewayHostSet bool `json:"gateway_host_set"`
+72 -79
View File
@@ -16,7 +16,7 @@ import (
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken)
mux.HandleFunc("GET /api/pico/info", h.handleGetPicoInfo)
mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken)
mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup)
@@ -28,12 +28,15 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
// The gateway bind host and port are resolved from the latest configuration.
func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy {
func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *httputil.ReverseProxy {
wsProxy := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
target := h.gatewayProxyURL()
r.SetURL(target)
r.Out.Header.Set(protocolKey, tokenPrefix+token)
r.Out.Header.Del(protocolKey)
if upstreamProtocol != "" {
r.Out.Header.Set(protocolKey, upstreamProtocol)
}
},
ModifyResponse: func(r *http.Response) error {
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
@@ -52,8 +55,50 @@ func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.Rev
return wsProxy
}
func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) {
if cfg == nil {
return config.PicoSettings{}, false
}
bc := cfg.Channels.GetByType(config.ChannelPico)
if bc == nil {
return config.PicoSettings{}, false
}
var picoCfg config.PicoSettings
if err := bc.Decode(&picoCfg); err != nil {
return config.PicoSettings{}, false
}
return picoCfg, bc.Enabled
}
func (h *Handler) writePicoInfoResponse(
w http.ResponseWriter,
r *http.Request,
cfg *config.Config,
changed *bool,
) {
picoCfg, enabled := decodePicoSettings(cfg)
resp := map[string]any{
"ws_url": h.buildWsURL(r),
"enabled": enabled,
}
if changed != nil {
resp["changed"] = *changed
}
if picoCfg.Token.String() != "" {
resp["configured"] = true
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
// It validates the client token before forwarding; rejects immediately on failure.
// It relies on launcher dashboard auth, then injects the raw pico token only
// on the upstream gateway request.
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
gateway.mu.Lock()
@@ -91,51 +136,38 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
return
}
prot := r.Header.Values(protocolKey)
if len(prot) > 0 {
origProtocol := prot[0]
newToken := picoComposedToken(prot[0])
if newToken != "" {
h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r)
return
}
upstreamProtocol := picoGatewayProtocol()
if upstreamProtocol == "" {
logger.Warn("Pico token unavailable for WebSocket proxy")
http.Error(w, "Pico channel not configured", http.StatusServiceUnavailable)
return
}
logger.Warnf("Invalid Pico token: %v", prot)
http.Error(w, "Invalid Pico token", http.StatusForbidden)
var origProtocol string
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
origProtocol = prot[0]
}
h.createWsProxy(origProtocol, upstreamProtocol).ServeHTTP(w, r)
}
}
// handleGetPicoToken returns the current WS token and URL for the frontend.
// handleGetPicoInfo returns non-secret Pico connection info for the launcher UI.
//
// GET /api/pico/token
func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) {
// GET /api/pico/info
func (h *Handler) handleGetPicoInfo(w http.ResponseWriter, r *http.Request) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
return
}
wsURL := h.buildWsURL(r)
w.Header().Set("Content-Type", "application/json")
bc := cfg.Channels.GetByType(config.ChannelPico)
var picoCfg config.PicoSettings
if bc != nil {
bc.Decode(&picoCfg)
}
enabled := false
if bc != nil {
enabled = bc.Enabled
}
json.NewEncoder(w).Encode(map[string]any{
"token": picoCfg.Token.String(),
"ws_url": wsURL,
"enabled": enabled,
})
h.writePicoInfoResponse(w, r, cfg, nil)
}
// handleRegenPicoToken generates a new Pico WebSocket token and saves it.
// handleRegenPicoToken rotates the raw Pico WebSocket token and returns
// non-secret connection info for the launcher UI.
//
// POST /api/pico/token
func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
@@ -160,28 +192,12 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token.
gateway.mu.Lock()
gateway.picoToken = token
gateway.mu.Unlock()
wsURL := h.buildWsURL(r)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"token": token,
"ws_url": wsURL,
})
h.writePicoInfoResponse(w, r, cfg, nil)
}
// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't
// already configured. Returns true when the config was modified.
//
// callerOrigin is the Origin header from the setup request. If non-empty and
// no origins are configured yet, it's written as the allowed origin so the
// WebSocket handshake works for whatever host the caller is on (LAN, custom
// port, etc.). Pass "" when there's no request context.
func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
func (h *Handler) EnsurePicoChannel() (bool, error) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
return false, fmt.Errorf("failed to load config: %w", err)
@@ -206,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
picoCfg.Token = *config.NewSecureString(generateSecureToken())
changed = true
}
// Seed origins from the request instead of hardcoding ports.
if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" {
picoCfg.AllowOrigins = []string{callerOrigin}
changed = true
}
}
}
@@ -228,37 +238,20 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
//
// POST /api/pico/setup
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
changed, err := h.EnsurePicoChannel(r.Header.Get("Origin"))
changed, err := h.EnsurePicoChannel()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Reload config (EnsurePicoChannel may have modified it) and refresh cache.
// Reload config (EnsurePicoChannel may have modified it).
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
return
}
if changed {
refreshPicoToken(cfg)
}
wsURL := h.buildWsURL(r)
var picoCfg2 config.PicoSettings
if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil {
if decoded, err := bc.GetDecoded(); err == nil && decoded != nil {
picoCfg2 = *decoded.(*config.PicoSettings)
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"token": picoCfg2.Token.String(),
"ws_url": wsURL,
"enabled": true,
"changed": changed,
})
h.writePicoInfoResponse(w, r, cfg, &changed)
}
// generateSecureToken creates a random 32-character hex string.
+142 -71
View File
@@ -11,16 +11,21 @@ import (
"strconv"
"testing"
"github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
func newPicoProxyRequest(method, path string) *http.Request {
req := httptest.NewRequest(method, "http://launcher.local:18800"+path, nil)
req.Header.Set("Origin", "http://launcher.local:18800")
return req
}
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -51,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -71,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
}
}
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -90,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
for _, origin := range picoCfg.AllowOrigins {
if origin == "*" {
t.Error("setup must not set wildcard origin '*'")
}
}
}
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
bc := cfg.Channels["pico"]
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
// Without a caller origin, allow_origins stays empty (CheckOrigin
// allows all when the list is empty, so the channel still works).
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins)
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
lanOrigin := "http://192.168.1.9:18800"
if _, err := h.EnsurePicoChannel(lanOrigin); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -143,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin {
t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin)
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
@@ -169,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -213,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) {
h := NewHandler(configPath)
changed, err := h.EnsurePicoChannel("")
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -253,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) {
}
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(""); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
@@ -280,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
origin := "http://localhost:18800"
// First call sets things up
if _, err := h.EnsurePicoChannel(origin); err != nil {
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("first EnsurePicoChannel() error = %v", err)
}
@@ -297,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
token1 := picoCfg.Token.String()
// Second call should be a no-op
changed, err := h.EnsurePicoChannel(origin)
changed, err := h.EnsurePicoChannel()
if err != nil {
t.Fatalf("second EnsurePicoChannel() error = %v", err)
}
@@ -317,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
}
}
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
@@ -342,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
t.Fatalf("GetDecoded() error = %v", err)
}
picoCfg := decoded.(*config.PicoSettings)
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" {
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins)
if len(picoCfg.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
}
}
@@ -365,8 +339,8 @@ func TestHandlePicoSetup_Response(t *testing.T) {
t.Fatalf("failed to decode response: %v", err)
}
if resp["token"] == nil || resp["token"] == "" {
t.Error("response should contain a non-empty token")
if _, ok := resp["token"]; ok {
t.Error("response must not expose the raw pico token")
}
if resp["ws_url"] == nil || resp["ws_url"] == "" {
t.Error("response should contain ws_url")
@@ -377,6 +351,45 @@ func TestHandlePicoSetup_Response(t *testing.T) {
if resp["changed"] != true {
t.Error("response should have changed=true on first setup")
}
if resp["configured"] != true {
t.Error("response should have configured=true")
}
}
func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.EnsurePicoChannel(); err != nil {
t.Fatalf("EnsurePicoChannel() error = %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/api/pico/info", nil)
rec := httptest.NewRecorder()
h.handleGetPicoInfo(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if _, ok := resp["token"]; ok {
t.Fatal("info response must not expose the raw pico token")
}
if resp["enabled"] != true {
t.Fatalf("enabled = %#v, want true", resp["enabled"])
}
if resp["configured"] != true {
t.Fatalf("configured = %#v, want true", resp["configured"])
}
if resp["ws_url"] == nil || resp["ws_url"] == "" {
t.Fatal("response should contain ws_url")
}
}
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
@@ -438,20 +451,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "pico"
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req1.Header.Set(protocolKey, tokenPrefix+"wrong_token")
req1 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
rec1 := httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusForbidden {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden)
}
req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req1.Header.Set(protocolKey, tokenPrefix+"pico")
rec1 = httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusOK {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
}
@@ -464,8 +467,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
t.Fatalf("SaveConfig() error = %v", err)
}
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req2.Header.Set(protocolKey, tokenPrefix+"pico")
req2 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
rec2 := httptest.NewRecorder()
handler(rec2, req2)
@@ -539,8 +541,7 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = ""
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"cached-token")
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
@@ -625,8 +626,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
@@ -634,7 +634,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token"
expected := tokenPrefix + "ui-token"
if got := rec.Body.String(); got != expected {
t.Fatalf("forwarded protocol = %q, want %q", got, expected)
}
@@ -696,8 +696,7 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
rec := httptest.NewRecorder()
handler(rec, req)
@@ -711,6 +710,78 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
}
}
func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) {
origMatcher := gatewayProcessMatcher
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "proxied")
}))
defer server.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
bc := cfg.Channels["pico"]
bc.Enabled = true
decoded, err := bc.GetDecoded()
if err != nil {
t.Fatalf("GetDecoded() error = %v", err)
}
decoded.(*config.PicoSettings).SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startGatewayLikeProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
writeTestPidFile(t, ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "test-token",
Host: cfg.Gateway.Host,
Port: cfg.Gateway.Port,
})
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
t.Cleanup(func() {
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
})
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "ui-token"
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil)
req.Header.Set("Origin", "http://evil.example")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
}
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()
+1 -1
View File
@@ -544,7 +544,7 @@ func main() {
// API Routes (e.g. /api/status)
apiHandler = api.NewHandler(absPath)
apiHandler.SetDebug(debug)
if _, err = apiHandler.EnsurePicoChannel(""); err != nil {
if _, err = apiHandler.EnsurePicoChannel(); err != nil {
logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err))
}
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
@@ -218,6 +218,10 @@ func validLauncherDashboardAuth(r *http.Request, cfg LauncherDashboardAuthConfig
}
func rejectLauncherDashboardAuth(w http.ResponseWriter, r *http.Request, canonicalPath string) {
if canonicalPath == "/pico/ws" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if strings.HasPrefix(canonicalPath, "/api/") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
@@ -40,6 +40,7 @@ func TestLauncherDashboardAuth_AllowsPublicPaths(t *testing.T) {
{http.MethodPost, "/api/auth/logout", http.StatusTeapot},
{http.MethodGet, "/api/auth/logout", http.StatusUnauthorized},
{http.MethodGet, "/api/config", http.StatusUnauthorized},
{http.MethodGet, "/pico/ws", http.StatusUnauthorized},
} {
rec := httptest.NewRecorder()
req := httptest.NewRequest(tc.method, tc.path, nil)
@@ -160,3 +161,22 @@ func TestLauncherDashboardAuth_CookieAndBearer(t *testing.T) {
t.Fatalf("bearer auth: status = %d", rec2.Code)
}
}
func TestLauncherDashboardAuth_WebSocketUnauthorizedDoesNotRedirect(t *testing.T) {
cfg := LauncherDashboardAuthConfig{ExpectedCookie: "deadbeef", Token: "x"}
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
t.Fatal("next handler should not run without auth")
})
h := LauncherDashboardAuth(cfg, next)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
}
if got := rec.Header().Get("Location"); got != "" {
t.Fatalf("Location = %q, want empty", got)
}
}
+8 -8
View File
@@ -2,16 +2,16 @@ import { launcherFetch } from "@/api/http"
// API client for Pico Channel configuration.
interface PicoTokenResponse {
token: string
interface PicoInfoResponse {
ws_url: string
enabled: boolean
configured?: boolean
}
interface PicoSetupResponse {
token: string
ws_url: string
enabled: boolean
configured?: boolean
changed: boolean
}
@@ -25,16 +25,16 @@ async function request<T>(path: string, options?: RequestInit): Promise<T> {
return res.json() as Promise<T>
}
export async function getPicoToken(): Promise<PicoTokenResponse> {
return request<PicoTokenResponse>("/api/pico/token")
export async function getPicoInfo(): Promise<PicoInfoResponse> {
return request<PicoInfoResponse>("/api/pico/info")
}
export async function regenPicoToken(): Promise<PicoTokenResponse> {
return request<PicoTokenResponse>("/api/pico/token", { method: "POST" })
export async function regenPicoToken(): Promise<PicoInfoResponse> {
return request<PicoInfoResponse>("/api/pico/token", { method: "POST" })
}
export async function setupPico(): Promise<PicoSetupResponse> {
return request<PicoSetupResponse>("/api/pico/setup", { method: "POST" })
}
export type { PicoTokenResponse, PicoSetupResponse }
export type { PicoInfoResponse, PicoSetupResponse }
+1 -11
View File
@@ -1,7 +1,6 @@
import { getDefaultStore } from "jotai"
import { toast } from "sonner"
import { getPicoToken } from "@/api/pico"
import {
loadSessionMessages,
mergeHistoryMessages,
@@ -131,7 +130,6 @@ export async function connectChat() {
updateChatStore({ connectionState: "connecting" })
try {
const { token } = await getPicoToken()
const sessionId = activeSessionIdRef
if (generation !== connectionGeneration) {
@@ -139,18 +137,10 @@ export async function connectChat() {
return
}
if (!token) {
console.error("No pico token available")
updateChatStore({ connectionState: "error" })
isConnecting = false
scheduleReconnect(generation, sessionId)
return
}
const wsScheme = window.location.protocol === "https:" ? "wss:" : "ws:"
const wsUrl = `${wsScheme}//${window.location.host}/pico/ws`
const url = `${wsUrl}?session_id=${encodeURIComponent(sessionId)}`
const socket = new WebSocket(url, [`token.${token}`])
const socket = new WebSocket(url)
if (generation !== connectionGeneration) {
isConnecting = false
+1 -1
View File
@@ -29,7 +29,7 @@ export default defineConfig({
target: "http://localhost:18800",
changeOrigin: true,
},
"/ws": {
"/pico/ws": {
target: "ws://localhost:18800",
ws: true,
},