mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(pico): use O(1) session indexing and harden websocket concurrency handling (#1970)
* perf(pico): implement O(1) session lookup for pico connections - Replace `sync.Map` with `connections` and `sessionConnections`. - Add `addConnection`, `removeConnection`, `sessionConnectionsSnapshot`, and `takeAllConnections` with `connsMu` for concurrency. - `broadcastToSession` now dispatches directly to `sessionConnections`. - Add `newUniqueConnID` to avoid UUID collision/overwrites. - Ensure `Stop` and `readLoop` use the new helpers for safe cleanup and correct `connCount` updates. * refactor(pico): replace addConnection with createAndAddConnection for atomic connID generation * refactor(pico): clear connections in one time to improve perf Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(pico): keep connCount consistent with connection indexes * refactor(pico): make connCount a regular int guarded by connsMu * fix(pico): enforce MaxConnections atomically on registration * fix(pico): use temporary over-limit error and remove conn counter --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
+132
-45
@@ -54,12 +54,13 @@ func (pc *picoConn) close() {
|
||||
// It serves as the reference implementation for all optional capability interfaces.
|
||||
type PicoChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.PicoConfig
|
||||
upgrader websocket.Upgrader
|
||||
connections sync.Map // connID → *picoConn
|
||||
connCount atomic.Int32
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config config.PicoConfig
|
||||
upgrader websocket.Upgrader
|
||||
connections map[string]*picoConn // connID -> *picoConn
|
||||
sessionConnections map[string]map[string]*picoConn // sessionID -> connID -> *picoConn
|
||||
connsMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewPicoChannel creates a new Pico Protocol channel.
|
||||
@@ -92,9 +93,104 @@ func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoCha
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
},
|
||||
connections: make(map[string]*picoConn),
|
||||
sessionConnections: make(map[string]map[string]*picoConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createAndAddConnection checks MaxConnections and registers a connection atomically.
|
||||
func (c *PicoChannel) createAndAddConnection(conn *websocket.Conn, sessionID string, maxConns int) (*picoConn, error) {
|
||||
c.connsMu.Lock()
|
||||
defer c.connsMu.Unlock()
|
||||
if len(c.connections) >= maxConns {
|
||||
return nil, channels.ErrTemporary
|
||||
}
|
||||
|
||||
var connID string
|
||||
for {
|
||||
connID = uuid.New().String()
|
||||
if _, exists := c.connections[connID]; !exists {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
pc := &picoConn{
|
||||
id: connID,
|
||||
conn: conn,
|
||||
sessionID: sessionID,
|
||||
}
|
||||
|
||||
c.connections[pc.id] = pc
|
||||
bySession, ok := c.sessionConnections[pc.sessionID]
|
||||
if !ok {
|
||||
bySession = make(map[string]*picoConn)
|
||||
c.sessionConnections[pc.sessionID] = bySession
|
||||
}
|
||||
bySession[pc.id] = pc
|
||||
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
// removeConnection deletes a connection from indexes and returns it when found.
|
||||
func (c *PicoChannel) removeConnection(connID string) *picoConn {
|
||||
c.connsMu.Lock()
|
||||
defer c.connsMu.Unlock()
|
||||
|
||||
pc, ok := c.connections[connID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(c.connections, connID)
|
||||
if bySession, ok := c.sessionConnections[pc.sessionID]; ok {
|
||||
delete(bySession, connID)
|
||||
if len(bySession) == 0 {
|
||||
delete(c.sessionConnections, pc.sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
return pc
|
||||
}
|
||||
|
||||
// takeAllConnections snapshots and clears all connection indexes.
|
||||
func (c *PicoChannel) takeAllConnections() []*picoConn {
|
||||
c.connsMu.Lock()
|
||||
defer c.connsMu.Unlock()
|
||||
|
||||
all := make([]*picoConn, 0, len(c.connections))
|
||||
for _, pc := range c.connections {
|
||||
all = append(all, pc)
|
||||
}
|
||||
clear(c.connections)
|
||||
clear(c.sessionConnections)
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// sessionConnectionsSnapshot returns all active connections for a session.
|
||||
func (c *PicoChannel) sessionConnectionsSnapshot(sessionID string) []*picoConn {
|
||||
c.connsMu.RLock()
|
||||
defer c.connsMu.RUnlock()
|
||||
|
||||
bySession, ok := c.sessionConnections[sessionID]
|
||||
if !ok || len(bySession) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
conns := make([]*picoConn, 0, len(bySession))
|
||||
for _, pc := range bySession {
|
||||
conns = append(conns, pc)
|
||||
}
|
||||
return conns
|
||||
}
|
||||
|
||||
// currentConnCount returns a lock-protected snapshot of active connection count.
|
||||
func (c *PicoChannel) currentConnCount() int {
|
||||
c.connsMu.RLock()
|
||||
defer c.connsMu.RUnlock()
|
||||
return len(c.connections)
|
||||
}
|
||||
|
||||
// Start implements Channel.
|
||||
func (c *PicoChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("pico", "Starting Pico Protocol channel")
|
||||
@@ -110,13 +206,9 @@ func (c *PicoChannel) Stop(ctx context.Context) error {
|
||||
c.SetRunning(false)
|
||||
|
||||
// Close all connections
|
||||
c.connections.Range(func(key, value any) bool {
|
||||
if pc, ok := value.(*picoConn); ok {
|
||||
pc.close()
|
||||
}
|
||||
c.connections.Delete(key)
|
||||
return true
|
||||
})
|
||||
for _, pc := range c.takeAllConnections() {
|
||||
pc.close()
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
@@ -133,8 +225,8 @@ func (c *PicoChannel) WebhookPath() string { return "/pico/" }
|
||||
func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/pico")
|
||||
|
||||
switch {
|
||||
case path == "/ws" || path == "/ws/":
|
||||
switch path {
|
||||
case "/ws", "/ws/":
|
||||
c.handleWebSocket(w, r)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
@@ -208,23 +300,16 @@ func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error {
|
||||
msg.SessionID = sessionID
|
||||
|
||||
var sent bool
|
||||
c.connections.Range(func(key, value any) bool {
|
||||
pc, ok := value.(*picoConn)
|
||||
if !ok {
|
||||
return true
|
||||
for _, pc := range c.sessionConnectionsSnapshot(sessionID) {
|
||||
if err := pc.writeJSON(msg); err != nil {
|
||||
logger.DebugCF("pico", "Write to connection failed", map[string]any{
|
||||
"conn_id": pc.id,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
sent = true
|
||||
}
|
||||
if pc.sessionID == sessionID {
|
||||
if err := pc.writeJSON(msg); err != nil {
|
||||
logger.DebugCF("pico", "Write to connection failed", map[string]any{
|
||||
"conn_id": pc.id,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
sent = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if !sent {
|
||||
return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed)
|
||||
@@ -250,7 +335,7 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
if maxConns <= 0 {
|
||||
maxConns = 100
|
||||
}
|
||||
if int(c.connCount.Load()) >= maxConns {
|
||||
if c.currentConnCount() >= maxConns {
|
||||
http.Error(w, "too many connections", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
@@ -275,15 +360,17 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID = uuid.New().String()
|
||||
}
|
||||
|
||||
pc := &picoConn{
|
||||
id: uuid.New().String(),
|
||||
conn: conn,
|
||||
sessionID: sessionID,
|
||||
pc, err := c.createAndAddConnection(conn, sessionID, maxConns)
|
||||
if err != nil {
|
||||
_ = conn.WriteControl(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseTryAgainLater, "too many connections"),
|
||||
time.Now().Add(2*time.Second),
|
||||
)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
c.connections.Store(pc.id, pc)
|
||||
c.connCount.Add(1)
|
||||
|
||||
logger.InfoCF("pico", "WebSocket client connected", map[string]any{
|
||||
"conn_id": pc.id,
|
||||
"session_id": sessionID,
|
||||
@@ -341,12 +428,12 @@ func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
|
||||
func (c *PicoChannel) readLoop(pc *picoConn) {
|
||||
defer func() {
|
||||
pc.close()
|
||||
c.connections.Delete(pc.id)
|
||||
c.connCount.Add(-1)
|
||||
logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{
|
||||
"conn_id": pc.id,
|
||||
"session_id": pc.sessionID,
|
||||
})
|
||||
if removed := c.removeConnection(pc.id); removed != nil {
|
||||
logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{
|
||||
"conn_id": removed.id,
|
||||
"session_id": removed.sessionID,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
readTimeout := time.Duration(c.config.ReadTimeout) * time.Second
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
package pico
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func newTestPicoChannel(t *testing.T) *PicoChannel {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.PicoConfig{}
|
||||
cfg.SetToken("test-token")
|
||||
ch, err := NewPicoChannel(cfg, bus.NewMessageBus())
|
||||
if err != nil {
|
||||
t.Fatalf("NewPicoChannel: %v", err)
|
||||
}
|
||||
|
||||
ch.ctx = context.Background()
|
||||
return ch
|
||||
}
|
||||
|
||||
func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
|
||||
const (
|
||||
maxConns = 5
|
||||
goroutines = 64
|
||||
sessionID = "session-a"
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
successCount := 0
|
||||
errCount := 0
|
||||
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
pc, err := ch.createAndAddConnection(nil, sessionID, maxConns)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
successCount++
|
||||
if pc == nil {
|
||||
t.Errorf("pc is nil on success")
|
||||
}
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, channels.ErrTemporary) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
errCount++
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if successCount > maxConns {
|
||||
t.Fatalf("successCount=%d > maxConns=%d", successCount, maxConns)
|
||||
}
|
||||
if successCount+errCount != goroutines {
|
||||
t.Fatalf("success=%d err=%d total=%d want=%d", successCount, errCount, successCount+errCount, goroutines)
|
||||
}
|
||||
if got := ch.currentConnCount(); got != maxConns {
|
||||
t.Fatalf("currentConnCount=%d want=%d", got, maxConns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveConnection_CleansBothIndexes(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
|
||||
pc, err := ch.createAndAddConnection(nil, "session-cleanup", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("createAndAddConnection: %v", err)
|
||||
}
|
||||
|
||||
removed := ch.removeConnection(pc.id)
|
||||
if removed == nil {
|
||||
t.Fatal("removeConnection returned nil")
|
||||
}
|
||||
|
||||
ch.connsMu.RLock()
|
||||
defer ch.connsMu.RUnlock()
|
||||
|
||||
if _, ok := ch.connections[pc.id]; ok {
|
||||
t.Fatalf("connID %s still exists in connections", pc.id)
|
||||
}
|
||||
if _, ok := ch.sessionConnections[pc.sessionID]; ok {
|
||||
t.Fatalf("session %s still exists in sessionConnections", pc.sessionID)
|
||||
}
|
||||
if got := len(ch.connections); got != 0 {
|
||||
t.Fatalf("len(connections)=%d want=0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastToSession_TargetsOnlyRequestedSession(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
|
||||
target := &picoConn{id: "target", sessionID: "s-target"}
|
||||
target.closed.Store(true)
|
||||
ch.addConnForTest(target)
|
||||
|
||||
other := &picoConn{id: "other", sessionID: "s-other"}
|
||||
ch.addConnForTest(other)
|
||||
|
||||
err := ch.broadcastToSession("pico:s-target", newMessage(TypeMessageCreate, map[string]any{"content": "hello"}))
|
||||
if err == nil {
|
||||
t.Fatal("expected send failure due to closed target connection")
|
||||
}
|
||||
if !errors.Is(err, channels.ErrSendFailed) {
|
||||
t.Fatalf("expected ErrSendFailed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PicoChannel) addConnForTest(pc *picoConn) {
|
||||
c.connsMu.Lock()
|
||||
defer c.connsMu.Unlock()
|
||||
if c.connections == nil {
|
||||
c.connections = make(map[string]*picoConn)
|
||||
}
|
||||
if c.sessionConnections == nil {
|
||||
c.sessionConnections = make(map[string]map[string]*picoConn)
|
||||
}
|
||||
if _, exists := c.connections[pc.id]; exists {
|
||||
panic(fmt.Sprintf("duplicate conn id in test: %s", pc.id))
|
||||
}
|
||||
c.connections[pc.id] = pc
|
||||
bySession, ok := c.sessionConnections[pc.sessionID]
|
||||
if !ok {
|
||||
bySession = make(map[string]*picoConn)
|
||||
c.sessionConnections[pc.sessionID] = bySession
|
||||
}
|
||||
bySession[pc.id] = pc
|
||||
}
|
||||
Reference in New Issue
Block a user