Skip to content

Commit

Permalink
Ensure all Python multiprocessing tests have timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 16, 2024
1 parent 9a9c4f7 commit 3c09652
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 49 deletions.
5 changes: 2 additions & 3 deletions python/ucxx/ucxx/_lib/tests/test_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ucxx._lib.libucxx as ucx_api
from ucxx._lib.arr import Array
from ucxx.testing import terminate_process
from ucxx.testing import join_processes, terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -80,7 +80,6 @@ def test_message_probe():
args=(queue,),
)
client.start()
client.join(timeout=10)
server.join(timeout=10)
join_processes([client + server], timeout=10)
terminate_process(client)
terminate_process(server)
5 changes: 2 additions & 3 deletions python/ucxx/ucxx/_lib/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ucxx._lib.libucxx as ucx_api
from ucxx._lib.arr import Array
from ucxx.testing import terminate_process, wait_requests
from ucxx.testing import join_processes, terminate_process, wait_requests

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -108,7 +108,6 @@ def test_close_callback(server_close_callback):
args=(port, server_close_callback),
)
client.start()
client.join(timeout=10)
server.join(timeout=10)
join_processes([client + server], timeout=10)
terminate_process(client)
terminate_process(server)
5 changes: 2 additions & 3 deletions python/ucxx/ucxx/_lib/tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ucxx._lib import libucxx as ucx_api
from ucxx._lib.arr import Array
from ucxx.testing import terminate_process, wait_requests
from ucxx.testing import join_processes, terminate_process, wait_requests

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -128,7 +128,6 @@ def test_message_probe(transfer_api):
server.start()
client = mp.Process(target=_client_probe, args=(queue, transfer_api))
client.start()
client.join(timeout=10)
server.join(timeout=10)
join_processes([client + server], timeout=10)
terminate_process(client)
terminate_process(server)
9 changes: 4 additions & 5 deletions python/ucxx/ucxx/_lib_async/tests/test_benchmark_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from ucxx.benchmarks.utils import _run_cluster_server, _run_cluster_workers
from ucxx.testing import join_processes, terminate_process


async def _worker(rank, eps, args):
Expand Down Expand Up @@ -46,9 +47,7 @@ async def test_benchmark_cluster(n_chunks=1, n_nodes=2, n_workers=2):
)
)

join_processes(workers + [server], timeout=30)
for worker in workers:
worker.join()
assert not worker.exitcode

server.join()
assert not server.exitcode
terminate_process(worker)
terminate_process(server)
9 changes: 5 additions & 4 deletions python/ucxx/ucxx/_lib_async/tests/test_disconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ucxx
from ucxx._lib_async.utils import get_event_loop
from ucxx._lib_async.utils_test import wait_listener_client_handlers
from ucxx.testing import terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -127,9 +128,9 @@ def test_shutdown_unexpected_closed_peer(caplog, endpoint_error_handling):
args=(client_queue, server_queue, endpoint_error_handling),
)
p2.start()
p2.join()
p2.join(timeout=30)
server_queue.put("client is down")
p1.join()
p1.join(timeout=30)

assert not p1.exitcode
assert not p2.exitcode
terminate_process(p2)
terminate_process(p1)
18 changes: 7 additions & 11 deletions python/ucxx/ucxx/_lib_async/tests/test_from_worker_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import ucxx
from ucxx._lib_async.utils import get_event_loop, hash64bits
from ucxx.testing import join_processes, terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -90,11 +91,9 @@ def test_from_worker_address():
)
client.start()

client.join()
server.join()

assert not server.exitcode
assert not client.exitcode
join_processes([client, server], timeout=30)
terminate_process(client)
terminate_process(server)


def _get_address_info(address=None):
Expand Down Expand Up @@ -259,10 +258,7 @@ def test_from_worker_address_multinode(num_nodes):
client.start()
clients.append(client)

join_processes(clients + [server], timeout=30)
for client in clients:
client.join()

server.join()

assert not server.exitcode
assert not client.exitcode
terminate_process(client)
terminate_process(server)
32 changes: 17 additions & 15 deletions python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ucxx
from ucxx._lib_async.utils import get_event_loop
from ucxx.testing import join_processes, terminate_process

mp = mp.get_context("spawn")

Expand Down Expand Up @@ -174,18 +175,19 @@ def test_from_worker_address_error(error_type):
)
client.start()

server.join()
client.join()

assert not server.exitcode

if ucxx.get_ucx_version() < (1, 12, 0) and client.exitcode == 1:
if all(t in error_type for t in ["timeout", "send"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7527 with rc/ud."
)
elif all(t in error_type for t in ["timeout", "recv"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7531 with rc/ud."
)
assert not client.exitcode
join_processes([client, server], timeout=30)
terminate_process(server)
try:
terminate_process(client)
except RuntimeError as e:
if ucxx.get_ucx_version() < (1, 12, 0):
if all(t in error_type for t in ["timeout", "send"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7527 with rc/ud."
)
elif all(t in error_type for t in ["timeout", "recv"]):
pytest.xfail(
"Requires https://github.com/openucx/ucx/pull/7531 with rc/ud."
)
else:
raise e
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
send,
wait_listener_client_handlers,
)
from ucxx.testing import join_processes, terminate_process

cupy = pytest.importorskip("cupy")
rmm = pytest.importorskip("rmm")
Expand Down Expand Up @@ -240,8 +241,6 @@ def test_send_recv_cu(cuda_obj_generator, comm_api):
os.environ.update(env_client)
client_process.start()

server_process.join()
client_process.join()

assert server_process.exitcode == 0
assert client_process.exitcode == 0
join_processes([client, server], timeout=30)
terminate_process(client)
terminate_process(server)

0 comments on commit 3c09652

Please sign in to comment.