diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 8d8b62a67..206e71f92 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { return } - conn, err := c.upgrader.Upgrade(w, r, nil) + // Echo the matched subprotocol back so the browser accepts the upgrade. + var responseHeader http.Header + if proto := c.matchedSubprotocol(r); proto != "" { + responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}} + } + + conn, err := c.upgrader.Upgrade(w, r, responseHeader) if err != nil { logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{ "error": err.Error(), @@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { go c.readLoop(pc) } -// authenticate checks the Bearer token from the Authorization header. -// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled. +// authenticate checks the request for a valid token: +// 1. Authorization: Bearer header +// 2. Sec-WebSocket-Protocol "token." (for browsers that can't set headers) +// 3. Query parameter "token" (only when AllowTokenQuery is on) func (c *PicoChannel) authenticate(r *http.Request) bool { token := c.config.Token if token == "" { @@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool { } } + // Check Sec-WebSocket-Protocol subprotocol ("token.") + if c.matchedSubprotocol(r) != "" { + return true + } + // Check query parameter only when explicitly allowed if c.config.AllowTokenQuery { if r.URL.Query().Get("token") == token { @@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool { return false } +// matchedSubprotocol returns the "token." subprotocol that matches +// the configured token, or "" if none do. +func (c *PicoChannel) matchedSubprotocol(r *http.Request) string { + token := c.config.Token + for _, proto := range websocket.Subprotocols(r) { + if after, ok := strings.CutPrefix(proto, "token."); ok && after == token { + return proto + } + } + return "" +} + // readLoop reads messages from a WebSocket connection. func (c *PicoChannel) readLoop(pc *picoConn) { defer func() { diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 1813cac92..f50f7609a 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -281,7 +281,7 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) { gateway.logs.Reset() // Ensure Pico Channel is configured before starting gateway - if _, err := h.ensurePicoChannel(); err != nil { + if _, err := h.ensurePicoChannel(""); err != nil { log.Printf("Warning: failed to ensure pico channel: %v", err) // Non-fatal: gateway can still start without pico channel } diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index a4590dcde..2d2201e16 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -65,9 +65,14 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { }) } -// ensurePicoChannel checks if the Pico Channel is properly configured and -// enables it with sensible defaults if not. Returns true if config was changed. -func (h *Handler) ensurePicoChannel() (bool, error) { +// ensurePicoChannel enables the Pico channel with sane defaults if it isn't +// already configured. Returns true when the config was modified. +// +// callerOrigin is the Origin header from the setup request. If non-empty and +// no origins are configured yet, it's written as the allowed origin so the +// WebSocket handshake works for whatever host the caller is on (LAN, custom +// port, etc.). Pass "" when there's no request context. +func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) { cfg, err := config.LoadConfig(h.configPath) if err != nil { return false, fmt.Errorf("failed to load config: %w", err) @@ -85,14 +90,9 @@ func (h *Handler) ensurePicoChannel() (bool, error) { changed = true } - if !cfg.Channels.Pico.AllowTokenQuery { - cfg.Channels.Pico.AllowTokenQuery = true - changed = true - } - - // Make sure origins are allowed (frontend might be running on a different port like 5173 during dev) - if len(cfg.Channels.Pico.AllowOrigins) == 0 { - cfg.Channels.Pico.AllowOrigins = []string{"*"} + // Seed origins from the request instead of hardcoding ports. + if len(cfg.Channels.Pico.AllowOrigins) == 0 && callerOrigin != "" { + cfg.Channels.Pico.AllowOrigins = []string{callerOrigin} changed = true } @@ -109,7 +109,7 @@ func (h *Handler) ensurePicoChannel() (bool, error) { // // POST /api/pico/setup func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { - changed, err := h.ensurePicoChannel() + changed, err := h.ensurePicoChannel(r.Header.Get("Origin")) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go new file mode 100644 index 000000000..46149fa09 --- /dev/null +++ b/web/backend/api/pico_test.go @@ -0,0 +1,237 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestEnsurePicoChannel_FreshConfig(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + changed, err := h.ensurePicoChannel("") + if err != nil { + t.Fatalf("ensurePicoChannel() error = %v", err) + } + if !changed { + t.Fatal("ensurePicoChannel() should report changed on a fresh config") + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if !cfg.Channels.Pico.Enabled { + t.Error("expected Pico to be enabled after setup") + } + if cfg.Channels.Pico.Token == "" { + t.Error("expected a non-empty token after setup") + } +} + +func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + if _, err := h.ensurePicoChannel(""); err != nil { + t.Fatalf("ensurePicoChannel() error = %v", err) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if cfg.Channels.Pico.AllowTokenQuery { + t.Error("setup must not enable allow_token_query by default") + } +} + +func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + if _, err := h.ensurePicoChannel("http://localhost:18800"); err != nil { + t.Fatalf("ensurePicoChannel() error = %v", err) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + for _, origin := range cfg.Channels.Pico.AllowOrigins { + if origin == "*" { + t.Error("setup must not set wildcard origin '*'") + } + } +} + +func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + if _, err := h.ensurePicoChannel(""); err != nil { + t.Fatalf("ensurePicoChannel() error = %v", err) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + // Without a caller origin, allow_origins stays empty (CheckOrigin + // allows all when the list is empty, so the channel still works). + if len(cfg.Channels.Pico.AllowOrigins) != 0 { + t.Errorf("allow_origins = %v, want empty when no caller origin", cfg.Channels.Pico.AllowOrigins) + } +} + +func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + lanOrigin := "http://192.168.1.9:18800" + if _, err := h.ensurePicoChannel(lanOrigin); err != nil { + t.Fatalf("ensurePicoChannel() error = %v", err) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != lanOrigin { + t.Errorf("allow_origins = %v, want [%s]", cfg.Channels.Pico.AllowOrigins, lanOrigin) + } +} + +func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + + // Pre-configure with custom user settings + cfg := config.DefaultConfig() + cfg.Channels.Pico.Enabled = true + cfg.Channels.Pico.Token = "user-custom-token" + cfg.Channels.Pico.AllowTokenQuery = true + cfg.Channels.Pico.AllowOrigins = []string{"https://myapp.example.com"} + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + + changed, err := h.ensurePicoChannel("") + if err != nil { + t.Fatalf("ensurePicoChannel() error = %v", err) + } + if changed { + t.Error("ensurePicoChannel() should not change a fully configured config") + } + + cfg, err = config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if cfg.Channels.Pico.Token != "user-custom-token" { + t.Errorf("token = %q, want %q", cfg.Channels.Pico.Token, "user-custom-token") + } + if !cfg.Channels.Pico.AllowTokenQuery { + t.Error("user's allow_token_query=true must be preserved") + } + if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "https://myapp.example.com" { + t.Errorf("allow_origins = %v, want [https://myapp.example.com]", cfg.Channels.Pico.AllowOrigins) + } +} + +func TestEnsurePicoChannel_Idempotent(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + origin := "http://localhost:18800" + + // First call sets things up + if _, err := h.ensurePicoChannel(origin); err != nil { + t.Fatalf("first ensurePicoChannel() error = %v", err) + } + + cfg1, _ := config.LoadConfig(configPath) + token1 := cfg1.Channels.Pico.Token + + // Second call should be a no-op + changed, err := h.ensurePicoChannel(origin) + if err != nil { + t.Fatalf("second ensurePicoChannel() error = %v", err) + } + if changed { + t.Error("second ensurePicoChannel() should not report changed") + } + + cfg2, _ := config.LoadConfig(configPath) + if cfg2.Channels.Pico.Token != token1 { + t.Error("token should not change on subsequent calls") + } +} + +func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest("POST", "/api/pico/setup", nil) + req.Header.Set("Origin", "http://10.0.0.5:3000") + rec := httptest.NewRecorder() + + h.handlePicoSetup(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "http://10.0.0.5:3000" { + t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", cfg.Channels.Pico.AllowOrigins) + } +} + +func TestHandlePicoSetup_Response(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + req := httptest.NewRequest("POST", "/api/pico/setup", nil) + rec := httptest.NewRecorder() + + h.handlePicoSetup(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if resp["token"] == nil || resp["token"] == "" { + t.Error("response should contain a non-empty token") + } + if resp["ws_url"] == nil || resp["ws_url"] == "" { + t.Error("response should contain ws_url") + } + if resp["enabled"] != true { + t.Error("response should have enabled=true") + } + if resp["changed"] != true { + t.Error("response should have changed=true on first setup") + } +} diff --git a/web/frontend/src/lib/pico-chat-controller.ts b/web/frontend/src/lib/pico-chat-controller.ts index be3397bae..0e77d1ad0 100644 --- a/web/frontend/src/lib/pico-chat-controller.ts +++ b/web/frontend/src/lib/pico-chat-controller.ts @@ -165,8 +165,9 @@ export async function connectChat() { console.warn("Could not parse ws_url:", error) } - const url = `${finalWsUrl}?token=${encodeURIComponent(token)}&session_id=${encodeURIComponent(activeSessionIdRef)}` - const socket = new WebSocket(url) + const url = `${finalWsUrl}?session_id=${encodeURIComponent(activeSessionIdRef)}` + // Send token as a subprotocol so it doesn't end up in the URL. + const socket = new WebSocket(url, [`token.${token}`]) if (generation !== connectionGeneration) { socket.close()