mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
71e2b636d6
* fix: Use secure defaults for Pico channel setup and stop leaking the token in the URL * fix: Derive default allow_origins from the setup request's Origin header instead of hardcoding localhost ports
488 lines
12 KiB
Go
488 lines
12 KiB
Go
package pico
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/channels"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/identity"
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
// picoConn represents a single WebSocket connection.
|
|
type picoConn struct {
|
|
id string
|
|
conn *websocket.Conn
|
|
sessionID string
|
|
writeMu sync.Mutex
|
|
closed atomic.Bool
|
|
}
|
|
|
|
// writeJSON sends a JSON message to the connection with write locking.
|
|
func (pc *picoConn) writeJSON(v any) error {
|
|
if pc.closed.Load() {
|
|
return fmt.Errorf("connection closed")
|
|
}
|
|
pc.writeMu.Lock()
|
|
defer pc.writeMu.Unlock()
|
|
return pc.conn.WriteJSON(v)
|
|
}
|
|
|
|
// close closes the connection.
|
|
func (pc *picoConn) close() {
|
|
if pc.closed.CompareAndSwap(false, true) {
|
|
pc.conn.Close()
|
|
}
|
|
}
|
|
|
|
// PicoChannel implements the native Pico Protocol WebSocket channel.
|
|
// It serves as the reference implementation for all optional capability interfaces.
|
|
type PicoChannel struct {
|
|
*channels.BaseChannel
|
|
config config.PicoConfig
|
|
upgrader websocket.Upgrader
|
|
connections sync.Map // connID → *picoConn
|
|
connCount atomic.Int32
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewPicoChannel creates a new Pico Protocol channel.
|
|
func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) {
|
|
if cfg.Token == "" {
|
|
return nil, fmt.Errorf("pico token is required")
|
|
}
|
|
|
|
base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom)
|
|
|
|
allowOrigins := cfg.AllowOrigins
|
|
checkOrigin := func(r *http.Request) bool {
|
|
if len(allowOrigins) == 0 {
|
|
return true // allow all if not configured
|
|
}
|
|
origin := r.Header.Get("Origin")
|
|
for _, allowed := range allowOrigins {
|
|
if allowed == "*" || allowed == origin {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
return &PicoChannel{
|
|
BaseChannel: base,
|
|
config: cfg,
|
|
upgrader: websocket.Upgrader{
|
|
CheckOrigin: checkOrigin,
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Start implements Channel.
|
|
func (c *PicoChannel) Start(ctx context.Context) error {
|
|
logger.InfoC("pico", "Starting Pico Protocol channel")
|
|
c.ctx, c.cancel = context.WithCancel(ctx)
|
|
c.SetRunning(true)
|
|
logger.InfoC("pico", "Pico Protocol channel started")
|
|
return nil
|
|
}
|
|
|
|
// Stop implements Channel.
|
|
func (c *PicoChannel) Stop(ctx context.Context) error {
|
|
logger.InfoC("pico", "Stopping Pico Protocol channel")
|
|
c.SetRunning(false)
|
|
|
|
// Close all connections
|
|
c.connections.Range(func(key, value any) bool {
|
|
if pc, ok := value.(*picoConn); ok {
|
|
pc.close()
|
|
}
|
|
c.connections.Delete(key)
|
|
return true
|
|
})
|
|
|
|
if c.cancel != nil {
|
|
c.cancel()
|
|
}
|
|
|
|
logger.InfoC("pico", "Pico Protocol channel stopped")
|
|
return nil
|
|
}
|
|
|
|
// WebhookPath implements channels.WebhookHandler.
|
|
func (c *PicoChannel) WebhookPath() string { return "/pico/" }
|
|
|
|
// ServeHTTP implements http.Handler for the shared HTTP server.
|
|
func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
path := strings.TrimPrefix(r.URL.Path, "/pico")
|
|
|
|
switch {
|
|
case path == "/ws" || path == "/ws/":
|
|
c.handleWebSocket(w, r)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}
|
|
|
|
// Send implements Channel — sends a message to the appropriate WebSocket connection.
|
|
func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
|
if !c.IsRunning() {
|
|
return channels.ErrNotRunning
|
|
}
|
|
|
|
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
|
"content": msg.Content,
|
|
})
|
|
|
|
return c.broadcastToSession(msg.ChatID, outMsg)
|
|
}
|
|
|
|
// EditMessage implements channels.MessageEditor.
|
|
func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
|
|
outMsg := newMessage(TypeMessageUpdate, map[string]any{
|
|
"message_id": messageID,
|
|
"content": content,
|
|
})
|
|
return c.broadcastToSession(chatID, outMsg)
|
|
}
|
|
|
|
// StartTyping implements channels.TypingCapable.
|
|
func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
|
startMsg := newMessage(TypeTypingStart, nil)
|
|
if err := c.broadcastToSession(chatID, startMsg); err != nil {
|
|
return func() {}, err
|
|
}
|
|
return func() {
|
|
stopMsg := newMessage(TypeTypingStop, nil)
|
|
c.broadcastToSession(chatID, stopMsg)
|
|
}, nil
|
|
}
|
|
|
|
// SendPlaceholder implements channels.PlaceholderCapable.
|
|
// It sends a placeholder message via the Pico Protocol that will later be
|
|
// edited to the actual response via EditMessage (channels.MessageEditor).
|
|
func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
|
if !c.config.Placeholder.Enabled {
|
|
return "", nil
|
|
}
|
|
|
|
text := c.config.Placeholder.Text
|
|
if text == "" {
|
|
text = "Thinking... 💭"
|
|
}
|
|
|
|
msgID := uuid.New().String()
|
|
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
|
"content": text,
|
|
"message_id": msgID,
|
|
})
|
|
|
|
if err := c.broadcastToSession(chatID, outMsg); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return msgID, nil
|
|
}
|
|
|
|
// broadcastToSession sends a message to all connections with a matching session.
|
|
func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error {
|
|
// chatID format: "pico:<sessionID>"
|
|
sessionID := strings.TrimPrefix(chatID, "pico:")
|
|
msg.SessionID = sessionID
|
|
|
|
var sent bool
|
|
c.connections.Range(func(key, value any) bool {
|
|
pc, ok := value.(*picoConn)
|
|
if !ok {
|
|
return true
|
|
}
|
|
if pc.sessionID == sessionID {
|
|
if err := pc.writeJSON(msg); err != nil {
|
|
logger.DebugCF("pico", "Write to connection failed", map[string]any{
|
|
"conn_id": pc.id,
|
|
"error": err.Error(),
|
|
})
|
|
} else {
|
|
sent = true
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
if !sent {
|
|
return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle.
|
|
func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
if !c.IsRunning() {
|
|
http.Error(w, "channel not running", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
// Authenticate
|
|
if !c.authenticate(r) {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Check connection limit
|
|
maxConns := c.config.MaxConnections
|
|
if maxConns <= 0 {
|
|
maxConns = 100
|
|
}
|
|
if int(c.connCount.Load()) >= maxConns {
|
|
http.Error(w, "too many connections", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
// Echo the matched subprotocol back so the browser accepts the upgrade.
|
|
var responseHeader http.Header
|
|
if proto := c.matchedSubprotocol(r); proto != "" {
|
|
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
|
|
}
|
|
|
|
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
|
|
if err != nil {
|
|
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
|
|
"error": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Determine session ID from query param or generate one
|
|
sessionID := r.URL.Query().Get("session_id")
|
|
if sessionID == "" {
|
|
sessionID = uuid.New().String()
|
|
}
|
|
|
|
pc := &picoConn{
|
|
id: uuid.New().String(),
|
|
conn: conn,
|
|
sessionID: sessionID,
|
|
}
|
|
|
|
c.connections.Store(pc.id, pc)
|
|
c.connCount.Add(1)
|
|
|
|
logger.InfoCF("pico", "WebSocket client connected", map[string]any{
|
|
"conn_id": pc.id,
|
|
"session_id": sessionID,
|
|
})
|
|
|
|
go c.readLoop(pc)
|
|
}
|
|
|
|
// authenticate checks the request for a valid token:
|
|
// 1. Authorization: Bearer <token> header
|
|
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
|
|
// 3. Query parameter "token" (only when AllowTokenQuery is on)
|
|
func (c *PicoChannel) authenticate(r *http.Request) bool {
|
|
token := c.config.Token
|
|
if token == "" {
|
|
return false
|
|
}
|
|
|
|
// Check Authorization header
|
|
auth := r.Header.Get("Authorization")
|
|
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
|
if after == token {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
|
|
if c.matchedSubprotocol(r) != "" {
|
|
return true
|
|
}
|
|
|
|
// Check query parameter only when explicitly allowed
|
|
if c.config.AllowTokenQuery {
|
|
if r.URL.Query().Get("token") == token {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
|
|
// the configured token, or "" if none do.
|
|
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
|
|
token := c.config.Token
|
|
for _, proto := range websocket.Subprotocols(r) {
|
|
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
|
|
return proto
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// readLoop reads messages from a WebSocket connection.
|
|
func (c *PicoChannel) readLoop(pc *picoConn) {
|
|
defer func() {
|
|
pc.close()
|
|
c.connections.Delete(pc.id)
|
|
c.connCount.Add(-1)
|
|
logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{
|
|
"conn_id": pc.id,
|
|
"session_id": pc.sessionID,
|
|
})
|
|
}()
|
|
|
|
readTimeout := time.Duration(c.config.ReadTimeout) * time.Second
|
|
if readTimeout <= 0 {
|
|
readTimeout = 60 * time.Second
|
|
}
|
|
|
|
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
pc.conn.SetPongHandler(func(appData string) error {
|
|
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
return nil
|
|
})
|
|
|
|
// Start ping ticker
|
|
pingInterval := time.Duration(c.config.PingInterval) * time.Second
|
|
if pingInterval <= 0 {
|
|
pingInterval = 30 * time.Second
|
|
}
|
|
go c.pingLoop(pc, pingInterval)
|
|
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
_, rawMsg, err := pc.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
|
logger.DebugCF("pico", "WebSocket read error", map[string]any{
|
|
"conn_id": pc.id,
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
|
|
var msg PicoMessage
|
|
if err := json.Unmarshal(rawMsg, &msg); err != nil {
|
|
errMsg := newError("invalid_message", "failed to parse message")
|
|
pc.writeJSON(errMsg)
|
|
continue
|
|
}
|
|
|
|
c.handleMessage(pc, msg)
|
|
}
|
|
}
|
|
|
|
// pingLoop sends periodic ping frames to keep the connection alive.
|
|
func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if pc.closed.Load() {
|
|
return
|
|
}
|
|
pc.writeMu.Lock()
|
|
err := pc.conn.WriteMessage(websocket.PingMessage, nil)
|
|
pc.writeMu.Unlock()
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleMessage processes an inbound Pico Protocol message.
|
|
func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) {
|
|
switch msg.Type {
|
|
case TypePing:
|
|
pong := newMessage(TypePong, nil)
|
|
pong.ID = msg.ID
|
|
pc.writeJSON(pong)
|
|
|
|
case TypeMessageSend:
|
|
c.handleMessageSend(pc, msg)
|
|
|
|
default:
|
|
errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type))
|
|
pc.writeJSON(errMsg)
|
|
}
|
|
}
|
|
|
|
// handleMessageSend processes an inbound message.send from a client.
|
|
func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
|
content, _ := msg.Payload["content"].(string)
|
|
if strings.TrimSpace(content) == "" {
|
|
errMsg := newError("empty_content", "message content is empty")
|
|
pc.writeJSON(errMsg)
|
|
return
|
|
}
|
|
|
|
sessionID := msg.SessionID
|
|
if sessionID == "" {
|
|
sessionID = pc.sessionID
|
|
}
|
|
|
|
chatID := "pico:" + sessionID
|
|
senderID := "pico-user"
|
|
|
|
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
|
|
|
|
metadata := map[string]string{
|
|
"platform": "pico",
|
|
"session_id": sessionID,
|
|
"conn_id": pc.id,
|
|
}
|
|
|
|
logger.DebugCF("pico", "Received message", map[string]any{
|
|
"session_id": sessionID,
|
|
"preview": truncate(content, 50),
|
|
})
|
|
|
|
sender := bus.SenderInfo{
|
|
Platform: "pico",
|
|
PlatformID: senderID,
|
|
CanonicalID: identity.BuildCanonicalID("pico", senderID),
|
|
}
|
|
|
|
if !c.IsAllowedSender(sender) {
|
|
return
|
|
}
|
|
|
|
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender)
|
|
}
|
|
|
|
// truncate truncates a string to maxLen runes.
|
|
func truncate(s string, maxLen int) string {
|
|
runes := []rune(s)
|
|
if len(runes) <= maxLen {
|
|
return s
|
|
}
|
|
return string(runes[:maxLen]) + "..."
|
|
}
|