mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix: Use secure defaults for Pico channel setup and stop leaking the token in the URL (#1563)
* fix: Use secure defaults for Pico channel setup and stop leaking the token in the URL * fix: Derive default allow_origins from the setup request's Origin header instead of hardcoding localhost ports
This commit is contained in:
@@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, nil)
|
||||
// Echo the matched subprotocol back so the browser accepts the upgrade.
|
||||
var responseHeader http.Header
|
||||
if proto := c.matchedSubprotocol(r); proto != "" {
|
||||
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
|
||||
if err != nil {
|
||||
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
go c.readLoop(pc)
|
||||
}
|
||||
|
||||
// authenticate checks the Bearer token from the Authorization header.
|
||||
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
|
||||
// authenticate checks the request for a valid token:
|
||||
// 1. Authorization: Bearer <token> header
|
||||
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
|
||||
// 3. Query parameter "token" (only when AllowTokenQuery is on)
|
||||
func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
token := c.config.Token
|
||||
if token == "" {
|
||||
@@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
|
||||
if c.matchedSubprotocol(r) != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check query parameter only when explicitly allowed
|
||||
if c.config.AllowTokenQuery {
|
||||
if r.URL.Query().Get("token") == token {
|
||||
@@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
|
||||
// the configured token, or "" if none do.
|
||||
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
|
||||
token := c.config.Token
|
||||
for _, proto := range websocket.Subprotocols(r) {
|
||||
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
|
||||
return proto
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readLoop reads messages from a WebSocket connection.
|
||||
func (c *PicoChannel) readLoop(pc *picoConn) {
|
||||
defer func() {
|
||||
|
||||
@@ -281,7 +281,7 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
|
||||
gateway.logs.Reset()
|
||||
|
||||
// Ensure Pico Channel is configured before starting gateway
|
||||
if _, err := h.ensurePicoChannel(); err != nil {
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
log.Printf("Warning: failed to ensure pico channel: %v", err)
|
||||
// Non-fatal: gateway can still start without pico channel
|
||||
}
|
||||
|
||||
+12
-12
@@ -65,9 +65,14 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// ensurePicoChannel checks if the Pico Channel is properly configured and
|
||||
// enables it with sensible defaults if not. Returns true if config was changed.
|
||||
func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
// ensurePicoChannel enables the Pico channel with sane defaults if it isn't
|
||||
// already configured. Returns true when the config was modified.
|
||||
//
|
||||
// callerOrigin is the Origin header from the setup request. If non-empty and
|
||||
// no origins are configured yet, it's written as the allowed origin so the
|
||||
// WebSocket handshake works for whatever host the caller is on (LAN, custom
|
||||
// port, etc.). Pass "" when there's no request context.
|
||||
func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load config: %w", err)
|
||||
@@ -85,14 +90,9 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !cfg.Channels.Pico.AllowTokenQuery {
|
||||
cfg.Channels.Pico.AllowTokenQuery = true
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Make sure origins are allowed (frontend might be running on a different port like 5173 during dev)
|
||||
if len(cfg.Channels.Pico.AllowOrigins) == 0 {
|
||||
cfg.Channels.Pico.AllowOrigins = []string{"*"}
|
||||
// Seed origins from the request instead of hardcoding ports.
|
||||
if len(cfg.Channels.Pico.AllowOrigins) == 0 && callerOrigin != "" {
|
||||
cfg.Channels.Pico.AllowOrigins = []string{callerOrigin}
|
||||
changed = true
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
//
|
||||
// POST /api/pico/setup
|
||||
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
changed, err := h.ensurePicoChannel()
|
||||
changed, err := h.ensurePicoChannel(r.Header.Get("Origin"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.ensurePicoChannel("")
|
||||
if err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatal("ensurePicoChannel() should report changed on a fresh config")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Channels.Pico.Enabled {
|
||||
t.Error("expected Pico to be enabled after setup")
|
||||
}
|
||||
if cfg.Channels.Pico.Token == "" {
|
||||
t.Error("expected a non-empty token after setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Channels.Pico.AllowTokenQuery {
|
||||
t.Error("setup must not enable allow_token_query by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel("http://localhost:18800"); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
for _, origin := range cfg.Channels.Pico.AllowOrigins {
|
||||
if origin == "*" {
|
||||
t.Error("setup must not set wildcard origin '*'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Without a caller origin, allow_origins stays empty (CheckOrigin
|
||||
// allows all when the list is empty, so the channel still works).
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty when no caller origin", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
lanOrigin := "http://192.168.1.9:18800"
|
||||
if _, err := h.ensurePicoChannel(lanOrigin); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != lanOrigin {
|
||||
t.Errorf("allow_origins = %v, want [%s]", cfg.Channels.Pico.AllowOrigins, lanOrigin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Pre-configure with custom user settings
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels.Pico.Enabled = true
|
||||
cfg.Channels.Pico.Token = "user-custom-token"
|
||||
cfg.Channels.Pico.AllowTokenQuery = true
|
||||
cfg.Channels.Pico.AllowOrigins = []string{"https://myapp.example.com"}
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.ensurePicoChannel("")
|
||||
if err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("ensurePicoChannel() should not change a fully configured config")
|
||||
}
|
||||
|
||||
cfg, err = config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Channels.Pico.Token != "user-custom-token" {
|
||||
t.Errorf("token = %q, want %q", cfg.Channels.Pico.Token, "user-custom-token")
|
||||
}
|
||||
if !cfg.Channels.Pico.AllowTokenQuery {
|
||||
t.Error("user's allow_token_query=true must be preserved")
|
||||
}
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "https://myapp.example.com" {
|
||||
t.Errorf("allow_origins = %v, want [https://myapp.example.com]", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
origin := "http://localhost:18800"
|
||||
|
||||
// First call sets things up
|
||||
if _, err := h.ensurePicoChannel(origin); err != nil {
|
||||
t.Fatalf("first ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg1, _ := config.LoadConfig(configPath)
|
||||
token1 := cfg1.Channels.Pico.Token
|
||||
|
||||
// Second call should be a no-op
|
||||
changed, err := h.ensurePicoChannel(origin)
|
||||
if err != nil {
|
||||
t.Fatalf("second ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("second ensurePicoChannel() should not report changed")
|
||||
}
|
||||
|
||||
cfg2, _ := config.LoadConfig(configPath)
|
||||
if cfg2.Channels.Pico.Token != token1 {
|
||||
t.Error("token should not change on subsequent calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
||||
req.Header.Set("Origin", "http://10.0.0.5:3000")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handlePicoSetup(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "http://10.0.0.5:3000" {
|
||||
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handlePicoSetup(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp["token"] == nil || resp["token"] == "" {
|
||||
t.Error("response should contain a non-empty token")
|
||||
}
|
||||
if resp["ws_url"] == nil || resp["ws_url"] == "" {
|
||||
t.Error("response should contain ws_url")
|
||||
}
|
||||
if resp["enabled"] != true {
|
||||
t.Error("response should have enabled=true")
|
||||
}
|
||||
if resp["changed"] != true {
|
||||
t.Error("response should have changed=true on first setup")
|
||||
}
|
||||
}
|
||||
@@ -165,8 +165,9 @@ export async function connectChat() {
|
||||
console.warn("Could not parse ws_url:", error)
|
||||
}
|
||||
|
||||
const url = `${finalWsUrl}?token=${encodeURIComponent(token)}&session_id=${encodeURIComponent(activeSessionIdRef)}`
|
||||
const socket = new WebSocket(url)
|
||||
const url = `${finalWsUrl}?session_id=${encodeURIComponent(activeSessionIdRef)}`
|
||||
// Send token as a subprotocol so it doesn't end up in the URL.
|
||||
const socket = new WebSocket(url, [`token.${token}`])
|
||||
|
||||
if (generation !== connectionGeneration) {
|
||||
socket.close()
|
||||
|
||||
Reference in New Issue
Block a user