Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remodel multinode tests #685

Open
wants to merge 7 commits into
base: branch-0.19
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tests/test_multiple_processes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import asyncio
import multiprocessing
import random
import sys

import numpy as np
import pytest

import ucp


def listener(ports):
ucp.init()

async def _listener(ports):
async def write(ep):
close_msg = bytearray(2)
msg2send = np.arange(10)
msg2recv = np.empty_like(msg2send)

msgs = [ep.recv(close_msg), ep.send(msg2send), ep.recv(msg2recv)]
await asyncio.gather(*msgs, loop=asyncio.get_event_loop())

close_msg = int.from_bytes(close_msg, sys.byteorder)

if close_msg != 0:
await ep.close()
listeners[close_msg].close()

listeners = {}
for port in ports:
listeners[port] = ucp.create_listener(write, port=port)

try:
while not all(listener.closed() for listener in listeners.values()):
await asyncio.sleep(0.1)
except ucp.UCXCloseError:
pass

asyncio.get_event_loop().run_until_complete(_listener(ports))


def client(listener_ports):
ucp.init()

async def _client(listener_ports):
async def read(port, close):
close_msg = bytearray(int(port if close else 0).to_bytes(2, sys.byteorder))
msg2send = np.arange(10)
msg2recv = np.empty_like(msg2send)

ep = await ucp.create_endpoint(ucp.get_address(), port)
msgs = [ep.send(close_msg), ep.send(msg2send), ep.recv(msg2recv)]
await asyncio.gather(*msgs, loop=asyncio.get_event_loop())

close_after = 100
clients = []
for i in range(close_after):
for port in listener_ports:
close = i == close_after - 1
clients.append(read(port, close=close))

await asyncio.gather(*clients, loop=asyncio.get_event_loop())

asyncio.get_event_loop().run_until_complete(_client(listener_ports))


@pytest.mark.parametrize("num_listeners", [1, 2, 4, 8])
def test_send_recv_cu(num_listeners):
ports = set()
while len(ports) != num_listeners:
ports = ports.union(
[random.randint(13000, 23000) for n in range(num_listeners)]
)
ports = list(ports)

ctx = multiprocessing.get_context("spawn")
listener_process = ctx.Process(name="listener", target=listener, args=[ports])
client_process = ctx.Process(name="client", target=client, args=[ports])

listener_process.start()
client_process.start()

listener_process.join()
client_process.join()

assert listener_process.exitcode == 0
assert client_process.exitcode == 0
23 changes: 6 additions & 17 deletions tests/test_multiple_nodes.py → tests/test_single_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,22 @@ async def client_node(port):


@pytest.mark.asyncio
async def test_multiple_nodes():
lf1 = ucp.create_listener(server_node)
lf2 = ucp.create_listener(server_node)
assert lf1.port != lf2.port

nodes = []
for _ in range(10):
nodes.append(client_node(lf1.port))
nodes.append(client_node(lf2.port))
await asyncio.gather(*nodes, loop=asyncio.get_event_loop())


@pytest.mark.asyncio
async def test_one_server_many_clients():
async def test_one_listener_many_clients():
lf = ucp.create_listener(server_node)
clients = []
for _ in range(100):
for _ in range(50):
clients.append(client_node(lf.port))
await asyncio.gather(*clients, loop=asyncio.get_event_loop())


@pytest.mark.asyncio
async def test_two_servers_many_clients():
async def test_two_listeners_many_clients():
lf1 = ucp.create_listener(server_node)
lf2 = ucp.create_listener(server_node)
assert lf1.port != lf2.port

clients = []
for _ in range(100):
for _ in range(25):
clients.append(client_node(lf1.port))
clients.append(client_node(lf2.port))
await asyncio.gather(*clients, loop=asyncio.get_event_loop())