package pico import ( "context" "errors" "fmt" "sync" "testing" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) func newTestPicoChannel(t *testing.T) *PicoChannel { t.Helper() cfg := config.PicoConfig{} cfg.SetToken("test-token") ch, err := NewPicoChannel(cfg, bus.NewMessageBus()) if err != nil { t.Fatalf("NewPicoChannel: %v", err) } ch.ctx = context.Background() return ch } func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) { ch := newTestPicoChannel(t) const ( maxConns = 5 goroutines = 64 sessionID = "session-a" ) var wg sync.WaitGroup var mu sync.Mutex successCount := 0 errCount := 0 wg.Add(goroutines) for i := 0; i < goroutines; i++ { go func() { defer wg.Done() pc, err := ch.createAndAddConnection(nil, sessionID, maxConns) mu.Lock() defer mu.Unlock() if err == nil { successCount++ if pc == nil { t.Errorf("pc is nil on success") } return } if !errors.Is(err, channels.ErrTemporary) { t.Errorf("unexpected error: %v", err) return } errCount++ }() } wg.Wait() if successCount > maxConns { t.Fatalf("successCount=%d > maxConns=%d", successCount, maxConns) } if successCount+errCount != goroutines { t.Fatalf("success=%d err=%d total=%d want=%d", successCount, errCount, successCount+errCount, goroutines) } if got := ch.currentConnCount(); got != maxConns { t.Fatalf("currentConnCount=%d want=%d", got, maxConns) } } func TestRemoveConnection_CleansBothIndexes(t *testing.T) { ch := newTestPicoChannel(t) pc, err := ch.createAndAddConnection(nil, "session-cleanup", 10) if err != nil { t.Fatalf("createAndAddConnection: %v", err) } removed := ch.removeConnection(pc.id) if removed == nil { t.Fatal("removeConnection returned nil") } ch.connsMu.RLock() defer ch.connsMu.RUnlock() if _, ok := ch.connections[pc.id]; ok { t.Fatalf("connID %s still exists in connections", pc.id) } if _, ok := ch.sessionConnections[pc.sessionID]; ok { t.Fatalf("session %s still exists in sessionConnections", pc.sessionID) } if got := len(ch.connections); got != 0 { t.Fatalf("len(connections)=%d want=0", got) } } func TestBroadcastToSession_TargetsOnlyRequestedSession(t *testing.T) { ch := newTestPicoChannel(t) target := &picoConn{id: "target", sessionID: "s-target"} target.closed.Store(true) ch.addConnForTest(target) other := &picoConn{id: "other", sessionID: "s-other"} ch.addConnForTest(other) err := ch.broadcastToSession("pico:s-target", newMessage(TypeMessageCreate, map[string]any{"content": "hello"})) if err == nil { t.Fatal("expected send failure due to closed target connection") } if !errors.Is(err, channels.ErrSendFailed) { t.Fatalf("expected ErrSendFailed, got %v", err) } } func (c *PicoChannel) addConnForTest(pc *picoConn) { c.connsMu.Lock() defer c.connsMu.Unlock() if c.connections == nil { c.connections = make(map[string]*picoConn) } if c.sessionConnections == nil { c.sessionConnections = make(map[string]map[string]*picoConn) } if _, exists := c.connections[pc.id]; exists { panic(fmt.Sprintf("duplicate conn id in test: %s", pc.id)) } c.connections[pc.id] = pc bySession, ok := c.sessionConnections[pc.sessionID] if !ok { bySession = make(map[string]*picoConn) c.sessionConnections[pc.sessionID] = bySession } bySession[pc.id] = pc }