package pico import ( "context" "encoding/json" "fmt" "net/http" "strings" "sync" "sync/atomic" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) // picoConn represents a single WebSocket connection. type picoConn struct { id string conn *websocket.Conn sessionID string writeMu sync.Mutex closed atomic.Bool } // writeJSON sends a JSON message to the connection with write locking. func (pc *picoConn) writeJSON(v any) error { if pc.closed.Load() { return fmt.Errorf("connection closed") } pc.writeMu.Lock() defer pc.writeMu.Unlock() return pc.conn.WriteJSON(v) } // close closes the connection. func (pc *picoConn) close() { if pc.closed.CompareAndSwap(false, true) { pc.conn.Close() } } // PicoChannel implements the native Pico Protocol WebSocket channel. // It serves as the reference implementation for all optional capability interfaces. type PicoChannel struct { *channels.BaseChannel config config.PicoConfig upgrader websocket.Upgrader connections sync.Map // connID → *picoConn connCount atomic.Int32 ctx context.Context cancel context.CancelFunc } // NewPicoChannel creates a new Pico Protocol channel. func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) { if cfg.Token == "" { return nil, fmt.Errorf("pico token is required") } base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom) allowOrigins := cfg.AllowOrigins checkOrigin := func(r *http.Request) bool { if len(allowOrigins) == 0 { return true // allow all if not configured } origin := r.Header.Get("Origin") for _, allowed := range allowOrigins { if allowed == "*" || allowed == origin { return true } } return false } return &PicoChannel{ BaseChannel: base, config: cfg, upgrader: websocket.Upgrader{ CheckOrigin: checkOrigin, ReadBufferSize: 1024, WriteBufferSize: 1024, }, }, nil } // Start implements Channel. func (c *PicoChannel) Start(ctx context.Context) error { logger.InfoC("pico", "Starting Pico Protocol channel") c.ctx, c.cancel = context.WithCancel(ctx) c.SetRunning(true) logger.InfoC("pico", "Pico Protocol channel started") return nil } // Stop implements Channel. func (c *PicoChannel) Stop(ctx context.Context) error { logger.InfoC("pico", "Stopping Pico Protocol channel") c.SetRunning(false) // Close all connections c.connections.Range(func(key, value any) bool { if pc, ok := value.(*picoConn); ok { pc.close() } c.connections.Delete(key) return true }) if c.cancel != nil { c.cancel() } logger.InfoC("pico", "Pico Protocol channel stopped") return nil } // WebhookPath implements channels.WebhookHandler. func (c *PicoChannel) WebhookPath() string { return "/pico/" } // ServeHTTP implements http.Handler for the shared HTTP server. func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/pico") switch { case path == "/ws" || path == "/ws/": c.handleWebSocket(w, r) default: http.NotFound(w, r) } } // Send implements Channel — sends a message to the appropriate WebSocket connection. func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning } outMsg := newMessage(TypeMessageCreate, map[string]any{ "content": msg.Content, }) return c.broadcastToSession(msg.ChatID, outMsg) } // EditMessage implements channels.MessageEditor. func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { outMsg := newMessage(TypeMessageUpdate, map[string]any{ "message_id": messageID, "content": content, }) return c.broadcastToSession(chatID, outMsg) } // StartTyping implements channels.TypingCapable. func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { startMsg := newMessage(TypeTypingStart, nil) if err := c.broadcastToSession(chatID, startMsg); err != nil { return func() {}, err } return func() { stopMsg := newMessage(TypeTypingStop, nil) c.broadcastToSession(chatID, stopMsg) }, nil } // SendPlaceholder implements channels.PlaceholderCapable. // It sends a placeholder message via the Pico Protocol that will later be // edited to the actual response via EditMessage (channels.MessageEditor). func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { if !c.config.Placeholder.Enabled { return "", nil } text := c.config.Placeholder.Text if text == "" { text = "Thinking... 💭" } msgID := uuid.New().String() outMsg := newMessage(TypeMessageCreate, map[string]any{ "content": text, "message_id": msgID, }) if err := c.broadcastToSession(chatID, outMsg); err != nil { return "", err } return msgID, nil } // broadcastToSession sends a message to all connections with a matching session. func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { // chatID format: "pico:" sessionID := strings.TrimPrefix(chatID, "pico:") msg.SessionID = sessionID var sent bool c.connections.Range(func(key, value any) bool { pc, ok := value.(*picoConn) if !ok { return true } if pc.sessionID == sessionID { if err := pc.writeJSON(msg); err != nil { logger.DebugCF("pico", "Write to connection failed", map[string]any{ "conn_id": pc.id, "error": err.Error(), }) } else { sent = true } } return true }) if !sent { return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) } return nil } // handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle. func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { if !c.IsRunning() { http.Error(w, "channel not running", http.StatusServiceUnavailable) return } // Authenticate if !c.authenticate(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) return } // Check connection limit maxConns := c.config.MaxConnections if maxConns <= 0 { maxConns = 100 } if int(c.connCount.Load()) >= maxConns { http.Error(w, "too many connections", http.StatusServiceUnavailable) return } conn, err := c.upgrader.Upgrade(w, r, nil) if err != nil { logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{ "error": err.Error(), }) return } // Determine session ID from query param or generate one sessionID := r.URL.Query().Get("session_id") if sessionID == "" { sessionID = uuid.New().String() } pc := &picoConn{ id: uuid.New().String(), conn: conn, sessionID: sessionID, } c.connections.Store(pc.id, pc) c.connCount.Add(1) logger.InfoCF("pico", "WebSocket client connected", map[string]any{ "conn_id": pc.id, "session_id": sessionID, }) go c.readLoop(pc) } // authenticate checks the Bearer token from the Authorization header. // Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled. func (c *PicoChannel) authenticate(r *http.Request) bool { token := c.config.Token if token == "" { return false } // Check Authorization header auth := r.Header.Get("Authorization") if after, ok := strings.CutPrefix(auth, "Bearer "); ok { if after == token { return true } } // Check query parameter only when explicitly allowed if c.config.AllowTokenQuery { if r.URL.Query().Get("token") == token { return true } } return false } // readLoop reads messages from a WebSocket connection. func (c *PicoChannel) readLoop(pc *picoConn) { defer func() { pc.close() c.connections.Delete(pc.id) c.connCount.Add(-1) logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ "conn_id": pc.id, "session_id": pc.sessionID, }) }() readTimeout := time.Duration(c.config.ReadTimeout) * time.Second if readTimeout <= 0 { readTimeout = 60 * time.Second } _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) pc.conn.SetPongHandler(func(appData string) error { _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) return nil }) // Start ping ticker pingInterval := time.Duration(c.config.PingInterval) * time.Second if pingInterval <= 0 { pingInterval = 30 * time.Second } go c.pingLoop(pc, pingInterval) for { select { case <-c.ctx.Done(): return default: } _, rawMsg, err := pc.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { logger.DebugCF("pico", "WebSocket read error", map[string]any{ "conn_id": pc.id, "error": err.Error(), }) } return } _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) var msg PicoMessage if err := json.Unmarshal(rawMsg, &msg); err != nil { errMsg := newError("invalid_message", "failed to parse message") pc.writeJSON(errMsg) continue } c.handleMessage(pc, msg) } } // pingLoop sends periodic ping frames to keep the connection alive. func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-c.ctx.Done(): return case <-ticker.C: if pc.closed.Load() { return } pc.writeMu.Lock() err := pc.conn.WriteMessage(websocket.PingMessage, nil) pc.writeMu.Unlock() if err != nil { return } } } } // handleMessage processes an inbound Pico Protocol message. func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { switch msg.Type { case TypePing: pong := newMessage(TypePong, nil) pong.ID = msg.ID pc.writeJSON(pong) case TypeMessageSend: c.handleMessageSend(pc, msg) default: errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) pc.writeJSON(errMsg) } } // handleMessageSend processes an inbound message.send from a client. func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { content, _ := msg.Payload["content"].(string) if strings.TrimSpace(content) == "" { errMsg := newError("empty_content", "message content is empty") pc.writeJSON(errMsg) return } sessionID := msg.SessionID if sessionID == "" { sessionID = pc.sessionID } chatID := "pico:" + sessionID senderID := "pico-user" peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} metadata := map[string]string{ "platform": "pico", "session_id": sessionID, "conn_id": pc.id, } logger.DebugCF("pico", "Received message", map[string]any{ "session_id": sessionID, "preview": truncate(content, 50), }) sender := bus.SenderInfo{ Platform: "pico", PlatformID: senderID, CanonicalID: identity.BuildCanonicalID("pico", senderID), } if !c.IsAllowedSender(sender) { return } c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender) } // truncate truncates a string to maxLen runes. func truncate(s string, maxLen int) string { runes := []rune(s) if len(runes) <= maxLen { return s } return string(runes[:maxLen]) + "..." }