mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
60b68b305a
Phase 10: Define TypingCapable, MessageEditor, PlaceholderRecorder interfaces. Manager orchestrates outbound typing stop and placeholder editing via preSend. Migrate Telegram, Discord, Slack, OneBot to register state with Manager instead of handling locally in Send. Phase 7: Add native WebSocket Pico Protocol channel as reference implementation of all optional capability interfaces.
431 lines
10 KiB
Go
431 lines
10 KiB
Go
package pico
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/channels"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
// picoConn represents a single WebSocket connection.
|
|
type picoConn struct {
|
|
id string
|
|
conn *websocket.Conn
|
|
sessionID string
|
|
writeMu sync.Mutex
|
|
closed atomic.Bool
|
|
}
|
|
|
|
// writeJSON sends a JSON message to the connection with write locking.
|
|
func (pc *picoConn) writeJSON(v any) error {
|
|
if pc.closed.Load() {
|
|
return fmt.Errorf("connection closed")
|
|
}
|
|
pc.writeMu.Lock()
|
|
defer pc.writeMu.Unlock()
|
|
return pc.conn.WriteJSON(v)
|
|
}
|
|
|
|
// close closes the connection.
|
|
func (pc *picoConn) close() {
|
|
if pc.closed.CompareAndSwap(false, true) {
|
|
pc.conn.Close()
|
|
}
|
|
}
|
|
|
|
// PicoChannel implements the native Pico Protocol WebSocket channel.
|
|
// It serves as the reference implementation for all optional capability interfaces.
|
|
type PicoChannel struct {
|
|
*channels.BaseChannel
|
|
config config.PicoConfig
|
|
upgrader websocket.Upgrader
|
|
connections sync.Map // connID → *picoConn
|
|
connCount atomic.Int32
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewPicoChannel creates a new Pico Protocol channel.
|
|
func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) {
|
|
if cfg.Token == "" {
|
|
return nil, fmt.Errorf("pico token is required")
|
|
}
|
|
|
|
base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom)
|
|
|
|
allowOrigins := cfg.AllowOrigins
|
|
checkOrigin := func(r *http.Request) bool {
|
|
if len(allowOrigins) == 0 {
|
|
return true // allow all if not configured
|
|
}
|
|
origin := r.Header.Get("Origin")
|
|
for _, allowed := range allowOrigins {
|
|
if allowed == "*" || allowed == origin {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
return &PicoChannel{
|
|
BaseChannel: base,
|
|
config: cfg,
|
|
upgrader: websocket.Upgrader{
|
|
CheckOrigin: checkOrigin,
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Start implements Channel.
|
|
func (c *PicoChannel) Start(ctx context.Context) error {
|
|
logger.InfoC("pico", "Starting Pico Protocol channel")
|
|
c.ctx, c.cancel = context.WithCancel(ctx)
|
|
c.SetRunning(true)
|
|
logger.InfoC("pico", "Pico Protocol channel started")
|
|
return nil
|
|
}
|
|
|
|
// Stop implements Channel.
|
|
func (c *PicoChannel) Stop(ctx context.Context) error {
|
|
logger.InfoC("pico", "Stopping Pico Protocol channel")
|
|
c.SetRunning(false)
|
|
|
|
// Close all connections
|
|
c.connections.Range(func(key, value any) bool {
|
|
if pc, ok := value.(*picoConn); ok {
|
|
pc.close()
|
|
}
|
|
c.connections.Delete(key)
|
|
return true
|
|
})
|
|
|
|
if c.cancel != nil {
|
|
c.cancel()
|
|
}
|
|
|
|
logger.InfoC("pico", "Pico Protocol channel stopped")
|
|
return nil
|
|
}
|
|
|
|
// WebhookPath implements channels.WebhookHandler.
|
|
func (c *PicoChannel) WebhookPath() string { return "/pico/" }
|
|
|
|
// ServeHTTP implements http.Handler for the shared HTTP server.
|
|
func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
path := strings.TrimPrefix(r.URL.Path, "/pico")
|
|
|
|
switch {
|
|
case path == "/ws" || path == "/ws/":
|
|
c.handleWebSocket(w, r)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}
|
|
|
|
// Send implements Channel — sends a message to the appropriate WebSocket connection.
|
|
func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
|
if !c.IsRunning() {
|
|
return channels.ErrNotRunning
|
|
}
|
|
|
|
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
|
"content": msg.Content,
|
|
})
|
|
|
|
return c.broadcastToSession(msg.ChatID, outMsg)
|
|
}
|
|
|
|
// EditMessage implements channels.MessageEditor.
|
|
func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
|
|
outMsg := newMessage(TypeMessageUpdate, map[string]any{
|
|
"message_id": messageID,
|
|
"content": content,
|
|
})
|
|
return c.broadcastToSession(chatID, outMsg)
|
|
}
|
|
|
|
// StartTyping implements channels.TypingCapable.
|
|
func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
|
startMsg := newMessage(TypeTypingStart, nil)
|
|
if err := c.broadcastToSession(chatID, startMsg); err != nil {
|
|
return func() {}, err
|
|
}
|
|
return func() {
|
|
stopMsg := newMessage(TypeTypingStop, nil)
|
|
c.broadcastToSession(chatID, stopMsg)
|
|
}, nil
|
|
}
|
|
|
|
// broadcastToSession sends a message to all connections with a matching session.
|
|
func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error {
|
|
// chatID format: "pico:<sessionID>"
|
|
sessionID := strings.TrimPrefix(chatID, "pico:")
|
|
msg.SessionID = sessionID
|
|
|
|
var sent bool
|
|
c.connections.Range(func(key, value any) bool {
|
|
pc, ok := value.(*picoConn)
|
|
if !ok {
|
|
return true
|
|
}
|
|
if pc.sessionID == sessionID {
|
|
if err := pc.writeJSON(msg); err != nil {
|
|
logger.DebugCF("pico", "Write to connection failed", map[string]any{
|
|
"conn_id": pc.id,
|
|
"error": err.Error(),
|
|
})
|
|
} else {
|
|
sent = true
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
if !sent {
|
|
return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle.
|
|
func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
if !c.IsRunning() {
|
|
http.Error(w, "channel not running", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
// Authenticate
|
|
if !c.authenticate(r) {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Check connection limit
|
|
maxConns := c.config.MaxConnections
|
|
if maxConns <= 0 {
|
|
maxConns = 100
|
|
}
|
|
if int(c.connCount.Load()) >= maxConns {
|
|
http.Error(w, "too many connections", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
conn, err := c.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
|
|
"error": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Determine session ID from query param or generate one
|
|
sessionID := r.URL.Query().Get("session_id")
|
|
if sessionID == "" {
|
|
sessionID = uuid.New().String()
|
|
}
|
|
|
|
pc := &picoConn{
|
|
id: uuid.New().String(),
|
|
conn: conn,
|
|
sessionID: sessionID,
|
|
}
|
|
|
|
c.connections.Store(pc.id, pc)
|
|
c.connCount.Add(1)
|
|
|
|
logger.InfoCF("pico", "WebSocket client connected", map[string]any{
|
|
"conn_id": pc.id,
|
|
"session_id": sessionID,
|
|
})
|
|
|
|
go c.readLoop(pc)
|
|
}
|
|
|
|
// authenticate checks the Bearer token from header or query parameter.
|
|
func (c *PicoChannel) authenticate(r *http.Request) bool {
|
|
token := c.config.Token
|
|
if token == "" {
|
|
return false
|
|
}
|
|
|
|
// Check Authorization header
|
|
auth := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
if strings.TrimPrefix(auth, "Bearer ") == token {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check query parameter
|
|
if r.URL.Query().Get("token") == token {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// readLoop reads messages from a WebSocket connection.
|
|
func (c *PicoChannel) readLoop(pc *picoConn) {
|
|
defer func() {
|
|
pc.close()
|
|
c.connections.Delete(pc.id)
|
|
c.connCount.Add(-1)
|
|
logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{
|
|
"conn_id": pc.id,
|
|
"session_id": pc.sessionID,
|
|
})
|
|
}()
|
|
|
|
readTimeout := time.Duration(c.config.ReadTimeout) * time.Second
|
|
if readTimeout <= 0 {
|
|
readTimeout = 60 * time.Second
|
|
}
|
|
|
|
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
pc.conn.SetPongHandler(func(appData string) error {
|
|
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
return nil
|
|
})
|
|
|
|
// Start ping ticker
|
|
pingInterval := time.Duration(c.config.PingInterval) * time.Second
|
|
if pingInterval <= 0 {
|
|
pingInterval = 30 * time.Second
|
|
}
|
|
go c.pingLoop(pc, pingInterval)
|
|
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
_, rawMsg, err := pc.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
|
logger.DebugCF("pico", "WebSocket read error", map[string]any{
|
|
"conn_id": pc.id,
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
_ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
|
|
var msg PicoMessage
|
|
if err := json.Unmarshal(rawMsg, &msg); err != nil {
|
|
errMsg := newError("invalid_message", "failed to parse message")
|
|
pc.writeJSON(errMsg)
|
|
continue
|
|
}
|
|
|
|
c.handleMessage(pc, msg)
|
|
}
|
|
}
|
|
|
|
// pingLoop sends periodic ping frames to keep the connection alive.
|
|
func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if pc.closed.Load() {
|
|
return
|
|
}
|
|
pc.writeMu.Lock()
|
|
err := pc.conn.WriteMessage(websocket.PingMessage, nil)
|
|
pc.writeMu.Unlock()
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleMessage processes an inbound Pico Protocol message.
|
|
func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) {
|
|
switch msg.Type {
|
|
case TypePing:
|
|
pong := newMessage(TypePong, nil)
|
|
pong.ID = msg.ID
|
|
pc.writeJSON(pong)
|
|
|
|
case TypeMessageSend:
|
|
c.handleMessageSend(pc, msg)
|
|
|
|
default:
|
|
errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type))
|
|
pc.writeJSON(errMsg)
|
|
}
|
|
}
|
|
|
|
// handleMessageSend processes an inbound message.send from a client.
|
|
func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
|
content, _ := msg.Payload["content"].(string)
|
|
if strings.TrimSpace(content) == "" {
|
|
errMsg := newError("empty_content", "message content is empty")
|
|
pc.writeJSON(errMsg)
|
|
return
|
|
}
|
|
|
|
sessionID := msg.SessionID
|
|
if sessionID == "" {
|
|
sessionID = pc.sessionID
|
|
}
|
|
|
|
chatID := "pico:" + sessionID
|
|
senderID := "pico-user"
|
|
|
|
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
|
|
|
|
metadata := map[string]string{
|
|
"platform": "pico",
|
|
"session_id": sessionID,
|
|
"conn_id": pc.id,
|
|
}
|
|
|
|
logger.DebugCF("pico", "Received message", map[string]any{
|
|
"session_id": sessionID,
|
|
"preview": truncate(content, 50),
|
|
})
|
|
|
|
// Register typing with Manager
|
|
if rec := c.GetPlaceholderRecorder(); rec != nil {
|
|
stop, err := c.StartTyping(c.ctx, chatID)
|
|
if err == nil {
|
|
rec.RecordTypingStop("pico", chatID, stop)
|
|
}
|
|
}
|
|
|
|
c.HandleMessage(peer, msg.ID, senderID, chatID, content, nil, metadata)
|
|
}
|
|
|
|
// truncate truncates a string to maxLen runes.
|
|
func truncate(s string, maxLen int) string {
|
|
runes := []rune(s)
|
|
if len(runes) <= maxLen {
|
|
return s
|
|
}
|
|
return string(runes[:maxLen]) + "..."
|
|
}
|