Files
picoclaw/pkg/channels/pico/pico_test.go
T
LC 6aff5b7ccd fix(pico): use O(1) session indexing and harden websocket concurrency handling (#1970)
* perf(pico): implement O(1) session lookup for pico connections

- Replace `sync.Map` with `connections` and `sessionConnections`.
- Add `addConnection`, `removeConnection`, `sessionConnectionsSnapshot`, and `takeAllConnections` with `connsMu` for concurrency.
- `broadcastToSession` now dispatches directly to `sessionConnections`.
- Add `newUniqueConnID` to avoid UUID collision/overwrites.
- Ensure `Stop` and `readLoop` use the new helpers for safe cleanup and correct `connCount` updates.

* refactor(pico): replace addConnection with createAndAddConnection for atomic connID generation

* refactor(pico): clear connections in one time to improve perf

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix(pico): keep connCount consistent with connection indexes

* refactor(pico): make connCount a regular int guarded by connsMu

* fix(pico): enforce MaxConnections atomically on registration

* fix(pico): use temporary over-limit error and remove conn counter

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-24 23:25:27 +08:00

145 lines
3.4 KiB
Go

package pico
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func newTestPicoChannel(t *testing.T) *PicoChannel {
t.Helper()
cfg := config.PicoConfig{}
cfg.SetToken("test-token")
ch, err := NewPicoChannel(cfg, bus.NewMessageBus())
if err != nil {
t.Fatalf("NewPicoChannel: %v", err)
}
ch.ctx = context.Background()
return ch
}
func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) {
ch := newTestPicoChannel(t)
const (
maxConns = 5
goroutines = 64
sessionID = "session-a"
)
var wg sync.WaitGroup
var mu sync.Mutex
successCount := 0
errCount := 0
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func() {
defer wg.Done()
pc, err := ch.createAndAddConnection(nil, sessionID, maxConns)
mu.Lock()
defer mu.Unlock()
if err == nil {
successCount++
if pc == nil {
t.Errorf("pc is nil on success")
}
return
}
if !errors.Is(err, channels.ErrTemporary) {
t.Errorf("unexpected error: %v", err)
return
}
errCount++
}()
}
wg.Wait()
if successCount > maxConns {
t.Fatalf("successCount=%d > maxConns=%d", successCount, maxConns)
}
if successCount+errCount != goroutines {
t.Fatalf("success=%d err=%d total=%d want=%d", successCount, errCount, successCount+errCount, goroutines)
}
if got := ch.currentConnCount(); got != maxConns {
t.Fatalf("currentConnCount=%d want=%d", got, maxConns)
}
}
func TestRemoveConnection_CleansBothIndexes(t *testing.T) {
ch := newTestPicoChannel(t)
pc, err := ch.createAndAddConnection(nil, "session-cleanup", 10)
if err != nil {
t.Fatalf("createAndAddConnection: %v", err)
}
removed := ch.removeConnection(pc.id)
if removed == nil {
t.Fatal("removeConnection returned nil")
}
ch.connsMu.RLock()
defer ch.connsMu.RUnlock()
if _, ok := ch.connections[pc.id]; ok {
t.Fatalf("connID %s still exists in connections", pc.id)
}
if _, ok := ch.sessionConnections[pc.sessionID]; ok {
t.Fatalf("session %s still exists in sessionConnections", pc.sessionID)
}
if got := len(ch.connections); got != 0 {
t.Fatalf("len(connections)=%d want=0", got)
}
}
func TestBroadcastToSession_TargetsOnlyRequestedSession(t *testing.T) {
ch := newTestPicoChannel(t)
target := &picoConn{id: "target", sessionID: "s-target"}
target.closed.Store(true)
ch.addConnForTest(target)
other := &picoConn{id: "other", sessionID: "s-other"}
ch.addConnForTest(other)
err := ch.broadcastToSession("pico:s-target", newMessage(TypeMessageCreate, map[string]any{"content": "hello"}))
if err == nil {
t.Fatal("expected send failure due to closed target connection")
}
if !errors.Is(err, channels.ErrSendFailed) {
t.Fatalf("expected ErrSendFailed, got %v", err)
}
}
func (c *PicoChannel) addConnForTest(pc *picoConn) {
c.connsMu.Lock()
defer c.connsMu.Unlock()
if c.connections == nil {
c.connections = make(map[string]*picoConn)
}
if c.sessionConnections == nil {
c.sessionConnections = make(map[string]map[string]*picoConn)
}
if _, exists := c.connections[pc.id]; exists {
panic(fmt.Sprintf("duplicate conn id in test: %s", pc.id))
}
c.connections[pc.id] = pc
bySession, ok := c.sessionConnections[pc.sessionID]
if !ok {
bySession = make(map[string]*picoConn)
c.sessionConnections[pc.sessionID] = bySession
}
bySession[pc.id] = pc
}