Skip to content

Commit

Permalink
Add testing util function to join a list of processes
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 16, 2024
1 parent 1f4e508 commit 9a9c4f7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
35 changes: 34 additions & 1 deletion python/ucxx/ucxx/_lib/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import multiprocessing
import re
import time
from multiprocessing.queues import Empty

import pytest

from ucxx.testing import terminate_process
from ucxx.testing import join_processes, terminate_process


def _test_process(queue):
Expand Down Expand Up @@ -84,3 +85,35 @@ def test_terminate_process_kill_timeout(mp_context):
ValueError, match="Cannot close a process while it is still running.*"
):
terminate_process(proc, kill_wait=0.0)


@pytest.mark.parametrize("mp_context", ["default", "fork", "forkserver", "spawn"])
@pytest.mark.parametrize("num_processes", [1, 2, 4])
def test_join_processes(mp_context, num_processes):
mp = (
multiprocessing
if mp_context == "default"
else multiprocessing.get_context(mp_context)
)

queue = mp.Queue()
processes = []
for _ in range(num_processes):
proc = mp.Process(
target=_test_process,
args=(queue,),
)
proc.start()
processes.append(proc)

start = time.monotonic()
join_processes(processes, timeout=1.25)
total_time = time.monotonic() - start
assert total_time >= 1.25 and total_time < 2.5

for proc in processes:
try:
terminate_process(proc)
except RuntimeError:
# The process has to be killed and that will raise a `RuntimeError`
pass
24 changes: 24 additions & 0 deletions python/ucxx/ucxx/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,30 @@
from typing import Type, Union


def join_processes(
processes: list[Type[BaseProcess]],
timeout: Union[float, int],
) -> None:
"""
Join a list of processes with a combined timeout.
Join a list of processes with a combined timeout, for each process `join()`
is called with a timeout equal to the difference of `timeout` and the time
elapsed since this function was called.
Parameters
----------
processes:
The list of processes to be joined.
timeout: float or integer
Maximum time to wait for all the processes to be joined.
"""
start = time.monotonic()
for p in processes:
t = timeout - (time.monotonic() - start)
p.join(timeout=t)


def terminate_process(
process: Type[BaseProcess], kill_wait: Union[float, int] = 3.0
) -> None:
Expand Down

0 comments on commit 9a9c4f7

Please sign in to comment.