diff --git a/pytest_parallel/mpi_reporter.py b/pytest_parallel/mpi_reporter.py index 1a9e912..3c0b3f8 100644 --- a/pytest_parallel/mpi_reporter.py +++ b/pytest_parallel/mpi_reporter.py @@ -33,7 +33,7 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function): color = MPI.UNDEFINED return global_comm.Split(color, key=i_rank) else: - assert 0, 'unknown MPI communicator creation function' + assert 0, 'Unknown MPI communicator creation function. Available: `MPI_Comm_create`, `MPI_Comm_split`' def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function): i_rank = global_comm.Get_rank() @@ -45,36 +45,6 @@ def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function): return sub_comms - -def filter_and_add_sub_comm__old(items, global_comm): - i_rank = global_comm.Get_rank() - n_workers = global_comm.Get_size() - - filtered_items = [] - for item in items: - n_proc_test = get_n_proc_for_test(item) - - if n_proc_test > n_workers: # not enough procs: will be skipped - if global_comm.Get_rank() == 0: - item.sub_comm = MPI.COMM_SELF - mark_skip(item) - filtered_items += [item] - else: - item.sub_comm = MPI.COMM_NULL # TODO this should not be needed - else: - if i_rank < n_proc_test: - color = 1 - else: - color = MPI.UNDEFINED - - sub_comm = global_comm.Split(color) - - if sub_comm != MPI.COMM_NULL: - item.sub_comm = sub_comm - filtered_items += [item] - - return filtered_items - def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_function): i_rank = global_comm.Get_rank() n_rank = global_comm.Get_size() @@ -91,27 +61,19 @@ def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_funct if n_proc_test > n_rank: # not enough procs: mark as to be skipped mark_skip(item) item.sub_comm = MPI.COMM_NULL - #if n_proc_test > n_workers: # not enough procs: will be skipped - # if global_comm.Get_rank() == 0: - # item.sub_comm = MPI.COMM_SELF - # mark_skip(item) - # else: - # item.sub_comm = MPI.COMM_NULL # TODO this should not be needed else: if test_comm_creation == 'by_rank': item.sub_comm = sub_comms[n_proc_test-1] elif test_comm_creation == 'by_test': item.sub_comm = create_sub_comm_of_size(global_comm, n_proc_test, mpi_comm_creation_function) else: - assert 0, 'unknown test MPI communicator creation strategy' - + assert 0, 'Unknown test MPI communicator creation strategy. Available: `by_rank`, `by_test`' class SequentialScheduler: - def __init__(self, global_comm, test_comm_creation='by_rank', mpi_comm_creation_function='MPI_Comm_create', barrier_at_test_start=True, barrier_at_test_end=False): + def __init__(self, global_comm, test_comm_creation='by_rank', mpi_comm_creation_function='MPI_Comm_create', barrier_at_test_start=True, barrier_at_test_end=True): self.global_comm = global_comm.Dup() # ensure that all communications within the framework are private to the framework self.test_comm_creation = test_comm_creation self.mpi_comm_creation_function = mpi_comm_creation_function - self.barrier_at_test_start = barrier_at_test_start self.barrier_at_test_end = barrier_at_test_end @@ -119,45 +81,31 @@ def __init__(self, global_comm, test_comm_creation='by_rank', mpi_comm_creation_ def pytest_collection_modifyitems(self, config, items): add_sub_comm(items, self.global_comm, self.test_comm_creation, self.mpi_comm_creation_function) - #@pytest.hookimpl(tryfirst=True) - #def pytest_runtest_protocol(self, item, nextitem): - # #i_rank = self.global_comm.Get_rank() - # #n_proc_test = get_n_proc_for_test(item) - # #if i_rank < n_proc_test: - # # sub_comm = sub_comm_from_ranks(self.global_comm, range(0,n_proc_test)) - # #else: - # # sub_comm = MPI.COMM_NULL - # #item.sub_comm = sub_comm - # n_proc_test = get_n_proc_for_test(item) - # #if n_proc_test <= self.global_comm.Get_size(): - # #if n_proc_test < self.global_comm.rank: - # item.sub_comm = self.sub_comms[n_proc_test-1] - # #else: - # # item.sub_comm = MPI.COMM_NULL - @pytest.hookimpl(hookwrapper=True, tryfirst=True) def pytest_runtest_protocol(self, item, nextitem): if self.barrier_at_test_start: self.global_comm.barrier() + #print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}') _ = yield + #print(f'pytest_runtest_protocol end {MPI.COMM_WORLD.rank=}') if self.barrier_at_test_end: self.global_comm.barrier() #@pytest.hookimpl(tryfirst=True) #def pytest_runtest_protocol(self, item, nextitem): - # pass - # #return True - # #if item.sub_comm != MPI.COMM_NULL: - # # _ = yield - # #else: - # # return True + # if self.barrier_at_test_start: + # self.global_comm.barrier() + # print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}') + # if item.sub_comm == MPI.COMM_NULL: + # return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run - @pytest.hookimpl(hookwrapper=True, tryfirst=True) + @pytest.hookimpl(tryfirst=True) def pytest_pyfunc_call(self, pyfuncitem): - if pyfuncitem.sub_comm != MPI.COMM_NULL: - _ = yield - else: # the rank does not participate in the test, so do nothing - return True + #print(f'pytest_pyfunc_call {MPI.COMM_WORLD.rank=}') + # This is where the test is normally run. + # Only run the test for the ranks that do participate in the test + if pyfuncitem.sub_comm == MPI.COMM_NULL: + return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run @pytest.hookimpl(hookwrapper=True, tryfirst=True) def pytest_runtestloop(self, session) -> bool: diff --git a/pytest_parallel/plugin.py b/pytest_parallel/plugin.py index 5a9ad05..57b12aa 100644 --- a/pytest_parallel/plugin.py +++ b/pytest_parallel/plugin.py @@ -7,7 +7,8 @@ import pytest from pathlib import Path import argparse - +#from mpi4py import MPI +#from logger import consoleLogger # -------------------------------------------------------------------------- def pytest_addoption(parser): @@ -164,15 +165,19 @@ def __init__(self, comm): self.tmp_path = None def __enter__(self): - rank = self.comm.Get_rank() - self.tmp_dir = tempfile.TemporaryDirectory() if rank == 0 else None - self.tmp_path = Path(self.tmp_dir.name) if rank == 0 else None - return self.comm.bcast(self.tmp_path, root=0) + from mpi4py import MPI + if self.comm != MPI.COMM_NULL: # TODO DEL once non-participating rank do not participate in fixtures either + rank = self.comm.Get_rank() + self.tmp_dir = tempfile.TemporaryDirectory() if rank == 0 else None + self.tmp_path = Path(self.tmp_dir.name) if rank == 0 else None + return self.comm.bcast(self.tmp_path, root=0) def __exit__(self, type, value, traceback): - self.comm.barrier() - if self.comm.Get_rank() == 0: - self.tmp_dir.cleanup() + from mpi4py import MPI + if self.comm != MPI.COMM_NULL: # TODO DEL once non-participating rank do not participate in fixtures either + self.comm.barrier() + if self.comm.Get_rank() == 0: + self.tmp_dir.cleanup() @pytest.fixture