Skip to content

Commit 815e833

Browse files
committed
add compute_capabilities_after
Signed-off-by: YunLiu <[email protected]>
1 parent 0bb20a8 commit 815e833

File tree

5 files changed

+52
-51
lines changed

5 files changed

+52
-51
lines changed

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
run_eval,
124124
version_geq,
125125
version_leq,
126+
compute_capabilities_after,
126127
)
127128
from .nvtx import Range
128129
from .ordering import Ordering

monai/utils/module.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from re import match
2727
from types import FunctionType, ModuleType
2828
from typing import Any, cast
29+
import pynvml
2930

3031
import torch
3132

@@ -634,3 +635,36 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
634635
if is_prerelease:
635636
return False
636637
return True
638+
639+
@functools.lru_cache(None)
640+
def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool:
641+
"""
642+
Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
643+
The current system GPU CUDA compute capability is determined by the first GPU in the system.
644+
The compared version is a string in the form of "major.minor".
645+
646+
Args:
647+
major: major version number to be compared with.
648+
minor: minor version number to be compared with. Defaults to 0.
649+
current_ver_string: if None, the current system GPU CUDA compute capability will be used.
650+
651+
Returns:
652+
True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
653+
"""
654+
if current_ver_string is None:
655+
pynvml.nvmlInit()
656+
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU
657+
major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
658+
pynvml.nvmlShutdown()
659+
current_ver_string = f"{major_c}.{minor_c}"
660+
661+
ver, has_ver = optional_import("packaging.version", name="parse")
662+
if has_ver:
663+
return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
664+
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
665+
while len(parts) < 2:
666+
parts += ["0"]
667+
c_major, c_minor = parts[:2]
668+
c_mn = int(c_major), int(c_minor)
669+
mn = int(major), int(minor)
670+
return c_mn > mn

tests/test_pytorch_version_after.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

tests/test_trt_compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from monai.networks import trt_compile
2222
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
2323
from monai.utils import min_version, optional_import
24-
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
24+
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows, SkipIfBeforeComputeCapabilityVersion
2525

2626
trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
2727
polygraphy, polygraphy_imported = optional_import("polygraphy")
@@ -36,6 +36,7 @@
3636
@skip_if_quick
3737
@unittest.skipUnless(trt_imported, "tensorrt is required")
3838
@unittest.skipUnless(polygraphy_imported, "polygraphy is required")
39+
@SkipIfBeforeComputeCapabilityVersion((7, 0))
3940
class TestTRTCompile(unittest.TestCase):
4041

4142
def setUp(self):

tests/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from monai.networks import convert_to_onnx, convert_to_torchscript
4848
from monai.utils import optional_import
4949
from monai.utils.misc import MONAIEnvVars
50-
from monai.utils.module import pytorch_after
50+
from monai.utils.module import pytorch_after, compute_capabilities_after
5151
from monai.utils.tf32 import detect_default_tf32
5252
from monai.utils.type_conversion import convert_data_type
5353

@@ -286,6 +286,20 @@ def __call__(self, obj):
286286
)(obj)
287287

288288

289+
class SkipIfBeforeComputeCapabilityVersion:
290+
"""Decorator to be used if test should be skipped
291+
with Compute Capability older than that given."""
292+
293+
def __init__(self, compute_capability_tuple):
294+
self.min_version = compute_capability_tuple
295+
self.version_too_old = not compute_capabilities_after(*compute_capability_tuple)
296+
297+
def __call__(self, obj):
298+
return unittest.skipIf(
299+
self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}"
300+
)(obj)
301+
302+
289303
def is_main_test_process():
290304
ps = torch.multiprocessing.current_process()
291305
if not ps or not hasattr(ps, "name"):

0 commit comments

Comments
 (0)