diff --git a/tests/conftest.py b/tests/conftest.py index b166177d9f..fb0a6dd459 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import pytest import torch +from torch import distributed import composer from composer.utils import dist, reproducibility @@ -156,6 +157,14 @@ def chdir_to_tmpdir(tmpdir: pathlib.Path): os.chdir(tmpdir) +@pytest.fixture(scope="function", autouse=True) +def destroy_process_group(): + """Teardown any existing process groups between tests.""" + yield + if distributed.is_available() and distributed.is_initialized(): + distributed.destroy_process_group() + + def pytest_sessionfinish(session: pytest.Session, exitstatus: int): if exitstatus == 5: session.exitstatus = 0 # Ignore no-test-ran errors