diff --git a/docs/guides/docker.md b/docs/guides/docker.md index 6c32879a6..3ccc7a2a7 100644 --- a/docs/guides/docker.md +++ b/docs/guides/docker.md @@ -27,7 +27,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d > **Docker Users**: By default, the Gateway listens on `127.0.0.1` which is not accessible from the host. If you need to access the health endpoints or expose ports, set `PICOCLAW_GATEWAY_HOST=0.0.0.0` in your environment or update `config.json`. > [!NOTE] -> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/token` and a `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled. +> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/info` and an authenticated `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled. ```bash # 5. Check logs diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index ecdc2d140..051beed1b 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -18,8 +18,6 @@ const ( TypeError = "error" TypePong = "pong" - PicoTokenPrefix = "pico-" - PayloadKeyContent = "content" PayloadKeyThought = "thought" diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 039f45075..f58590d5b 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -9,7 +9,6 @@ import ( "path/filepath" "sort" "strconv" - "strings" "sync" "sync/atomic" "syscall" @@ -27,7 +26,7 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" - "github.com/sipeed/picoclaw/pkg/channels/pico" + _ "github.com/sipeed/picoclaw/pkg/channels/pico" _ "github.com/sipeed/picoclaw/pkg/channels/qq" _ "github.com/sipeed/picoclaw/pkg/channels/slack" _ "github.com/sipeed/picoclaw/pkg/channels/teams_webhook" @@ -316,8 +315,6 @@ func executeReload( ) error { defer runningServices.reloading.Store(false) - overridePicoToken(newCfg, runningServices.authToken) - return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug) } @@ -386,8 +383,6 @@ func setupAndStartServices( fms.Start() } - overridePicoToken(cfg, authToken) - runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore) if err != nil { if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok { @@ -788,23 +783,6 @@ func setupCronTool( return cronService, nil } -// overridePicoToken replaces the pico channel token with the one from the PID file. -// The PID file is the single source of truth for the pico auth token; -// it is generated once at gateway startup and remains unchanged across reloads. -func overridePicoToken(cfg *config.Config, token string) { - picoBC := cfg.Channels.GetByType(config.ChannelPico) - if picoBC == nil || !picoBC.Enabled { - return - } - var picoCfg config.PicoSettings - picoBC.Decode(&picoCfg) - picoToken := picoCfg.Token.String() - if picoToken == "" || strings.HasPrefix(picoToken, pico.PicoTokenPrefix) { - return - } - picoCfg.SetToken(pico.PicoTokenPrefix + token + picoToken) -} - func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult { return func(prompt, channel, chatID string) *tools.ToolResult { if channel == "" || chatID == "" { diff --git a/web/backend/api/config.go b/web/backend/api/config.go index 80ab80f35..c7bd21197 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -94,8 +94,6 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { return } - // Refresh cached pico token in case user changed it. - refreshPicoToken(&cfg) h.applyRuntimeLogLevel() logger.Infof("configuration updated successfully") @@ -193,8 +191,6 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } - // Refresh cached pico token in case user changed it. - refreshPicoToken(&newCfg) h.applyRuntimeLogLevel() logger.Infof("configuration updated successfully") diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index fa5652323..201000ff3 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -17,7 +17,6 @@ import ( "syscall" "time" - "github.com/sipeed/picoclaw/pkg/channels/pico" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" @@ -37,28 +36,12 @@ var gateway = struct { startupDeadline time.Time logs *LogBuffer pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json - picoToken string // cached pico token from config (for proxy auth validation) + picoToken string // cached raw pico token for upstream gateway proxy injection }{ runtimeStatus: "stopped", logs: NewLogBuffer(200), } -// refreshPicoToken updates gateway.picoToken from cfg -func refreshPicoToken(cfg *config.Config) { - gateway.mu.Lock() - defer gateway.mu.Unlock() - var picoCfg config.PicoSettings - if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil { - decoded, err := bc.GetDecoded() - if err == nil && decoded != nil { - if p, ok := decoded.(*config.PicoSettings); ok { - picoCfg = *p - } - } - } - gateway.picoToken = picoCfg.Token.String() -} - // refreshPicoTokensLocked reads the pico token from config and caches it. // Caller must hold gateway.mu (or be sole writer). func refreshPicoTokensLocked(configPath string) { @@ -101,18 +84,15 @@ const ( tokenPrefix = "token." ) -// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth. -func picoComposedToken(token string) string { +// picoGatewayProtocol returns the gateway-facing pico subprotocol that the +// launcher should inject when proxying browser traffic upstream. +func picoGatewayProtocol() string { gateway.mu.Lock() defer gateway.mu.Unlock() - // if not initial pico token, don't allow gateway auth - if gateway.picoToken == "" || gateway.pidData == nil { + if gateway.picoToken == "" { return "" } - if tokenPrefix+gateway.picoToken != token { - return "" - } - return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken + return tokenPrefix + gateway.picoToken } var ( @@ -752,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int gateway.logs.Reset() // Ensure Pico Channel is configured before starting gateway - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err)) // Non-fatal: gateway can still start without pico channel diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go index c6c2073e2..03af7a9d3 100644 --- a/web/backend/api/gateway_host.go +++ b/web/backend/api/gateway_host.go @@ -85,8 +85,22 @@ func requestHostName(r *http.Request) string { return netbind.ResolveAdaptiveLoopbackHost() } +func forwardedProtoFirst(r *http.Request) string { + raw := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")) + if raw == "" { + raw = forwardedRFC7239Proto(r) + } + if raw == "" { + return "" + } + if i := strings.IndexByte(raw, ','); i >= 0 { + raw = strings.TrimSpace(raw[:i]) + } + return strings.ToLower(raw) +} + func requestWSScheme(r *http.Request) string { - if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" { + if forwarded := forwardedProtoFirst(r); forwarded != "" { proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0])) if proto == "https" || proto == "wss" { return "wss" @@ -105,7 +119,7 @@ func requestWSScheme(r *http.Request) string { // requestHTTPScheme returns http or https for URLs that are not WebSockets (e.g. SSE). func requestHTTPScheme(r *http.Request) string { - if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" { + if forwarded := forwardedProtoFirst(r); forwarded != "" { proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0])) if proto == "https" || proto == "wss" { return "https" @@ -117,6 +131,7 @@ func requestHTTPScheme(r *http.Request) string { if r.TLS != nil { return "https" } + return "http" } @@ -138,6 +153,14 @@ func forwardedHostFirst(r *http.Request) string { // forwardedRFC7239Host parses host= from the first Forwarded header element (RFC 7239). func forwardedRFC7239Host(r *http.Request) string { + return forwardedRFC7239Param(r, "host") +} + +func forwardedRFC7239Proto(r *http.Request) string { + return forwardedRFC7239Param(r, "proto") +} + +func forwardedRFC7239Param(r *http.Request, key string) string { v := strings.TrimSpace(r.Header.Get("Forwarded")) if v == "" { return "" @@ -146,7 +169,7 @@ func forwardedRFC7239Host(r *http.Request) string { for _, part := range strings.Split(first, ";") { part = strings.TrimSpace(part) low := strings.ToLower(part) - if !strings.HasPrefix(low, "host=") { + if !strings.HasPrefix(low, key+"=") { continue } val := strings.TrimSpace(part[strings.IndexByte(part, '=')+1:]) @@ -177,13 +200,21 @@ func clientVisiblePort(r *http.Request, serverListenPort int) string { if p := forwardedPortFirst(r); p != "" { return p } + if fwdHost := forwardedHostFirst(r); fwdHost != "" { + if _, port, err := net.SplitHostPort(fwdHost); err == nil && port != "" { + return port + } + } if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" { return port } + if strings.TrimSpace(r.Host) == "" && forwardedHostFirst(r) == "" { + return strconv.Itoa(serverListenPort) + } if requestHTTPScheme(r) == "https" { return "443" } - return strconv.Itoa(serverListenPort) + return "80" } // joinClientVisibleHostPort builds host:port for absolute URLs returned to the browser. @@ -205,16 +236,7 @@ func (h *Handler) picoWebUIAddr(r *http.Request) string { if fwdHost := forwardedHostFirst(r); fwdHost != "" { return joinClientVisibleHostPort(r, fwdHost, wsPort) } - host := requestHostName(r) - // Use clientVisiblePort only when an explicit port is present in headers - // or Host header — do not infer from TLS/scheme, as serverPort takes priority. - if p := forwardedPortFirst(r); p != "" { - return net.JoinHostPort(host, p) - } - if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" { - return net.JoinHostPort(host, port) - } - return net.JoinHostPort(host, strconv.Itoa(wsPort)) + return joinClientVisibleHostPort(r, requestHostName(r), wsPort) } func (h *Handler) buildWsURL(r *http.Request) string { diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go index d0fc26d7b..54d1010d2 100644 --- a/web/backend/api/gateway_host_test.go +++ b/web/backend/api/gateway_host_test.go @@ -50,7 +50,7 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) { cfg.Gateway.Host = "127.0.0.1" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil) req.Host = "192.168.1.9:18800" if got := h.buildWsURL(req); got != "ws://192.168.1.9:18800/pico/ws" { @@ -181,12 +181,12 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil) req.Host = "chat.example.com" req.Header.Set("X-Forwarded-Proto", "https") - if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "wss://chat.example.com:443/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:443/pico/ws") } } @@ -198,12 +198,12 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil) req.Host = "secure.example.com" req.TLS = &tls.ConnectionState{} - if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "wss://secure.example.com:443/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:443/pico/ws") } } @@ -224,7 +224,7 @@ func TestBuildPicoURLsPreferXForwardedHost(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/info", nil) req.Host = "127.0.0.1:18800" req.Header.Set("X-Forwarded-Host", "vscode-tunnel.example.com") req.Header.Set("X-Forwarded-Proto", "https") @@ -249,13 +249,30 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) { cfg.Gateway.Host = "0.0.0.0" cfg.Gateway.Port = 18790 - req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil) + req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil) req.Host = "chat.example.com" req.TLS = &tls.ConnectionState{} req.Header.Set("X-Forwarded-Proto", "http") - if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" { - t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws") + if got := h.buildWsURL(req); got != "ws://chat.example.com:80/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:80/pico/ws") + } +} + +func TestBuildWsURLDoesNotTrustOriginWhenProxyOmitsForwardedProto(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil) + req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com" + req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com") + + if got := h.buildWsURL(req); got != "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws" { + t.Fatalf( + "buildWsURL() = %q, want %q", + got, + "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws", + ) } } @@ -264,7 +281,7 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) { h := NewHandler(configPath) h.SetServerOptions(18800, false, false, nil) - req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/token", nil) + req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/info", nil) req.Host = "localhost:18800" if got := h.buildWsURL(req); got != "ws://localhost:18800/pico/ws" { diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index 78bf34a63..998ed3317 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -121,6 +121,18 @@ func resetGatewayTestState(t *testing.T) { }) } +func TestPicoGatewayProtocol(t *testing.T) { + resetGatewayTestState(t) + + gateway.mu.Lock() + gateway.picoToken = "ui-token" + gateway.mu.Unlock() + + if got := picoGatewayProtocol(); got != tokenPrefix+"ui-token" { + t.Fatalf("picoGatewayProtocol() = %q, want %q", got, tokenPrefix+"ui-token") + } +} + type gatewayStartEnvSnapshot struct { GatewayHost string `json:"gateway_host"` GatewayHostSet bool `json:"gateway_host_set"` diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 00ffb8bb2..ffd0796c7 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -16,7 +16,7 @@ import ( // registerPicoRoutes binds Pico Channel management endpoints to the ServeMux. func (h *Handler) registerPicoRoutes(mux *http.ServeMux) { - mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken) + mux.HandleFunc("GET /api/pico/info", h.handleGetPicoInfo) mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken) mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup) @@ -28,12 +28,15 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) { // createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint. // The gateway bind host and port are resolved from the latest configuration. -func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy { +func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *httputil.ReverseProxy { wsProxy := &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { target := h.gatewayProxyURL() r.SetURL(target) - r.Out.Header.Set(protocolKey, tokenPrefix+token) + r.Out.Header.Del(protocolKey) + if upstreamProtocol != "" { + r.Out.Header.Set(protocolKey, upstreamProtocol) + } }, ModifyResponse: func(r *http.Response) error { if prot := r.Header.Values(protocolKey); len(prot) > 0 { @@ -52,8 +55,50 @@ func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.Rev return wsProxy } +func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) { + if cfg == nil { + return config.PicoSettings{}, false + } + + bc := cfg.Channels.GetByType(config.ChannelPico) + if bc == nil { + return config.PicoSettings{}, false + } + + var picoCfg config.PicoSettings + if err := bc.Decode(&picoCfg); err != nil { + return config.PicoSettings{}, false + } + + return picoCfg, bc.Enabled +} + +func (h *Handler) writePicoInfoResponse( + w http.ResponseWriter, + r *http.Request, + cfg *config.Config, + changed *bool, +) { + picoCfg, enabled := decodePicoSettings(cfg) + + resp := map[string]any{ + "ws_url": h.buildWsURL(r), + "enabled": enabled, + } + if changed != nil { + resp["changed"] = *changed + } + if picoCfg.Token.String() != "" { + resp["configured"] = true + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + // handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections. -// It validates the client token before forwarding; rejects immediately on failure. +// It relies on launcher dashboard auth, then injects the raw pico token only +// on the upstream gateway request. func (h *Handler) handleWebSocketProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { gateway.mu.Lock() @@ -91,51 +136,38 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc { http.Error(w, "Gateway not available", http.StatusServiceUnavailable) return } - prot := r.Header.Values(protocolKey) - if len(prot) > 0 { - origProtocol := prot[0] - newToken := picoComposedToken(prot[0]) - if newToken != "" { - h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r) - return - } + + upstreamProtocol := picoGatewayProtocol() + if upstreamProtocol == "" { + logger.Warn("Pico token unavailable for WebSocket proxy") + http.Error(w, "Pico channel not configured", http.StatusServiceUnavailable) + return } - logger.Warnf("Invalid Pico token: %v", prot) - http.Error(w, "Invalid Pico token", http.StatusForbidden) + var origProtocol string + if prot := r.Header.Values(protocolKey); len(prot) > 0 { + origProtocol = prot[0] + } + + h.createWsProxy(origProtocol, upstreamProtocol).ServeHTTP(w, r) } } -// handleGetPicoToken returns the current WS token and URL for the frontend. +// handleGetPicoInfo returns non-secret Pico connection info for the launcher UI. // -// GET /api/pico/token -func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) { +// GET /api/pico/info +func (h *Handler) handleGetPicoInfo(w http.ResponseWriter, r *http.Request) { cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } - wsURL := h.buildWsURL(r) - - w.Header().Set("Content-Type", "application/json") - bc := cfg.Channels.GetByType(config.ChannelPico) - var picoCfg config.PicoSettings - if bc != nil { - bc.Decode(&picoCfg) - } - enabled := false - if bc != nil { - enabled = bc.Enabled - } - json.NewEncoder(w).Encode(map[string]any{ - "token": picoCfg.Token.String(), - "ws_url": wsURL, - "enabled": enabled, - }) + h.writePicoInfoResponse(w, r, cfg, nil) } -// handleRegenPicoToken generates a new Pico WebSocket token and saves it. +// handleRegenPicoToken rotates the raw Pico WebSocket token and returns +// non-secret connection info for the launcher UI. // // POST /api/pico/token func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { @@ -160,28 +192,12 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { return } - // Refresh cached pico token. - gateway.mu.Lock() - gateway.picoToken = token - gateway.mu.Unlock() - - wsURL := h.buildWsURL(r) - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "token": token, - "ws_url": wsURL, - }) + h.writePicoInfoResponse(w, r, cfg, nil) } // EnsurePicoChannel enables the Pico channel with sane defaults if it isn't // already configured. Returns true when the config was modified. -// -// callerOrigin is the Origin header from the setup request. If non-empty and -// no origins are configured yet, it's written as the allowed origin so the -// WebSocket handshake works for whatever host the caller is on (LAN, custom -// port, etc.). Pass "" when there's no request context. -func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { +func (h *Handler) EnsurePicoChannel() (bool, error) { cfg, err := config.LoadConfig(h.configPath) if err != nil { return false, fmt.Errorf("failed to load config: %w", err) @@ -206,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { picoCfg.Token = *config.NewSecureString(generateSecureToken()) changed = true } - - // Seed origins from the request instead of hardcoding ports. - if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" { - picoCfg.AllowOrigins = []string{callerOrigin} - changed = true - } } } @@ -228,37 +238,20 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) { // // POST /api/pico/setup func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { - changed, err := h.EnsurePicoChannel(r.Header.Get("Origin")) + changed, err := h.EnsurePicoChannel() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - // Reload config (EnsurePicoChannel may have modified it) and refresh cache. + // Reload config (EnsurePicoChannel may have modified it). cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } - if changed { - refreshPicoToken(cfg) - } - wsURL := h.buildWsURL(r) - - var picoCfg2 config.PicoSettings - if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil { - if decoded, err := bc.GetDecoded(); err == nil && decoded != nil { - picoCfg2 = *decoded.(*config.PicoSettings) - } - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "token": picoCfg2.Token.String(), - "ws_url": wsURL, - "enabled": true, - "changed": changed, - }) + h.writePicoInfoResponse(w, r, cfg, &changed) } // generateSecureToken creates a random 32-character hex string. diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 807c796dc..a56cd9ba2 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -11,16 +11,21 @@ import ( "strconv" "testing" - "github.com/sipeed/picoclaw/pkg/channels/pico" "github.com/sipeed/picoclaw/pkg/config" ppid "github.com/sipeed/picoclaw/pkg/pid" ) +func newPicoProxyRequest(method, path string) *http.Request { + req := httptest.NewRequest(method, "http://launcher.local:18800"+path, nil) + req.Header.Set("Origin", "http://launcher.local:18800") + return req +} + func TestEnsurePicoChannel_FreshConfig(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -51,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -71,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { } } -func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { +func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -90,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - for _, origin := range picoCfg.AllowOrigins { - if origin == "*" { - t.Error("setup must not set wildcard origin '*'") - } - } -} - -func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "config.json") - h := NewHandler(configPath) - - if _, err := h.EnsurePicoChannel(""); err != nil { - t.Fatalf("EnsurePicoChannel() error = %v", err) - } - - cfg, err := config.LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error = %v", err) - } - - bc := cfg.Channels["pico"] - decoded, err := bc.GetDecoded() - if err != nil { - t.Fatalf("GetDecoded() error = %v", err) - } - picoCfg := decoded.(*config.PicoSettings) - // Without a caller origin, allow_origins stays empty (CheckOrigin - // allows all when the list is empty, so the channel still works). if len(picoCfg.AllowOrigins) != 0 { - t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins) + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } -func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { +func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - lanOrigin := "http://192.168.1.9:18800" - if _, err := h.EnsurePicoChannel(lanOrigin); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -143,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin { - t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin) + if len(picoCfg.AllowOrigins) != 0 { + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } @@ -169,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -213,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) { h := NewHandler(configPath) - changed, err := h.EnsurePicoChannel("") + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -253,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) { } h := NewHandler(configPath) - if _, err := h.EnsurePicoChannel(""); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("EnsurePicoChannel() error = %v", err) } @@ -280,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - origin := "http://localhost:18800" - // First call sets things up - if _, err := h.EnsurePicoChannel(origin); err != nil { + if _, err := h.EnsurePicoChannel(); err != nil { t.Fatalf("first EnsurePicoChannel() error = %v", err) } @@ -297,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { token1 := picoCfg.Token.String() // Second call should be a no-op - changed, err := h.EnsurePicoChannel(origin) + changed, err := h.EnsurePicoChannel() if err != nil { t.Fatalf("second EnsurePicoChannel() error = %v", err) } @@ -317,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) { } } -func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) { +func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -342,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) { t.Fatalf("GetDecoded() error = %v", err) } picoCfg := decoded.(*config.PicoSettings) - if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" { - t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins) + if len(picoCfg.AllowOrigins) != 0 { + t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins) } } @@ -365,8 +339,8 @@ func TestHandlePicoSetup_Response(t *testing.T) { t.Fatalf("failed to decode response: %v", err) } - if resp["token"] == nil || resp["token"] == "" { - t.Error("response should contain a non-empty token") + if _, ok := resp["token"]; ok { + t.Error("response must not expose the raw pico token") } if resp["ws_url"] == nil || resp["ws_url"] == "" { t.Error("response should contain ws_url") @@ -377,6 +351,45 @@ func TestHandlePicoSetup_Response(t *testing.T) { if resp["changed"] != true { t.Error("response should have changed=true on first setup") } + if resp["configured"] != true { + t.Error("response should have configured=true") + } +} + +func TestHandleGetPicoInfo_OmitsToken(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + if _, err := h.EnsurePicoChannel(); err != nil { + t.Fatalf("EnsurePicoChannel() error = %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/api/pico/info", nil) + rec := httptest.NewRecorder() + + h.handleGetPicoInfo(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if _, ok := resp["token"]; ok { + t.Fatal("info response must not expose the raw pico token") + } + if resp["enabled"] != true { + t.Fatalf("enabled = %#v, want true", resp["enabled"]) + } + if resp["configured"] != true { + t.Fatalf("configured = %#v, want true", resp["configured"]) + } + if resp["ws_url"] == nil || resp["ws_url"] == "" { + t.Fatal("response should contain ws_url") + } } func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { @@ -438,20 +451,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { gateway.pidData = &ppid.PidFileData{} gateway.picoToken = "pico" - req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) - req1.Header.Set(protocolKey, tokenPrefix+"wrong_token") + req1 := newPicoProxyRequest(http.MethodGet, "/pico/ws") rec1 := httptest.NewRecorder() handler(rec1, req1) - if rec1.Code != http.StatusForbidden { - t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden) - } - - req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil) - req1.Header.Set(protocolKey, tokenPrefix+"pico") - rec1 = httptest.NewRecorder() - handler(rec1, req1) - if rec1.Code != http.StatusOK { t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK) } @@ -464,8 +467,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { t.Fatalf("SaveConfig() error = %v", err) } - req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) - req2.Header.Set(protocolKey, tokenPrefix+"pico") + req2 := newPicoProxyRequest(http.MethodGet, "/pico/ws") rec2 := httptest.NewRecorder() handler(rec2, req2) @@ -539,8 +541,7 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) { gateway.pidData = &ppid.PidFileData{} gateway.picoToken = "" - req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil) - req.Header.Set(protocolKey, tokenPrefix+"cached-token") + req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session") rec := httptest.NewRecorder() handler(rec, req) @@ -625,8 +626,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) { setGatewayRuntimeStatusLocked("stopped") gateway.mu.Unlock() - req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil) - req.Header.Set(protocolKey, tokenPrefix+"ui-token") + req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session") rec := httptest.NewRecorder() handler(rec, req) @@ -634,7 +634,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } - expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token" + expected := tokenPrefix + "ui-token" if got := rec.Body.String(); got != expected { t.Fatalf("forwarded protocol = %q, want %q", got, expected) } @@ -696,8 +696,7 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) { setGatewayRuntimeStatusLocked("running") gateway.mu.Unlock() - req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil) - req.Header.Set(protocolKey, tokenPrefix+"ui-token") + req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session") rec := httptest.NewRecorder() handler(rec, req) @@ -711,6 +710,78 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) { } } +func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) { + origMatcher := gatewayProcessMatcher + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } + t.Cleanup(func() { gatewayProcessMatcher = origMatcher }) + + home := t.TempDir() + t.Setenv("PICOCLAW_HOME", home) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + handler := h.handleWebSocketProxy() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/pico/ws" { + t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws") + } + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "proxied") + })) + defer server.Close() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = mustGatewayTestPort(t, server.URL) + bc := cfg.Channels["pico"] + bc.Enabled = true + decoded, err := bc.GetDecoded() + if err != nil { + t.Fatalf("GetDecoded() error = %v", err) + } + decoded.(*config.PicoSettings).SetToken("ui-token") + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: cfg.Gateway.Host, + Port: cfg.Gateway.Port, + }) + t.Cleanup(func() { + ppid.RemovePidFile(globalConfigDir()) + }) + + origPidData := gateway.pidData + origPicoToken := gateway.picoToken + t.Cleanup(func() { + gateway.pidData = origPidData + gateway.picoToken = origPicoToken + }) + + gateway.pidData = &ppid.PidFileData{} + gateway.picoToken = "ui-token" + + req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil) + req.Header.Set("Origin", "http://evil.example") + rec := httptest.NewRecorder() + handler(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } +} + func mustGatewayTestPort(t *testing.T, rawURL string) int { t.Helper() diff --git a/web/backend/main.go b/web/backend/main.go index 01ef5edf0..e42558398 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -544,7 +544,7 @@ func main() { // API Routes (e.g. /api/status) apiHandler = api.NewHandler(absPath) apiHandler.SetDebug(debug) - if _, err = apiHandler.EnsurePicoChannel(""); err != nil { + if _, err = apiHandler.EnsurePicoChannel(); err != nil { logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err)) } apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs) diff --git a/web/backend/middleware/launcher_dashboard_auth.go b/web/backend/middleware/launcher_dashboard_auth.go index c1c4c19c6..d72bd0f00 100644 --- a/web/backend/middleware/launcher_dashboard_auth.go +++ b/web/backend/middleware/launcher_dashboard_auth.go @@ -218,6 +218,10 @@ func validLauncherDashboardAuth(r *http.Request, cfg LauncherDashboardAuthConfig } func rejectLauncherDashboardAuth(w http.ResponseWriter, r *http.Request, canonicalPath string) { + if canonicalPath == "/pico/ws" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } if strings.HasPrefix(canonicalPath, "/api/") { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) diff --git a/web/backend/middleware/launcher_dashboard_auth_test.go b/web/backend/middleware/launcher_dashboard_auth_test.go index 1b919bf96..7b7418998 100644 --- a/web/backend/middleware/launcher_dashboard_auth_test.go +++ b/web/backend/middleware/launcher_dashboard_auth_test.go @@ -40,6 +40,7 @@ func TestLauncherDashboardAuth_AllowsPublicPaths(t *testing.T) { {http.MethodPost, "/api/auth/logout", http.StatusTeapot}, {http.MethodGet, "/api/auth/logout", http.StatusUnauthorized}, {http.MethodGet, "/api/config", http.StatusUnauthorized}, + {http.MethodGet, "/pico/ws", http.StatusUnauthorized}, } { rec := httptest.NewRecorder() req := httptest.NewRequest(tc.method, tc.path, nil) @@ -160,3 +161,22 @@ func TestLauncherDashboardAuth_CookieAndBearer(t *testing.T) { t.Fatalf("bearer auth: status = %d", rec2.Code) } } + +func TestLauncherDashboardAuth_WebSocketUnauthorizedDoesNotRedirect(t *testing.T) { + cfg := LauncherDashboardAuthConfig{ExpectedCookie: "deadbeef", Token: "x"} + next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Fatal("next handler should not run without auth") + }) + h := LauncherDashboardAuth(cfg, next) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + if got := rec.Header().Get("Location"); got != "" { + t.Fatalf("Location = %q, want empty", got) + } +} diff --git a/web/frontend/src/api/pico.ts b/web/frontend/src/api/pico.ts index 6b8ceb49a..ca98a06da 100644 --- a/web/frontend/src/api/pico.ts +++ b/web/frontend/src/api/pico.ts @@ -2,16 +2,16 @@ import { launcherFetch } from "@/api/http" // API client for Pico Channel configuration. -interface PicoTokenResponse { - token: string +interface PicoInfoResponse { ws_url: string enabled: boolean + configured?: boolean } interface PicoSetupResponse { - token: string ws_url: string enabled: boolean + configured?: boolean changed: boolean } @@ -25,16 +25,16 @@ async function request(path: string, options?: RequestInit): Promise { return res.json() as Promise } -export async function getPicoToken(): Promise { - return request("/api/pico/token") +export async function getPicoInfo(): Promise { + return request("/api/pico/info") } -export async function regenPicoToken(): Promise { - return request("/api/pico/token", { method: "POST" }) +export async function regenPicoToken(): Promise { + return request("/api/pico/token", { method: "POST" }) } export async function setupPico(): Promise { return request("/api/pico/setup", { method: "POST" }) } -export type { PicoTokenResponse, PicoSetupResponse } +export type { PicoInfoResponse, PicoSetupResponse } diff --git a/web/frontend/src/features/chat/controller.ts b/web/frontend/src/features/chat/controller.ts index 28ef491fa..183b1ba6f 100644 --- a/web/frontend/src/features/chat/controller.ts +++ b/web/frontend/src/features/chat/controller.ts @@ -1,7 +1,6 @@ import { getDefaultStore } from "jotai" import { toast } from "sonner" -import { getPicoToken } from "@/api/pico" import { loadSessionMessages, mergeHistoryMessages, @@ -131,7 +130,6 @@ export async function connectChat() { updateChatStore({ connectionState: "connecting" }) try { - const { token } = await getPicoToken() const sessionId = activeSessionIdRef if (generation !== connectionGeneration) { @@ -139,18 +137,10 @@ export async function connectChat() { return } - if (!token) { - console.error("No pico token available") - updateChatStore({ connectionState: "error" }) - isConnecting = false - scheduleReconnect(generation, sessionId) - return - } - const wsScheme = window.location.protocol === "https:" ? "wss:" : "ws:" const wsUrl = `${wsScheme}//${window.location.host}/pico/ws` const url = `${wsUrl}?session_id=${encodeURIComponent(sessionId)}` - const socket = new WebSocket(url, [`token.${token}`]) + const socket = new WebSocket(url) if (generation !== connectionGeneration) { isConnecting = false diff --git a/web/frontend/vite.config.ts b/web/frontend/vite.config.ts index 0ef4e1415..57512c8b9 100644 --- a/web/frontend/vite.config.ts +++ b/web/frontend/vite.config.ts @@ -29,7 +29,7 @@ export default defineConfig({ target: "http://localhost:18800", changeOrigin: true, }, - "/ws": { + "/pico/ws": { target: "ws://localhost:18800", ws: true, },