mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
6aff5b7ccd
* perf(pico): implement O(1) session lookup for pico connections - Replace `sync.Map` with `connections` and `sessionConnections`. - Add `addConnection`, `removeConnection`, `sessionConnectionsSnapshot`, and `takeAllConnections` with `connsMu` for concurrency. - `broadcastToSession` now dispatches directly to `sessionConnections`. - Add `newUniqueConnID` to avoid UUID collision/overwrites. - Ensure `Stop` and `readLoop` use the new helpers for safe cleanup and correct `connCount` updates. * refactor(pico): replace addConnection with createAndAddConnection for atomic connID generation * refactor(pico): clear connections in one time to improve perf Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(pico): keep connCount consistent with connection indexes * refactor(pico): make connCount a regular int guarded by connsMu * fix(pico): enforce MaxConnections atomically on registration * fix(pico): use temporary over-limit error and remove conn counter --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
145 lines
3.4 KiB
Go
145 lines
3.4 KiB
Go
package pico
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/channels"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
)
|
|
|
|
func newTestPicoChannel(t *testing.T) *PicoChannel {
|
|
t.Helper()
|
|
|
|
cfg := config.PicoConfig{}
|
|
cfg.SetToken("test-token")
|
|
ch, err := NewPicoChannel(cfg, bus.NewMessageBus())
|
|
if err != nil {
|
|
t.Fatalf("NewPicoChannel: %v", err)
|
|
}
|
|
|
|
ch.ctx = context.Background()
|
|
return ch
|
|
}
|
|
|
|
func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) {
|
|
ch := newTestPicoChannel(t)
|
|
|
|
const (
|
|
maxConns = 5
|
|
goroutines = 64
|
|
sessionID = "session-a"
|
|
)
|
|
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex
|
|
successCount := 0
|
|
errCount := 0
|
|
|
|
wg.Add(goroutines)
|
|
for i := 0; i < goroutines; i++ {
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
pc, err := ch.createAndAddConnection(nil, sessionID, maxConns)
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if err == nil {
|
|
successCount++
|
|
if pc == nil {
|
|
t.Errorf("pc is nil on success")
|
|
}
|
|
return
|
|
}
|
|
if !errors.Is(err, channels.ErrTemporary) {
|
|
t.Errorf("unexpected error: %v", err)
|
|
return
|
|
}
|
|
errCount++
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
if successCount > maxConns {
|
|
t.Fatalf("successCount=%d > maxConns=%d", successCount, maxConns)
|
|
}
|
|
if successCount+errCount != goroutines {
|
|
t.Fatalf("success=%d err=%d total=%d want=%d", successCount, errCount, successCount+errCount, goroutines)
|
|
}
|
|
if got := ch.currentConnCount(); got != maxConns {
|
|
t.Fatalf("currentConnCount=%d want=%d", got, maxConns)
|
|
}
|
|
}
|
|
|
|
func TestRemoveConnection_CleansBothIndexes(t *testing.T) {
|
|
ch := newTestPicoChannel(t)
|
|
|
|
pc, err := ch.createAndAddConnection(nil, "session-cleanup", 10)
|
|
if err != nil {
|
|
t.Fatalf("createAndAddConnection: %v", err)
|
|
}
|
|
|
|
removed := ch.removeConnection(pc.id)
|
|
if removed == nil {
|
|
t.Fatal("removeConnection returned nil")
|
|
}
|
|
|
|
ch.connsMu.RLock()
|
|
defer ch.connsMu.RUnlock()
|
|
|
|
if _, ok := ch.connections[pc.id]; ok {
|
|
t.Fatalf("connID %s still exists in connections", pc.id)
|
|
}
|
|
if _, ok := ch.sessionConnections[pc.sessionID]; ok {
|
|
t.Fatalf("session %s still exists in sessionConnections", pc.sessionID)
|
|
}
|
|
if got := len(ch.connections); got != 0 {
|
|
t.Fatalf("len(connections)=%d want=0", got)
|
|
}
|
|
}
|
|
|
|
func TestBroadcastToSession_TargetsOnlyRequestedSession(t *testing.T) {
|
|
ch := newTestPicoChannel(t)
|
|
|
|
target := &picoConn{id: "target", sessionID: "s-target"}
|
|
target.closed.Store(true)
|
|
ch.addConnForTest(target)
|
|
|
|
other := &picoConn{id: "other", sessionID: "s-other"}
|
|
ch.addConnForTest(other)
|
|
|
|
err := ch.broadcastToSession("pico:s-target", newMessage(TypeMessageCreate, map[string]any{"content": "hello"}))
|
|
if err == nil {
|
|
t.Fatal("expected send failure due to closed target connection")
|
|
}
|
|
if !errors.Is(err, channels.ErrSendFailed) {
|
|
t.Fatalf("expected ErrSendFailed, got %v", err)
|
|
}
|
|
}
|
|
|
|
func (c *PicoChannel) addConnForTest(pc *picoConn) {
|
|
c.connsMu.Lock()
|
|
defer c.connsMu.Unlock()
|
|
if c.connections == nil {
|
|
c.connections = make(map[string]*picoConn)
|
|
}
|
|
if c.sessionConnections == nil {
|
|
c.sessionConnections = make(map[string]map[string]*picoConn)
|
|
}
|
|
if _, exists := c.connections[pc.id]; exists {
|
|
panic(fmt.Sprintf("duplicate conn id in test: %s", pc.id))
|
|
}
|
|
c.connections[pc.id] = pc
|
|
bySession, ok := c.sessionConnections[pc.sessionID]
|
|
if !ok {
|
|
bySession = make(map[string]*picoConn)
|
|
c.sessionConnections[pc.sessionID] = bySession
|
|
}
|
|
bySession[pc.id] = pc
|
|
}
|