From 9a9c4f78faf694687761af9ffb6c20ef976daea0 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 16 Oct 2024 04:04:30 -0700 Subject: [PATCH] Add testing util function to join a list of processes --- python/ucxx/ucxx/_lib/tests/test_utils.py | 35 ++++++++++++++++++++++- python/ucxx/ucxx/testing.py | 24 ++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/python/ucxx/ucxx/_lib/tests/test_utils.py b/python/ucxx/ucxx/_lib/tests/test_utils.py index 50f596e2..22d993ee 100644 --- a/python/ucxx/ucxx/_lib/tests/test_utils.py +++ b/python/ucxx/ucxx/_lib/tests/test_utils.py @@ -3,11 +3,12 @@ import multiprocessing import re +import time from multiprocessing.queues import Empty import pytest -from ucxx.testing import terminate_process +from ucxx.testing import join_processes, terminate_process def _test_process(queue): @@ -84,3 +85,35 @@ def test_terminate_process_kill_timeout(mp_context): ValueError, match="Cannot close a process while it is still running.*" ): terminate_process(proc, kill_wait=0.0) + + +@pytest.mark.parametrize("mp_context", ["default", "fork", "forkserver", "spawn"]) +@pytest.mark.parametrize("num_processes", [1, 2, 4]) +def test_join_processes(mp_context, num_processes): + mp = ( + multiprocessing + if mp_context == "default" + else multiprocessing.get_context(mp_context) + ) + + queue = mp.Queue() + processes = [] + for _ in range(num_processes): + proc = mp.Process( + target=_test_process, + args=(queue,), + ) + proc.start() + processes.append(proc) + + start = time.monotonic() + join_processes(processes, timeout=1.25) + total_time = time.monotonic() - start + assert total_time >= 1.25 and total_time < 2.5 + + for proc in processes: + try: + terminate_process(proc) + except RuntimeError: + # The process has to be killed and that will raise a `RuntimeError` + pass diff --git a/python/ucxx/ucxx/testing.py b/python/ucxx/ucxx/testing.py index fe2f0f72..31d107a8 100644 --- a/python/ucxx/ucxx/testing.py +++ b/python/ucxx/ucxx/testing.py @@ -6,6 +6,30 @@ from typing import Type, Union +def join_processes( + processes: list[Type[BaseProcess]], + timeout: Union[float, int], +) -> None: + """ + Join a list of processes with a combined timeout. + + Join a list of processes with a combined timeout, for each process `join()` + is called with a timeout equal to the difference of `timeout` and the time + elapsed since this function was called. + + Parameters + ---------- + processes: + The list of processes to be joined. + timeout: float or integer + Maximum time to wait for all the processes to be joined. + """ + start = time.monotonic() + for p in processes: + t = timeout - (time.monotonic() - start) + p.join(timeout=t) + + def terminate_process( process: Type[BaseProcess], kill_wait: Union[float, int] = 3.0 ) -> None: