Files
picoclaw/pkg/channels/pico/pico.go
T
Hoshina 90b4a64683 feat(channels): add typing/placeholder automation and Pico Protocol channel (Phase 10 + 7)
Phase 10: Define TypingCapable, MessageEditor, PlaceholderRecorder interfaces.
Manager orchestrates outbound typing stop and placeholder editing via preSend.
Migrate Telegram, Discord, Slack, OneBot to register state with Manager instead
of handling locally in Send. Phase 7: Add native WebSocket Pico Protocol channel
as reference implementation of all optional capability interfaces.
2026-02-24 12:10:45 +08:00

431 lines
10 KiB
Go

package pico
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// picoConn represents a single WebSocket connection.
type picoConn struct {
id string
conn *websocket.Conn
sessionID string
writeMu sync.Mutex
closed atomic.Bool
}
// writeJSON sends a JSON message to the connection with write locking.
func (pc *picoConn) writeJSON(v any) error {
if pc.closed.Load() {
return fmt.Errorf("connection closed")
}
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
return pc.conn.WriteJSON(v)
}
// close closes the connection.
func (pc *picoConn) close() {
if pc.closed.CompareAndSwap(false, true) {
pc.conn.Close()
}
}
// PicoChannel implements the native Pico Protocol WebSocket channel.
// It serves as the reference implementation for all optional capability interfaces.
type PicoChannel struct {
*channels.BaseChannel
config config.PicoConfig
upgrader websocket.Upgrader
connections sync.Map // connID → *picoConn
connCount atomic.Int32
ctx context.Context
cancel context.CancelFunc
}
// NewPicoChannel creates a new Pico Protocol channel.
func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) {
if cfg.Token == "" {
return nil, fmt.Errorf("pico token is required")
}
base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom)
allowOrigins := cfg.AllowOrigins
checkOrigin := func(r *http.Request) bool {
if len(allowOrigins) == 0 {
return true // allow all if not configured
}
origin := r.Header.Get("Origin")
for _, allowed := range allowOrigins {
if allowed == "*" || allowed == origin {
return true
}
}
return false
}
return &PicoChannel{
BaseChannel: base,
config: cfg,
upgrader: websocket.Upgrader{
CheckOrigin: checkOrigin,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
}, nil
}
// Start implements Channel.
func (c *PicoChannel) Start(ctx context.Context) error {
logger.InfoC("pico", "Starting Pico Protocol channel")
c.ctx, c.cancel = context.WithCancel(ctx)
c.SetRunning(true)
logger.InfoC("pico", "Pico Protocol channel started")
return nil
}
// Stop implements Channel.
func (c *PicoChannel) Stop(ctx context.Context) error {
logger.InfoC("pico", "Stopping Pico Protocol channel")
c.SetRunning(false)
// Close all connections
c.connections.Range(func(key, value any) bool {
if pc, ok := value.(*picoConn); ok {
pc.close()
}
c.connections.Delete(key)
return true
})
if c.cancel != nil {
c.cancel()
}
logger.InfoC("pico", "Pico Protocol channel stopped")
return nil
}
// WebhookPath implements channels.WebhookHandler.
func (c *PicoChannel) WebhookPath() string { return "/pico/" }
// ServeHTTP implements http.Handler for the shared HTTP server.
func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/pico")
switch {
case path == "/ws" || path == "/ws/":
c.handleWebSocket(w, r)
default:
http.NotFound(w, r)
}
}
// Send implements Channel — sends a message to the appropriate WebSocket connection.
func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
outMsg := newMessage(TypeMessageCreate, map[string]any{
"content": msg.Content,
})
return c.broadcastToSession(msg.ChatID, outMsg)
}
// EditMessage implements channels.MessageEditor.
func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
outMsg := newMessage(TypeMessageUpdate, map[string]any{
"message_id": messageID,
"content": content,
})
return c.broadcastToSession(chatID, outMsg)
}
// StartTyping implements channels.TypingCapable.
func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
startMsg := newMessage(TypeTypingStart, nil)
if err := c.broadcastToSession(chatID, startMsg); err != nil {
return func() {}, err
}
return func() {
stopMsg := newMessage(TypeTypingStop, nil)
c.broadcastToSession(chatID, stopMsg)
}, nil
}
// broadcastToSession sends a message to all connections with a matching session.
func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error {
// chatID format: "pico:<sessionID>"
sessionID := strings.TrimPrefix(chatID, "pico:")
msg.SessionID = sessionID
var sent bool
c.connections.Range(func(key, value any) bool {
pc, ok := value.(*picoConn)
if !ok {
return true
}
if pc.sessionID == sessionID {
if err := pc.writeJSON(msg); err != nil {
logger.DebugCF("pico", "Write to connection failed", map[string]any{
"conn_id": pc.id,
"error": err.Error(),
})
} else {
sent = true
}
}
return true
})
if !sent {
return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed)
}
return nil
}
// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle.
func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
if !c.IsRunning() {
http.Error(w, "channel not running", http.StatusServiceUnavailable)
return
}
// Authenticate
if !c.authenticate(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
// Check connection limit
maxConns := c.config.MaxConnections
if maxConns <= 0 {
maxConns = 100
}
if int(c.connCount.Load()) >= maxConns {
http.Error(w, "too many connections", http.StatusServiceUnavailable)
return
}
conn, err := c.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
"error": err.Error(),
})
return
}
// Determine session ID from query param or generate one
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" {
sessionID = uuid.New().String()
}
pc := &picoConn{
id: uuid.New().String(),
conn: conn,
sessionID: sessionID,
}
c.connections.Store(pc.id, pc)
c.connCount.Add(1)
logger.InfoCF("pico", "WebSocket client connected", map[string]any{
"conn_id": pc.id,
"session_id": sessionID,
})
go c.readLoop(pc)
}
// authenticate checks the Bearer token from header or query parameter.
func (c *PicoChannel) authenticate(r *http.Request) bool {
token := c.config.Token
if token == "" {
return false
}
// Check Authorization header
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
if strings.TrimPrefix(auth, "Bearer ") == token {
return true
}
}
// Check query parameter
if r.URL.Query().Get("token") == token {
return true
}
return false
}
// readLoop reads messages from a WebSocket connection.
func (c *PicoChannel) readLoop(pc *picoConn) {
defer func() {
pc.close()
c.connections.Delete(pc.id)
c.connCount.Add(-1)
logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{
"conn_id": pc.id,
"session_id": pc.sessionID,
})
}()
readTimeout := time.Duration(c.config.ReadTimeout) * time.Second
if readTimeout <= 0 {
readTimeout = 60 * time.Second
}
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
pc.conn.SetPongHandler(func(appData string) error {
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
return nil
})
// Start ping ticker
pingInterval := time.Duration(c.config.PingInterval) * time.Second
if pingInterval <= 0 {
pingInterval = 30 * time.Second
}
go c.pingLoop(pc, pingInterval)
for {
select {
case <-c.ctx.Done():
return
default:
}
_, rawMsg, err := pc.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
logger.DebugCF("pico", "WebSocket read error", map[string]any{
"conn_id": pc.id,
"error": err.Error(),
})
}
return
}
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
var msg PicoMessage
if err := json.Unmarshal(rawMsg, &msg); err != nil {
errMsg := newError("invalid_message", "failed to parse message")
pc.writeJSON(errMsg)
continue
}
c.handleMessage(pc, msg)
}
}
// pingLoop sends periodic ping frames to keep the connection alive.
func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
if pc.closed.Load() {
return
}
pc.writeMu.Lock()
err := pc.conn.WriteMessage(websocket.PingMessage, nil)
pc.writeMu.Unlock()
if err != nil {
return
}
}
}
}
// handleMessage processes an inbound Pico Protocol message.
func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) {
switch msg.Type {
case TypePing:
pong := newMessage(TypePong, nil)
pong.ID = msg.ID
pc.writeJSON(pong)
case TypeMessageSend:
c.handleMessageSend(pc, msg)
default:
errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type))
pc.writeJSON(errMsg)
}
}
// handleMessageSend processes an inbound message.send from a client.
func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
content, _ := msg.Payload["content"].(string)
if strings.TrimSpace(content) == "" {
errMsg := newError("empty_content", "message content is empty")
pc.writeJSON(errMsg)
return
}
sessionID := msg.SessionID
if sessionID == "" {
sessionID = pc.sessionID
}
chatID := "pico:" + sessionID
senderID := "pico-user"
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
metadata := map[string]string{
"platform": "pico",
"session_id": sessionID,
"conn_id": pc.id,
}
logger.DebugCF("pico", "Received message", map[string]any{
"session_id": sessionID,
"preview": truncate(content, 50),
})
// Register typing with Manager
if rec := c.GetPlaceholderRecorder(); rec != nil {
stop, err := c.StartTyping(c.ctx, chatID)
if err == nil {
rec.RecordTypingStop("pico", chatID, stop)
}
}
c.HandleMessage(peer, msg.ID, senderID, chatID, content, nil, metadata)
}
// truncate truncates a string to maxLen runes.
func truncate(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return string(runes[:maxLen]) + "..."
}