Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,8 @@ def trt_export(
"""
Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript.
Currently, this API only supports converting models whose inputs are all tensors.
Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
Review the TensorRT Support Matrix for which GPUs are supported.

There are two ways to export a model:
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
Expand Down
4 changes: 3 additions & 1 deletion monai/networks/trt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,9 @@ def trt_compile(
) -> torch.nn.Module:
"""
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.
NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
Review the TensorRT Support Matrix for which GPUs are supported.
Args:
model: module to patch with TrtCompiler object.
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
InvalidPyTorchVersionError,
OptionalImportError,
allow_missing_reference,
compute_capabilities_after,
damerau_levenshtein_distance,
exact_version,
get_full_type_name,
Expand Down
41 changes: 41 additions & 0 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
if is_prerelease:
return False
return True


@functools.lru_cache(None)
def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool:
"""
Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
The current system GPU CUDA compute capability is determined by the first GPU in the system.
The compared version is a string in the form of "major.minor".

Args:
major: major version number to be compared with.
minor: minor version number to be compared with. Defaults to 0.
current_ver_string: if None, the current system GPU CUDA compute capability will be used.

Returns:
True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
"""
if current_ver_string is None:
cuda_available = torch.cuda.is_available()
pynvml, has_pynvml = optional_import("pynvml")
if not has_pynvml: # assuming that the user has Ampere and later GPU
return True
if not cuda_available:
return False
else:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU
major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
current_ver_string = f"{major_c}.{minor_c}"
pynvml.nvmlShutdown()

ver, has_ver = optional_import("packaging.version", name="parse")
if has_ver:
return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
while len(parts) < 2:
parts += ["0"]
c_major, c_minor = parts[:2]
c_mn = int(c_major), int(c_minor)
mn = int(major), int(minor)
return c_mn > mn
9 changes: 8 additions & 1 deletion tests/test_bundle_trt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from monai.data import load_net_with_metadata
from monai.networks import save_state
from monai.utils import optional_import
from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows
from tests.utils import (
SkipIfBeforeComputeCapabilityVersion,
command_line_tests,
skip_if_no_cuda,
skip_if_quick,
skip_if_windows,
)

_, has_torchtrt = optional_import(
"torch_tensorrt",
Expand All @@ -47,6 +53,7 @@
@skip_if_windows
@skip_if_no_cuda
@skip_if_quick
@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestTRTExport(unittest.TestCase):

def setUp(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_convert_to_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.networks import convert_to_trt
from monai.networks.nets import UNet
from monai.utils import optional_import
from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows
from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows

_, has_torchtrt = optional_import(
"torch_tensorrt",
Expand All @@ -38,6 +38,7 @@
@skip_if_windows
@skip_if_no_cuda
@skip_if_quick
@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestConvertToTRT(unittest.TestCase):

def setUp(self):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_trt_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from monai.networks import trt_compile
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
from monai.utils import min_version, optional_import
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
from tests.utils import (
SkipIfAtLeastPyTorchVersion,
SkipIfBeforeComputeCapabilityVersion,
skip_if_no_cuda,
skip_if_quick,
skip_if_windows,
)

trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
polygraphy, polygraphy_imported = optional_import("polygraphy")
Expand All @@ -36,6 +42,7 @@
@skip_if_quick
@unittest.skipUnless(trt_imported, "tensorrt is required")
@unittest.skipUnless(polygraphy_imported, "polygraphy is required")
@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestTRTCompile(unittest.TestCase):

def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from parameterized import parameterized

from monai.utils import pytorch_after
from monai.utils import compute_capabilities_after, pytorch_after

TEST_CASES = (
TEST_CASES_PT = (
(1, 5, 9, "1.6.0"),
(1, 6, 0, "1.6.0"),
(1, 6, 1, "1.6.0", False),
Expand All @@ -36,14 +36,30 @@
(1, 6, 1, "1.6.0+cpu", False),
)

TEST_CASES_SM = [
# (major, minor, sm, expected)
(6, 1, "6.1", True),
(6, 1, "6.0", False),
(6, 0, "8.6", True),
(7, 0, "8", True),
(8, 6, "8", False),
]


class TestPytorchVersionCompare(unittest.TestCase):

@parameterized.expand(TEST_CASES)
@parameterized.expand(TEST_CASES_PT)
def test_compare(self, a, b, p, current, expected=True):
"""Test pytorch_after with a and b"""
self.assertEqual(pytorch_after(a, b, p, current), expected)


class TestComputeCapabilitiesAfter(unittest.TestCase):

@parameterized.expand(TEST_CASES_SM)
def test_compute_capabilities_after(self, major, minor, sm, expected):
self.assertEqual(compute_capabilities_after(major, minor, sm), expected)


if __name__ == "__main__":
unittest.main()
16 changes: 15 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from monai.networks import convert_to_onnx, convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import MONAIEnvVars
from monai.utils.module import pytorch_after
from monai.utils.module import compute_capabilities_after, pytorch_after
from monai.utils.tf32 import detect_default_tf32
from monai.utils.type_conversion import convert_data_type

Expand Down Expand Up @@ -286,6 +286,20 @@ def __call__(self, obj):
)(obj)


class SkipIfBeforeComputeCapabilityVersion:
"""Decorator to be used if test should be skipped
with Compute Capability older than that given."""

def __init__(self, compute_capability_tuple):
self.min_version = compute_capability_tuple
self.version_too_old = not compute_capabilities_after(*compute_capability_tuple)

def __call__(self, obj):
return unittest.skipIf(
self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}"
)(obj)


def is_main_test_process():
ps = torch.multiprocessing.current_process()
if not ps or not hasattr(ps, "name"):
Expand Down
Loading