mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(web): secure Pico websocket access behind launcher auth (#2545)
* refactor(web): secure Pico websocket access behind launcher auth - stop exposing the raw Pico token to the frontend - add /api/pico/info for non-secret Pico connection metadata - proxy /pico/ws through the launcher with same-origin and dashboard auth checks - inject the upstream Pico websocket protocol server-side - update frontend chat connection flow and Vite websocket proxy path - refresh related docs and tests * fix(web): improve Pico URL and origin handling behind proxies - read client scheme from X-Forwarded-Proto and RFC 7239 Forwarded - derive client-visible ports from forwarded host information - add coverage for HTTPS origins without explicit ports - verify behavior when proxies omit forwarded protocol headers * fix(web): stop pinning Pico WebSocket origins during setup - remove request-origin seeding from `EnsurePicoChannel` - keep `allow_origins` empty by default for auto-configured Pico channels - relax launcher Pico WebSocket proxy origin validation - update Pico backend tests for the new setup and proxy behavior
This commit is contained in:
@@ -27,7 +27,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d
|
||||
> **Docker Users**: By default, the Gateway listens on `127.0.0.1` which is not accessible from the host. If you need to access the health endpoints or expose ports, set `PICOCLAW_GATEWAY_HOST=0.0.0.0` in your environment or update `config.json`.
|
||||
|
||||
> [!NOTE]
|
||||
> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/token` and a `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled.
|
||||
> The `gateway` profile only serves the webhook handlers (including Pico when enabled) and health endpoints on the gateway port, so it does not expose generic REST chat endpoints such as `/chat` or `/a2a`. Launcher mode adds the browser UI plus `/api/pico/info` and an authenticated `/pico/ws` proxy on the launcher port, but `/pico/ws` is also available directly on the gateway whenever the Pico channel is enabled.
|
||||
|
||||
```bash
|
||||
# 5. Check logs
|
||||
|
||||
@@ -18,8 +18,6 @@ const (
|
||||
TypeError = "error"
|
||||
TypePong = "pong"
|
||||
|
||||
PicoTokenPrefix = "pico-"
|
||||
|
||||
PayloadKeyContent = "content"
|
||||
PayloadKeyThought = "thought"
|
||||
|
||||
|
||||
+1
-23
@@ -9,7 +9,6 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
@@ -27,7 +26,7 @@ import (
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/line"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
|
||||
"github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/teams_webhook"
|
||||
@@ -316,8 +315,6 @@ func executeReload(
|
||||
) error {
|
||||
defer runningServices.reloading.Store(false)
|
||||
|
||||
overridePicoToken(newCfg, runningServices.authToken)
|
||||
|
||||
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
|
||||
}
|
||||
|
||||
@@ -386,8 +383,6 @@ func setupAndStartServices(
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
overridePicoToken(cfg, authToken)
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
@@ -788,23 +783,6 @@ func setupCronTool(
|
||||
return cronService, nil
|
||||
}
|
||||
|
||||
// overridePicoToken replaces the pico channel token with the one from the PID file.
|
||||
// The PID file is the single source of truth for the pico auth token;
|
||||
// it is generated once at gateway startup and remains unchanged across reloads.
|
||||
func overridePicoToken(cfg *config.Config, token string) {
|
||||
picoBC := cfg.Channels.GetByType(config.ChannelPico)
|
||||
if picoBC == nil || !picoBC.Enabled {
|
||||
return
|
||||
}
|
||||
var picoCfg config.PicoSettings
|
||||
picoBC.Decode(&picoCfg)
|
||||
picoToken := picoCfg.Token.String()
|
||||
if picoToken == "" || strings.HasPrefix(picoToken, pico.PicoTokenPrefix) {
|
||||
return
|
||||
}
|
||||
picoCfg.SetToken(pico.PicoTokenPrefix + token + picoToken)
|
||||
}
|
||||
|
||||
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
|
||||
@@ -94,8 +94,6 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh cached pico token in case user changed it.
|
||||
refreshPicoToken(&cfg)
|
||||
h.applyRuntimeLogLevel()
|
||||
logger.Infof("configuration updated successfully")
|
||||
|
||||
@@ -193,8 +191,6 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh cached pico token in case user changed it.
|
||||
refreshPicoToken(&newCfg)
|
||||
h.applyRuntimeLogLevel()
|
||||
logger.Infof("configuration updated successfully")
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -37,28 +36,12 @@ var gateway = struct {
|
||||
startupDeadline time.Time
|
||||
logs *LogBuffer
|
||||
pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json
|
||||
picoToken string // cached pico token from config (for proxy auth validation)
|
||||
picoToken string // cached raw pico token for upstream gateway proxy injection
|
||||
}{
|
||||
runtimeStatus: "stopped",
|
||||
logs: NewLogBuffer(200),
|
||||
}
|
||||
|
||||
// refreshPicoToken updates gateway.picoToken from cfg
|
||||
func refreshPicoToken(cfg *config.Config) {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
var picoCfg config.PicoSettings
|
||||
if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil {
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err == nil && decoded != nil {
|
||||
if p, ok := decoded.(*config.PicoSettings); ok {
|
||||
picoCfg = *p
|
||||
}
|
||||
}
|
||||
}
|
||||
gateway.picoToken = picoCfg.Token.String()
|
||||
}
|
||||
|
||||
// refreshPicoTokensLocked reads the pico token from config and caches it.
|
||||
// Caller must hold gateway.mu (or be sole writer).
|
||||
func refreshPicoTokensLocked(configPath string) {
|
||||
@@ -101,18 +84,15 @@ const (
|
||||
tokenPrefix = "token."
|
||||
)
|
||||
|
||||
// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth.
|
||||
func picoComposedToken(token string) string {
|
||||
// picoGatewayProtocol returns the gateway-facing pico subprotocol that the
|
||||
// launcher should inject when proxying browser traffic upstream.
|
||||
func picoGatewayProtocol() string {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
// if not initial pico token, don't allow gateway auth
|
||||
if gateway.picoToken == "" || gateway.pidData == nil {
|
||||
if gateway.picoToken == "" {
|
||||
return ""
|
||||
}
|
||||
if tokenPrefix+gateway.picoToken != token {
|
||||
return ""
|
||||
}
|
||||
return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken
|
||||
return tokenPrefix + gateway.picoToken
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -752,7 +732,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
gateway.logs.Reset()
|
||||
|
||||
// Ensure Pico Channel is configured before starting gateway
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
|
||||
// Non-fatal: gateway can still start without pico channel
|
||||
|
||||
@@ -85,8 +85,22 @@ func requestHostName(r *http.Request) string {
|
||||
return netbind.ResolveAdaptiveLoopbackHost()
|
||||
}
|
||||
|
||||
func forwardedProtoFirst(r *http.Request) string {
|
||||
raw := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto"))
|
||||
if raw == "" {
|
||||
raw = forwardedRFC7239Proto(r)
|
||||
}
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.IndexByte(raw, ','); i >= 0 {
|
||||
raw = strings.TrimSpace(raw[:i])
|
||||
}
|
||||
return strings.ToLower(raw)
|
||||
}
|
||||
|
||||
func requestWSScheme(r *http.Request) string {
|
||||
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
|
||||
if forwarded := forwardedProtoFirst(r); forwarded != "" {
|
||||
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
|
||||
if proto == "https" || proto == "wss" {
|
||||
return "wss"
|
||||
@@ -105,7 +119,7 @@ func requestWSScheme(r *http.Request) string {
|
||||
|
||||
// requestHTTPScheme returns http or https for URLs that are not WebSockets (e.g. SSE).
|
||||
func requestHTTPScheme(r *http.Request) string {
|
||||
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
|
||||
if forwarded := forwardedProtoFirst(r); forwarded != "" {
|
||||
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
|
||||
if proto == "https" || proto == "wss" {
|
||||
return "https"
|
||||
@@ -117,6 +131,7 @@ func requestHTTPScheme(r *http.Request) string {
|
||||
if r.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
return "http"
|
||||
}
|
||||
|
||||
@@ -138,6 +153,14 @@ func forwardedHostFirst(r *http.Request) string {
|
||||
|
||||
// forwardedRFC7239Host parses host= from the first Forwarded header element (RFC 7239).
|
||||
func forwardedRFC7239Host(r *http.Request) string {
|
||||
return forwardedRFC7239Param(r, "host")
|
||||
}
|
||||
|
||||
func forwardedRFC7239Proto(r *http.Request) string {
|
||||
return forwardedRFC7239Param(r, "proto")
|
||||
}
|
||||
|
||||
func forwardedRFC7239Param(r *http.Request, key string) string {
|
||||
v := strings.TrimSpace(r.Header.Get("Forwarded"))
|
||||
if v == "" {
|
||||
return ""
|
||||
@@ -146,7 +169,7 @@ func forwardedRFC7239Host(r *http.Request) string {
|
||||
for _, part := range strings.Split(first, ";") {
|
||||
part = strings.TrimSpace(part)
|
||||
low := strings.ToLower(part)
|
||||
if !strings.HasPrefix(low, "host=") {
|
||||
if !strings.HasPrefix(low, key+"=") {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(part[strings.IndexByte(part, '=')+1:])
|
||||
@@ -177,13 +200,21 @@ func clientVisiblePort(r *http.Request, serverListenPort int) string {
|
||||
if p := forwardedPortFirst(r); p != "" {
|
||||
return p
|
||||
}
|
||||
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
|
||||
if _, port, err := net.SplitHostPort(fwdHost); err == nil && port != "" {
|
||||
return port
|
||||
}
|
||||
}
|
||||
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
|
||||
return port
|
||||
}
|
||||
if strings.TrimSpace(r.Host) == "" && forwardedHostFirst(r) == "" {
|
||||
return strconv.Itoa(serverListenPort)
|
||||
}
|
||||
if requestHTTPScheme(r) == "https" {
|
||||
return "443"
|
||||
}
|
||||
return strconv.Itoa(serverListenPort)
|
||||
return "80"
|
||||
}
|
||||
|
||||
// joinClientVisibleHostPort builds host:port for absolute URLs returned to the browser.
|
||||
@@ -205,16 +236,7 @@ func (h *Handler) picoWebUIAddr(r *http.Request) string {
|
||||
if fwdHost := forwardedHostFirst(r); fwdHost != "" {
|
||||
return joinClientVisibleHostPort(r, fwdHost, wsPort)
|
||||
}
|
||||
host := requestHostName(r)
|
||||
// Use clientVisiblePort only when an explicit port is present in headers
|
||||
// or Host header — do not infer from TLS/scheme, as serverPort takes priority.
|
||||
if p := forwardedPortFirst(r); p != "" {
|
||||
return net.JoinHostPort(host, p)
|
||||
}
|
||||
if _, port, err := net.SplitHostPort(r.Host); err == nil && port != "" {
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
return net.JoinHostPort(host, strconv.Itoa(wsPort))
|
||||
return joinClientVisibleHostPort(r, requestHostName(r), wsPort)
|
||||
}
|
||||
|
||||
func (h *Handler) buildWsURL(r *http.Request) string {
|
||||
|
||||
@@ -50,7 +50,7 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
|
||||
req.Host = "192.168.1.9:18800"
|
||||
|
||||
if got := h.buildWsURL(req); got != "ws://192.168.1.9:18800/pico/ws" {
|
||||
@@ -181,12 +181,12 @@ func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
|
||||
req.Host = "chat.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
if got := h.buildWsURL(req); got != "wss://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws")
|
||||
if got := h.buildWsURL(req); got != "wss://chat.example.com:443/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:443/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,12 +198,12 @@ func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil)
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
|
||||
if got := h.buildWsURL(req); got != "wss://secure.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws")
|
||||
if got := h.buildWsURL(req); got != "wss://secure.example.com:443/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:443/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ func TestBuildPicoURLsPreferXForwardedHost(t *testing.T) {
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/token", nil)
|
||||
req := httptest.NewRequest("GET", "http://127.0.0.1:18800/api/pico/info", nil)
|
||||
req.Host = "127.0.0.1:18800"
|
||||
req.Header.Set("X-Forwarded-Host", "vscode-tunnel.example.com")
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
@@ -249,13 +249,30 @@ func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/info", nil)
|
||||
req.Host = "chat.example.com"
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
|
||||
if got := h.buildWsURL(req); got != "ws://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws")
|
||||
if got := h.buildWsURL(req); got != "ws://chat.example.com:80/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:80/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLDoesNotTrustOriginWhenProxyOmitsForwardedProto(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/info", nil)
|
||||
req.Host = "fs-952210-xwj.picoclaw.lan.sipeed.com"
|
||||
req.Header.Set("Origin", "https://fs-952210-xwj.picoclaw.lan.sipeed.com")
|
||||
|
||||
if got := h.buildWsURL(req); got != "ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws" {
|
||||
t.Fatalf(
|
||||
"buildWsURL() = %q, want %q",
|
||||
got,
|
||||
"ws://fs-952210-xwj.picoclaw.lan.sipeed.com:80/pico/ws",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +281,7 @@ func TestBuildWsURLUsesRequestHostNotGatewayBindLoopback(t *testing.T) {
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/token", nil)
|
||||
req := httptest.NewRequest("GET", "http://localhost:18800/api/pico/info", nil)
|
||||
req.Host = "localhost:18800"
|
||||
|
||||
if got := h.buildWsURL(req); got != "ws://localhost:18800/pico/ws" {
|
||||
|
||||
@@ -121,6 +121,18 @@ func resetGatewayTestState(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPicoGatewayProtocol(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.picoToken = "ui-token"
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if got := picoGatewayProtocol(); got != tokenPrefix+"ui-token" {
|
||||
t.Fatalf("picoGatewayProtocol() = %q, want %q", got, tokenPrefix+"ui-token")
|
||||
}
|
||||
}
|
||||
|
||||
type gatewayStartEnvSnapshot struct {
|
||||
GatewayHost string `json:"gateway_host"`
|
||||
GatewayHostSet bool `json:"gateway_host_set"`
|
||||
|
||||
+72
-79
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
|
||||
func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken)
|
||||
mux.HandleFunc("GET /api/pico/info", h.handleGetPicoInfo)
|
||||
mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken)
|
||||
mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup)
|
||||
|
||||
@@ -28,12 +28,15 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
|
||||
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
|
||||
// The gateway bind host and port are resolved from the latest configuration.
|
||||
func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy {
|
||||
func (h *Handler) createWsProxy(origProtocol string, upstreamProtocol string) *httputil.ReverseProxy {
|
||||
wsProxy := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
target := h.gatewayProxyURL()
|
||||
r.SetURL(target)
|
||||
r.Out.Header.Set(protocolKey, tokenPrefix+token)
|
||||
r.Out.Header.Del(protocolKey)
|
||||
if upstreamProtocol != "" {
|
||||
r.Out.Header.Set(protocolKey, upstreamProtocol)
|
||||
}
|
||||
},
|
||||
ModifyResponse: func(r *http.Response) error {
|
||||
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
|
||||
@@ -52,8 +55,50 @@ func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.Rev
|
||||
return wsProxy
|
||||
}
|
||||
|
||||
func decodePicoSettings(cfg *config.Config) (config.PicoSettings, bool) {
|
||||
if cfg == nil {
|
||||
return config.PicoSettings{}, false
|
||||
}
|
||||
|
||||
bc := cfg.Channels.GetByType(config.ChannelPico)
|
||||
if bc == nil {
|
||||
return config.PicoSettings{}, false
|
||||
}
|
||||
|
||||
var picoCfg config.PicoSettings
|
||||
if err := bc.Decode(&picoCfg); err != nil {
|
||||
return config.PicoSettings{}, false
|
||||
}
|
||||
|
||||
return picoCfg, bc.Enabled
|
||||
}
|
||||
|
||||
func (h *Handler) writePicoInfoResponse(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
cfg *config.Config,
|
||||
changed *bool,
|
||||
) {
|
||||
picoCfg, enabled := decodePicoSettings(cfg)
|
||||
|
||||
resp := map[string]any{
|
||||
"ws_url": h.buildWsURL(r),
|
||||
"enabled": enabled,
|
||||
}
|
||||
if changed != nil {
|
||||
resp["changed"] = *changed
|
||||
}
|
||||
if picoCfg.Token.String() != "" {
|
||||
resp["configured"] = true
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||
// It validates the client token before forwarding; rejects immediately on failure.
|
||||
// It relies on launcher dashboard auth, then injects the raw pico token only
|
||||
// on the upstream gateway request.
|
||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.mu.Lock()
|
||||
@@ -91,51 +136,38 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
prot := r.Header.Values(protocolKey)
|
||||
if len(prot) > 0 {
|
||||
origProtocol := prot[0]
|
||||
newToken := picoComposedToken(prot[0])
|
||||
if newToken != "" {
|
||||
h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
upstreamProtocol := picoGatewayProtocol()
|
||||
if upstreamProtocol == "" {
|
||||
logger.Warn("Pico token unavailable for WebSocket proxy")
|
||||
http.Error(w, "Pico channel not configured", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Warnf("Invalid Pico token: %v", prot)
|
||||
http.Error(w, "Invalid Pico token", http.StatusForbidden)
|
||||
var origProtocol string
|
||||
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
|
||||
origProtocol = prot[0]
|
||||
}
|
||||
|
||||
h.createWsProxy(origProtocol, upstreamProtocol).ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetPicoToken returns the current WS token and URL for the frontend.
|
||||
// handleGetPicoInfo returns non-secret Pico connection info for the launcher UI.
|
||||
//
|
||||
// GET /api/pico/token
|
||||
func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
// GET /api/pico/info
|
||||
func (h *Handler) handleGetPicoInfo(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
wsURL := h.buildWsURL(r)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
bc := cfg.Channels.GetByType(config.ChannelPico)
|
||||
var picoCfg config.PicoSettings
|
||||
if bc != nil {
|
||||
bc.Decode(&picoCfg)
|
||||
}
|
||||
enabled := false
|
||||
if bc != nil {
|
||||
enabled = bc.Enabled
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": picoCfg.Token.String(),
|
||||
"ws_url": wsURL,
|
||||
"enabled": enabled,
|
||||
})
|
||||
h.writePicoInfoResponse(w, r, cfg, nil)
|
||||
}
|
||||
|
||||
// handleRegenPicoToken generates a new Pico WebSocket token and saves it.
|
||||
// handleRegenPicoToken rotates the raw Pico WebSocket token and returns
|
||||
// non-secret connection info for the launcher UI.
|
||||
//
|
||||
// POST /api/pico/token
|
||||
func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -160,28 +192,12 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh cached pico token.
|
||||
gateway.mu.Lock()
|
||||
gateway.picoToken = token
|
||||
gateway.mu.Unlock()
|
||||
|
||||
wsURL := h.buildWsURL(r)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": token,
|
||||
"ws_url": wsURL,
|
||||
})
|
||||
h.writePicoInfoResponse(w, r, cfg, nil)
|
||||
}
|
||||
|
||||
// EnsurePicoChannel enables the Pico channel with sane defaults if it isn't
|
||||
// already configured. Returns true when the config was modified.
|
||||
//
|
||||
// callerOrigin is the Origin header from the setup request. If non-empty and
|
||||
// no origins are configured yet, it's written as the allowed origin so the
|
||||
// WebSocket handshake works for whatever host the caller is on (LAN, custom
|
||||
// port, etc.). Pass "" when there's no request context.
|
||||
func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
||||
func (h *Handler) EnsurePicoChannel() (bool, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load config: %w", err)
|
||||
@@ -206,12 +222,6 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
||||
picoCfg.Token = *config.NewSecureString(generateSecureToken())
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Seed origins from the request instead of hardcoding ports.
|
||||
if len(picoCfg.AllowOrigins) == 0 && callerOrigin != "" {
|
||||
picoCfg.AllowOrigins = []string{callerOrigin}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,37 +238,20 @@ func (h *Handler) EnsurePicoChannel(callerOrigin string) (bool, error) {
|
||||
//
|
||||
// POST /api/pico/setup
|
||||
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
changed, err := h.EnsurePicoChannel(r.Header.Get("Origin"))
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Reload config (EnsurePicoChannel may have modified it) and refresh cache.
|
||||
// Reload config (EnsurePicoChannel may have modified it).
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if changed {
|
||||
refreshPicoToken(cfg)
|
||||
}
|
||||
|
||||
wsURL := h.buildWsURL(r)
|
||||
|
||||
var picoCfg2 config.PicoSettings
|
||||
if bc := cfg.Channels.GetByType(config.ChannelPico); bc != nil {
|
||||
if decoded, err := bc.GetDecoded(); err == nil && decoded != nil {
|
||||
picoCfg2 = *decoded.(*config.PicoSettings)
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": picoCfg2.Token.String(),
|
||||
"ws_url": wsURL,
|
||||
"enabled": true,
|
||||
"changed": changed,
|
||||
})
|
||||
h.writePicoInfoResponse(w, r, cfg, &changed)
|
||||
}
|
||||
|
||||
// generateSecureToken creates a random 32-character hex string.
|
||||
|
||||
+142
-71
@@ -11,16 +11,21 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
)
|
||||
|
||||
func newPicoProxyRequest(method, path string) *http.Request {
|
||||
req := httptest.NewRequest(method, "http://launcher.local:18800"+path, nil)
|
||||
req.Header.Set("Origin", "http://launcher.local:18800")
|
||||
return req
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -51,7 +56,7 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -71,11 +76,11 @@ func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
||||
func TestEnsurePicoChannel_LeavesAllowOriginsEmptyByDefault(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel("http://localhost:18800"); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -90,45 +95,16 @@ func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
for _, origin := range picoCfg.AllowOrigins {
|
||||
if origin == "*" {
|
||||
t.Error("setup must not set wildcard origin '*'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
bc := cfg.Channels["pico"]
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
// Without a caller origin, allow_origins stays empty (CheckOrigin
|
||||
// allows all when the list is empty, so the channel still works).
|
||||
if len(picoCfg.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty when no caller origin", picoCfg.AllowOrigins)
|
||||
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
||||
func TestEnsurePicoChannel_NoOriginConfigurationRequired(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
lanOrigin := "http://192.168.1.9:18800"
|
||||
if _, err := h.EnsurePicoChannel(lanOrigin); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -143,8 +119,8 @@ func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != lanOrigin {
|
||||
t.Errorf("allow_origins = %v, want [%s]", picoCfg.AllowOrigins, lanOrigin)
|
||||
if len(picoCfg.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,7 +145,7 @@ func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
|
||||
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -213,7 +189,7 @@ func TestEnsurePicoChannel_ExistingConfigWithoutSecurityFile(t *testing.T) {
|
||||
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -253,7 +229,7 @@ func TestEnsurePicoChannel_ConfiguresPicoWithoutGateway(t *testing.T) {
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -280,10 +256,8 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
origin := "http://localhost:18800"
|
||||
|
||||
// First call sets things up
|
||||
if _, err := h.EnsurePicoChannel(origin); err != nil {
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("first EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -297,7 +271,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
token1 := picoCfg.Token.String()
|
||||
|
||||
// Second call should be a no-op
|
||||
changed, err := h.EnsurePicoChannel(origin)
|
||||
changed, err := h.EnsurePicoChannel()
|
||||
if err != nil {
|
||||
t.Fatalf("second EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
@@ -317,7 +291,7 @@ func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
||||
func TestHandlePicoSetup_DoesNotPersistRequestOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
@@ -342,8 +316,8 @@ func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
picoCfg := decoded.(*config.PicoSettings)
|
||||
if len(picoCfg.AllowOrigins) != 1 || picoCfg.AllowOrigins[0] != "http://10.0.0.5:3000" {
|
||||
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", picoCfg.AllowOrigins)
|
||||
if len(picoCfg.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty", picoCfg.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -365,8 +339,8 @@ func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp["token"] == nil || resp["token"] == "" {
|
||||
t.Error("response should contain a non-empty token")
|
||||
if _, ok := resp["token"]; ok {
|
||||
t.Error("response must not expose the raw pico token")
|
||||
}
|
||||
if resp["ws_url"] == nil || resp["ws_url"] == "" {
|
||||
t.Error("response should contain ws_url")
|
||||
@@ -377,6 +351,45 @@ func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
if resp["changed"] != true {
|
||||
t.Error("response should have changed=true on first setup")
|
||||
}
|
||||
if resp["configured"] != true {
|
||||
t.Error("response should have configured=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetPicoInfo_OmitsToken(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.EnsurePicoChannel(); err != nil {
|
||||
t.Fatalf("EnsurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/api/pico/info", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleGetPicoInfo(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := resp["token"]; ok {
|
||||
t.Fatal("info response must not expose the raw pico token")
|
||||
}
|
||||
if resp["enabled"] != true {
|
||||
t.Fatalf("enabled = %#v, want true", resp["enabled"])
|
||||
}
|
||||
if resp["configured"] != true {
|
||||
t.Fatalf("configured = %#v, want true", resp["configured"])
|
||||
}
|
||||
if resp["ws_url"] == nil || resp["ws_url"] == "" {
|
||||
t.Fatal("response should contain ws_url")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
@@ -438,20 +451,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
|
||||
gateway.pidData = &ppid.PidFileData{}
|
||||
gateway.picoToken = "pico"
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
req1.Header.Set(protocolKey, tokenPrefix+"wrong_token")
|
||||
req1 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusForbidden {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden)
|
||||
}
|
||||
|
||||
req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
req1.Header.Set(protocolKey, tokenPrefix+"pico")
|
||||
rec1 = httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
|
||||
}
|
||||
@@ -464,8 +467,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
req2.Header.Set(protocolKey, tokenPrefix+"pico")
|
||||
req2 := newPicoProxyRequest(http.MethodGet, "/pico/ws")
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler(rec2, req2)
|
||||
|
||||
@@ -539,8 +541,7 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
gateway.pidData = &ppid.PidFileData{}
|
||||
gateway.picoToken = ""
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
|
||||
req.Header.Set(protocolKey, tokenPrefix+"cached-token")
|
||||
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
@@ -625,8 +626,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
|
||||
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
|
||||
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
@@ -634,7 +634,7 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token"
|
||||
expected := tokenPrefix + "ui-token"
|
||||
if got := rec.Body.String(); got != expected {
|
||||
t.Fatalf("forwarded protocol = %q, want %q", got, expected)
|
||||
}
|
||||
@@ -696,8 +696,7 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
|
||||
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
|
||||
req := newPicoProxyRequest(http.MethodGet, "/pico/ws?session_id=test-session")
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
@@ -711,6 +710,78 @@ func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxy_AllowsArbitraryOrigin(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
||||
bc := cfg.Channels["pico"]
|
||||
bc.Enabled = true
|
||||
decoded, err := bc.GetDecoded()
|
||||
if err != nil {
|
||||
t.Fatalf("GetDecoded() error = %v", err)
|
||||
}
|
||||
decoded.(*config.PicoSettings).SetToken("ui-token")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
t.Cleanup(func() {
|
||||
gateway.pidData = origPidData
|
||||
gateway.picoToken = origPicoToken
|
||||
})
|
||||
|
||||
gateway.pidData = &ppid.PidFileData{}
|
||||
gateway.picoToken = "ui-token"
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://launcher.local/pico/ws?session_id=test-session", nil)
|
||||
req.Header.Set("Origin", "http://evil.example")
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+1
-1
@@ -544,7 +544,7 @@ func main() {
|
||||
// API Routes (e.g. /api/status)
|
||||
apiHandler = api.NewHandler(absPath)
|
||||
apiHandler.SetDebug(debug)
|
||||
if _, err = apiHandler.EnsurePicoChannel(""); err != nil {
|
||||
if _, err = apiHandler.EnsurePicoChannel(); err != nil {
|
||||
logger.ErrorC("web", fmt.Sprintf("Warning: failed to ensure pico channel on startup: %v", err))
|
||||
}
|
||||
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
|
||||
|
||||
@@ -218,6 +218,10 @@ func validLauncherDashboardAuth(r *http.Request, cfg LauncherDashboardAuthConfig
|
||||
}
|
||||
|
||||
func rejectLauncherDashboardAuth(w http.ResponseWriter, r *http.Request, canonicalPath string) {
|
||||
if canonicalPath == "/pico/ws" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(canonicalPath, "/api/") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
@@ -40,6 +40,7 @@ func TestLauncherDashboardAuth_AllowsPublicPaths(t *testing.T) {
|
||||
{http.MethodPost, "/api/auth/logout", http.StatusTeapot},
|
||||
{http.MethodGet, "/api/auth/logout", http.StatusUnauthorized},
|
||||
{http.MethodGet, "/api/config", http.StatusUnauthorized},
|
||||
{http.MethodGet, "/pico/ws", http.StatusUnauthorized},
|
||||
} {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil)
|
||||
@@ -160,3 +161,22 @@ func TestLauncherDashboardAuth_CookieAndBearer(t *testing.T) {
|
||||
t.Fatalf("bearer auth: status = %d", rec2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLauncherDashboardAuth_WebSocketUnauthorizedDoesNotRedirect(t *testing.T) {
|
||||
cfg := LauncherDashboardAuthConfig{ExpectedCookie: "deadbeef", Token: "x"}
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
t.Fatal("next handler should not run without auth")
|
||||
})
|
||||
h := LauncherDashboardAuth(cfg, next)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
if got := rec.Header().Get("Location"); got != "" {
|
||||
t.Fatalf("Location = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,16 +2,16 @@ import { launcherFetch } from "@/api/http"
|
||||
|
||||
// API client for Pico Channel configuration.
|
||||
|
||||
interface PicoTokenResponse {
|
||||
token: string
|
||||
interface PicoInfoResponse {
|
||||
ws_url: string
|
||||
enabled: boolean
|
||||
configured?: boolean
|
||||
}
|
||||
|
||||
interface PicoSetupResponse {
|
||||
token: string
|
||||
ws_url: string
|
||||
enabled: boolean
|
||||
configured?: boolean
|
||||
changed: boolean
|
||||
}
|
||||
|
||||
@@ -25,16 +25,16 @@ async function request<T>(path: string, options?: RequestInit): Promise<T> {
|
||||
return res.json() as Promise<T>
|
||||
}
|
||||
|
||||
export async function getPicoToken(): Promise<PicoTokenResponse> {
|
||||
return request<PicoTokenResponse>("/api/pico/token")
|
||||
export async function getPicoInfo(): Promise<PicoInfoResponse> {
|
||||
return request<PicoInfoResponse>("/api/pico/info")
|
||||
}
|
||||
|
||||
export async function regenPicoToken(): Promise<PicoTokenResponse> {
|
||||
return request<PicoTokenResponse>("/api/pico/token", { method: "POST" })
|
||||
export async function regenPicoToken(): Promise<PicoInfoResponse> {
|
||||
return request<PicoInfoResponse>("/api/pico/token", { method: "POST" })
|
||||
}
|
||||
|
||||
export async function setupPico(): Promise<PicoSetupResponse> {
|
||||
return request<PicoSetupResponse>("/api/pico/setup", { method: "POST" })
|
||||
}
|
||||
|
||||
export type { PicoTokenResponse, PicoSetupResponse }
|
||||
export type { PicoInfoResponse, PicoSetupResponse }
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { getDefaultStore } from "jotai"
|
||||
import { toast } from "sonner"
|
||||
|
||||
import { getPicoToken } from "@/api/pico"
|
||||
import {
|
||||
loadSessionMessages,
|
||||
mergeHistoryMessages,
|
||||
@@ -131,7 +130,6 @@ export async function connectChat() {
|
||||
updateChatStore({ connectionState: "connecting" })
|
||||
|
||||
try {
|
||||
const { token } = await getPicoToken()
|
||||
const sessionId = activeSessionIdRef
|
||||
|
||||
if (generation !== connectionGeneration) {
|
||||
@@ -139,18 +137,10 @@ export async function connectChat() {
|
||||
return
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
console.error("No pico token available")
|
||||
updateChatStore({ connectionState: "error" })
|
||||
isConnecting = false
|
||||
scheduleReconnect(generation, sessionId)
|
||||
return
|
||||
}
|
||||
|
||||
const wsScheme = window.location.protocol === "https:" ? "wss:" : "ws:"
|
||||
const wsUrl = `${wsScheme}//${window.location.host}/pico/ws`
|
||||
const url = `${wsUrl}?session_id=${encodeURIComponent(sessionId)}`
|
||||
const socket = new WebSocket(url, [`token.${token}`])
|
||||
const socket = new WebSocket(url)
|
||||
|
||||
if (generation !== connectionGeneration) {
|
||||
isConnecting = false
|
||||
|
||||
@@ -29,7 +29,7 @@ export default defineConfig({
|
||||
target: "http://localhost:18800",
|
||||
changeOrigin: true,
|
||||
},
|
||||
"/ws": {
|
||||
"/pico/ws": {
|
||||
target: "ws://localhost:18800",
|
||||
ws: true,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user