diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 86ce98b06..f3ba55a92 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -54,12 +54,13 @@ func (pc *picoConn) close() { // It serves as the reference implementation for all optional capability interfaces. type PicoChannel struct { *channels.BaseChannel - config config.PicoConfig - upgrader websocket.Upgrader - connections sync.Map // connID → *picoConn - connCount atomic.Int32 - ctx context.Context - cancel context.CancelFunc + config config.PicoConfig + upgrader websocket.Upgrader + connections map[string]*picoConn // connID -> *picoConn + sessionConnections map[string]map[string]*picoConn // sessionID -> connID -> *picoConn + connsMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc } // NewPicoChannel creates a new Pico Protocol channel. @@ -92,9 +93,104 @@ func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoCha ReadBufferSize: 1024, WriteBufferSize: 1024, }, + connections: make(map[string]*picoConn), + sessionConnections: make(map[string]map[string]*picoConn), }, nil } +// createAndAddConnection checks MaxConnections and registers a connection atomically. +func (c *PicoChannel) createAndAddConnection(conn *websocket.Conn, sessionID string, maxConns int) (*picoConn, error) { + c.connsMu.Lock() + defer c.connsMu.Unlock() + if len(c.connections) >= maxConns { + return nil, channels.ErrTemporary + } + + var connID string + for { + connID = uuid.New().String() + if _, exists := c.connections[connID]; !exists { + break + } + } + + pc := &picoConn{ + id: connID, + conn: conn, + sessionID: sessionID, + } + + c.connections[pc.id] = pc + bySession, ok := c.sessionConnections[pc.sessionID] + if !ok { + bySession = make(map[string]*picoConn) + c.sessionConnections[pc.sessionID] = bySession + } + bySession[pc.id] = pc + + return pc, nil +} + +// removeConnection deletes a connection from indexes and returns it when found. +func (c *PicoChannel) removeConnection(connID string) *picoConn { + c.connsMu.Lock() + defer c.connsMu.Unlock() + + pc, ok := c.connections[connID] + if !ok { + return nil + } + + delete(c.connections, connID) + if bySession, ok := c.sessionConnections[pc.sessionID]; ok { + delete(bySession, connID) + if len(bySession) == 0 { + delete(c.sessionConnections, pc.sessionID) + } + } + + return pc +} + +// takeAllConnections snapshots and clears all connection indexes. +func (c *PicoChannel) takeAllConnections() []*picoConn { + c.connsMu.Lock() + defer c.connsMu.Unlock() + + all := make([]*picoConn, 0, len(c.connections)) + for _, pc := range c.connections { + all = append(all, pc) + } + clear(c.connections) + clear(c.sessionConnections) + + return all +} + +// sessionConnectionsSnapshot returns all active connections for a session. +func (c *PicoChannel) sessionConnectionsSnapshot(sessionID string) []*picoConn { + c.connsMu.RLock() + defer c.connsMu.RUnlock() + + bySession, ok := c.sessionConnections[sessionID] + if !ok || len(bySession) == 0 { + return nil + } + + conns := make([]*picoConn, 0, len(bySession)) + for _, pc := range bySession { + conns = append(conns, pc) + } + return conns +} + +// currentConnCount returns a lock-protected snapshot of active connection count. +func (c *PicoChannel) currentConnCount() int { + c.connsMu.RLock() + defer c.connsMu.RUnlock() + return len(c.connections) +} + // Start implements Channel. func (c *PicoChannel) Start(ctx context.Context) error { logger.InfoC("pico", "Starting Pico Protocol channel") @@ -110,13 +206,9 @@ func (c *PicoChannel) Stop(ctx context.Context) error { c.SetRunning(false) // Close all connections - c.connections.Range(func(key, value any) bool { - if pc, ok := value.(*picoConn); ok { - pc.close() - } - c.connections.Delete(key) - return true - }) + for _, pc := range c.takeAllConnections() { + pc.close() + } if c.cancel != nil { c.cancel() @@ -133,8 +225,8 @@ func (c *PicoChannel) WebhookPath() string { return "/pico/" } func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/pico") - switch { - case path == "/ws" || path == "/ws/": + switch path { + case "/ws", "/ws/": c.handleWebSocket(w, r) default: http.NotFound(w, r) @@ -208,23 +300,16 @@ func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { msg.SessionID = sessionID var sent bool - c.connections.Range(func(key, value any) bool { - pc, ok := value.(*picoConn) - if !ok { - return true + for _, pc := range c.sessionConnectionsSnapshot(sessionID) { + if err := pc.writeJSON(msg); err != nil { + logger.DebugCF("pico", "Write to connection failed", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } else { + sent = true } - if pc.sessionID == sessionID { - if err := pc.writeJSON(msg); err != nil { - logger.DebugCF("pico", "Write to connection failed", map[string]any{ - "conn_id": pc.id, - "error": err.Error(), - }) - } else { - sent = true - } - } - return true - }) + } if !sent { return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) @@ -250,7 +335,7 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { if maxConns <= 0 { maxConns = 100 } - if int(c.connCount.Load()) >= maxConns { + if c.currentConnCount() >= maxConns { http.Error(w, "too many connections", http.StatusServiceUnavailable) return } @@ -275,15 +360,17 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { sessionID = uuid.New().String() } - pc := &picoConn{ - id: uuid.New().String(), - conn: conn, - sessionID: sessionID, + pc, err := c.createAndAddConnection(conn, sessionID, maxConns) + if err != nil { + _ = conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseTryAgainLater, "too many connections"), + time.Now().Add(2*time.Second), + ) + _ = conn.Close() + return } - c.connections.Store(pc.id, pc) - c.connCount.Add(1) - logger.InfoCF("pico", "WebSocket client connected", map[string]any{ "conn_id": pc.id, "session_id": sessionID, @@ -341,12 +428,12 @@ func (c *PicoChannel) matchedSubprotocol(r *http.Request) string { func (c *PicoChannel) readLoop(pc *picoConn) { defer func() { pc.close() - c.connections.Delete(pc.id) - c.connCount.Add(-1) - logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ - "conn_id": pc.id, - "session_id": pc.sessionID, - }) + if removed := c.removeConnection(pc.id); removed != nil { + logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ + "conn_id": removed.id, + "session_id": removed.sessionID, + }) + } }() readTimeout := time.Duration(c.config.ReadTimeout) * time.Second diff --git a/pkg/channels/pico/pico_test.go b/pkg/channels/pico/pico_test.go new file mode 100644 index 000000000..e712767ad --- /dev/null +++ b/pkg/channels/pico/pico_test.go @@ -0,0 +1,144 @@ +package pico + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func newTestPicoChannel(t *testing.T) *PicoChannel { + t.Helper() + + cfg := config.PicoConfig{} + cfg.SetToken("test-token") + ch, err := NewPicoChannel(cfg, bus.NewMessageBus()) + if err != nil { + t.Fatalf("NewPicoChannel: %v", err) + } + + ch.ctx = context.Background() + return ch +} + +func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) { + ch := newTestPicoChannel(t) + + const ( + maxConns = 5 + goroutines = 64 + sessionID = "session-a" + ) + + var wg sync.WaitGroup + var mu sync.Mutex + successCount := 0 + errCount := 0 + + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + + pc, err := ch.createAndAddConnection(nil, sessionID, maxConns) + mu.Lock() + defer mu.Unlock() + + if err == nil { + successCount++ + if pc == nil { + t.Errorf("pc is nil on success") + } + return + } + if !errors.Is(err, channels.ErrTemporary) { + t.Errorf("unexpected error: %v", err) + return + } + errCount++ + }() + } + wg.Wait() + + if successCount > maxConns { + t.Fatalf("successCount=%d > maxConns=%d", successCount, maxConns) + } + if successCount+errCount != goroutines { + t.Fatalf("success=%d err=%d total=%d want=%d", successCount, errCount, successCount+errCount, goroutines) + } + if got := ch.currentConnCount(); got != maxConns { + t.Fatalf("currentConnCount=%d want=%d", got, maxConns) + } +} + +func TestRemoveConnection_CleansBothIndexes(t *testing.T) { + ch := newTestPicoChannel(t) + + pc, err := ch.createAndAddConnection(nil, "session-cleanup", 10) + if err != nil { + t.Fatalf("createAndAddConnection: %v", err) + } + + removed := ch.removeConnection(pc.id) + if removed == nil { + t.Fatal("removeConnection returned nil") + } + + ch.connsMu.RLock() + defer ch.connsMu.RUnlock() + + if _, ok := ch.connections[pc.id]; ok { + t.Fatalf("connID %s still exists in connections", pc.id) + } + if _, ok := ch.sessionConnections[pc.sessionID]; ok { + t.Fatalf("session %s still exists in sessionConnections", pc.sessionID) + } + if got := len(ch.connections); got != 0 { + t.Fatalf("len(connections)=%d want=0", got) + } +} + +func TestBroadcastToSession_TargetsOnlyRequestedSession(t *testing.T) { + ch := newTestPicoChannel(t) + + target := &picoConn{id: "target", sessionID: "s-target"} + target.closed.Store(true) + ch.addConnForTest(target) + + other := &picoConn{id: "other", sessionID: "s-other"} + ch.addConnForTest(other) + + err := ch.broadcastToSession("pico:s-target", newMessage(TypeMessageCreate, map[string]any{"content": "hello"})) + if err == nil { + t.Fatal("expected send failure due to closed target connection") + } + if !errors.Is(err, channels.ErrSendFailed) { + t.Fatalf("expected ErrSendFailed, got %v", err) + } +} + +func (c *PicoChannel) addConnForTest(pc *picoConn) { + c.connsMu.Lock() + defer c.connsMu.Unlock() + if c.connections == nil { + c.connections = make(map[string]*picoConn) + } + if c.sessionConnections == nil { + c.sessionConnections = make(map[string]map[string]*picoConn) + } + if _, exists := c.connections[pc.id]; exists { + panic(fmt.Sprintf("duplicate conn id in test: %s", pc.id)) + } + c.connections[pc.id] = pc + bySession, ok := c.sessionConnections[pc.sessionID] + if !ok { + bySession = make(map[string]*picoConn) + c.sessionConnections[pc.sessionID] = bySession + } + bySession[pc.id] = pc +}