From 815e833431aaebed7e76a47743af4f7fab47e35a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:40:43 +0800 Subject: [PATCH 1/9] add compute_capabilities_after Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/utils/__init__.py | 1 + monai/utils/module.py | 34 ++++++++++++++++++++ tests/test_pytorch_version_after.py | 49 ----------------------------- tests/test_trt_compile.py | 3 +- tests/utils.py | 16 +++++++++- 5 files changed, 52 insertions(+), 51 deletions(-) delete mode 100644 tests/test_pytorch_version_after.py diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 79dc1f2304..40735aa8cb 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -123,6 +123,7 @@ run_eval, version_geq, version_leq, + compute_capabilities_after, ) from .nvtx import Range from .ordering import Ordering diff --git a/monai/utils/module.py b/monai/utils/module.py index 1f7f8aecfc..184169509e 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -26,6 +26,7 @@ from re import match from types import FunctionType, ModuleType from typing import Any, cast +import pynvml import torch @@ -634,3 +635,36 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st if is_prerelease: return False return True + +@functools.lru_cache(None) +def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool: + """ + Compute whether the current system GPU CUDA compute capability is after or equal to the specified version. + The current system GPU CUDA compute capability is determined by the first GPU in the system. + The compared version is a string in the form of "major.minor". + + Args: + major: major version number to be compared with. + minor: minor version number to be compared with. Defaults to 0. + current_ver_string: if None, the current system GPU CUDA compute capability will be used. + + Returns: + True if the current system GPU CUDA compute capability is greater than or equal to the specified version. + """ + if current_ver_string is None: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU + major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + pynvml.nvmlShutdown() + current_ver_string = f"{major_c}.{minor_c}" + + ver, has_ver = optional_import("packaging.version", name="parse") + if has_ver: + return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore + parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2) + while len(parts) < 2: + parts += ["0"] + c_major, c_minor = parts[:2] + c_mn = int(c_major), int(c_minor) + mn = int(major), int(minor) + return c_mn > mn diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py deleted file mode 100644 index 147707d2c0..0000000000 --- a/tests/test_pytorch_version_after.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -from parameterized import parameterized - -from monai.utils import pytorch_after - -TEST_CASES = ( - (1, 5, 9, "1.6.0"), - (1, 6, 0, "1.6.0"), - (1, 6, 1, "1.6.0", False), - (1, 7, 0, "1.6.0", False), - (2, 6, 0, "1.6.0", False), - (0, 6, 0, "1.6.0a0+3fd9dcf"), - (1, 5, 9, "1.6.0a0+3fd9dcf"), - (1, 6, 0, "1.6.0a0+3fd9dcf", False), - (1, 6, 1, "1.6.0a0+3fd9dcf", False), - (2, 6, 0, "1.6.0a0+3fd9dcf", False), - (1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease - (1, 6, 0, "1.6.0rc0", False), - (1, 6, 0, "1.6", True), - (1, 6, 0, "1", False), - (1, 6, 0, "1.6.0+cpu", True), - (1, 6, 1, "1.6.0+cpu", False), -) - - -class TestPytorchVersionCompare(unittest.TestCase): - - @parameterized.expand(TEST_CASES) - def test_compare(self, a, b, p, current, expected=True): - """Test pytorch_after with a and b""" - self.assertEqual(pytorch_after(a, b, p, current), expected) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 6df5d520bd..5e3a70da1e 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -21,7 +21,7 @@ from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows, SkipIfBeforeComputeCapabilityVersion trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") @@ -36,6 +36,7 @@ @skip_if_quick @unittest.skipUnless(trt_imported, "tensorrt is required") @unittest.skipUnless(polygraphy_imported, "polygraphy is required") +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestTRTCompile(unittest.TestCase): def setUp(self): diff --git a/tests/utils.py b/tests/utils.py index 77b53cebb8..3d64d4e22d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -47,7 +47,7 @@ from monai.networks import convert_to_onnx, convert_to_torchscript from monai.utils import optional_import from monai.utils.misc import MONAIEnvVars -from monai.utils.module import pytorch_after +from monai.utils.module import pytorch_after, compute_capabilities_after from monai.utils.tf32 import detect_default_tf32 from monai.utils.type_conversion import convert_data_type @@ -286,6 +286,20 @@ def __call__(self, obj): )(obj) +class SkipIfBeforeComputeCapabilityVersion: + """Decorator to be used if test should be skipped + with Compute Capability older than that given.""" + + def __init__(self, compute_capability_tuple): + self.min_version = compute_capability_tuple + self.version_too_old = not compute_capabilities_after(*compute_capability_tuple) + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}" + )(obj) + + def is_main_test_process(): ps = torch.multiprocessing.current_process() if not ps or not hasattr(ps, "name"): From 8fafc3f046fe49b851f2d79ac0ed69c19af00339 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:42:42 +0800 Subject: [PATCH 2/9] fix #8198 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_bundle_trt_export.py | 3 ++- tests/test_convert_to_trt.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 833a0ca1dc..e6e26f3647 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -22,7 +22,7 @@ from monai.data import load_net_with_metadata from monai.networks import save_state from monai.utils import optional_import -from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows, SkipIfBeforeComputeCapabilityVersion _, has_torchtrt = optional_import( "torch_tensorrt", @@ -47,6 +47,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestTRTExport(unittest.TestCase): def setUp(self): diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 5579539764..4799d651f8 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -20,7 +20,7 @@ from monai.networks import convert_to_trt from monai.networks.nets import UNet from monai.utils import optional_import -from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows +from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows, SkipIfBeforeComputeCapabilityVersion _, has_torchtrt = optional_import( "torch_tensorrt", @@ -38,6 +38,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick +@SkipIfBeforeComputeCapabilityVersion((7, 0)) class TestConvertToTRT(unittest.TestCase): def setUp(self): From b55bf6a84a4008a8da15ff74fcad06b81d22b896 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:42:59 +0800 Subject: [PATCH 3/9] add test Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_version_after.py | 65 +++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/test_version_after.py diff --git a/tests/test_version_after.py b/tests/test_version_after.py new file mode 100644 index 0000000000..aecc0baedd --- /dev/null +++ b/tests/test_version_after.py @@ -0,0 +1,65 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from parameterized import parameterized + +from monai.utils import pytorch_after, compute_capabilities_after + +TEST_CASES_PT = ( + (1, 5, 9, "1.6.0"), + (1, 6, 0, "1.6.0"), + (1, 6, 1, "1.6.0", False), + (1, 7, 0, "1.6.0", False), + (2, 6, 0, "1.6.0", False), + (0, 6, 0, "1.6.0a0+3fd9dcf"), + (1, 5, 9, "1.6.0a0+3fd9dcf"), + (1, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 1, "1.6.0a0+3fd9dcf", False), + (2, 6, 0, "1.6.0a0+3fd9dcf", False), + (1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease + (1, 6, 0, "1.6.0rc0", False), + (1, 6, 0, "1.6", True), + (1, 6, 0, "1", False), + (1, 6, 0, "1.6.0+cpu", True), + (1, 6, 1, "1.6.0+cpu", False), +) + +TEST_CASES_SM = [ + # (major, minor, sm, expected) + (6, 1, "6.1", True), + (6, 1, "6.0", False), + (6, 0, "8.6", True), + (7, 0, "8", True), + (8, 6, "8", False), +] + + +class TestPytorchVersionCompare(unittest.TestCase): + + @parameterized.expand(TEST_CASES_PT) + def test_compare(self, a, b, p, current, expected=True): + """Test pytorch_after with a and b""" + self.assertEqual(pytorch_after(a, b, p, current), expected) + + +class TestComputeCapabilitiesAfter(unittest.TestCase): + + @parameterized.expand(TEST_CASES_SM) + def test_compute_capabilities_after(self, major, minor, sm, expected): + self.assertEqual(compute_capabilities_after(major, minor, sm), expected) + + +if __name__ == "__main__": + unittest.main() From 83b5a4926e9b11b2c03932c9f0fb12e4d4cfb09e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Nov 2024 08:45:56 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 184169509e..65a3dcf0ec 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -642,12 +642,12 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s Compute whether the current system GPU CUDA compute capability is after or equal to the specified version. The current system GPU CUDA compute capability is determined by the first GPU in the system. The compared version is a string in the form of "major.minor". - + Args: major: major version number to be compared with. minor: minor version number to be compared with. Defaults to 0. current_ver_string: if None, the current system GPU CUDA compute capability will be used. - + Returns: True if the current system GPU CUDA compute capability is greater than or equal to the specified version. """ @@ -657,7 +657,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) pynvml.nvmlShutdown() current_ver_string = f"{major_c}.{minor_c}" - + ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore From 7dfad50387f6a5d67d41b6d340065a1f928d3985 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:47:52 +0800 Subject: [PATCH 5/9] add docstring Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 2 ++ monai/networks/trt_compiler.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 884723ed68..131c78008b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1589,6 +1589,8 @@ def trt_export( """ Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript. Currently, this API only supports converting models whose inputs are all tensors. + Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. + Review the TensorRT Support Matrix for which GPUs are supported. There are two ways to export a model: 1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript. diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a360f63dbd..d2d05fae22 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -505,7 +505,9 @@ def trt_compile( ) -> torch.nn.Module: """ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. - Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x. + NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5. + Review the TensorRT Support Matrix for which GPUs are supported. Args: model: module to patch with TrtCompiler object. base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. From 46e9bce519d309962c2b54468297c759f4de8e77 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:51:12 +0800 Subject: [PATCH 6/9] fix format Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/utils/__init__.py | 2 +- monai/utils/module.py | 3 ++- tests/test_bundle_trt_export.py | 8 +++++++- tests/test_convert_to_trt.py | 2 +- tests/test_trt_compile.py | 8 +++++++- tests/test_version_after.py | 2 +- tests/utils.py | 2 +- 7 files changed, 20 insertions(+), 7 deletions(-) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 40735aa8cb..8f2f400b5d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -107,6 +107,7 @@ InvalidPyTorchVersionError, OptionalImportError, allow_missing_reference, + compute_capabilities_after, damerau_levenshtein_distance, exact_version, get_full_type_name, @@ -123,7 +124,6 @@ run_eval, version_geq, version_leq, - compute_capabilities_after, ) from .nvtx import Range from .ordering import Ordering diff --git a/monai/utils/module.py b/monai/utils/module.py index 65a3dcf0ec..e16d7371f3 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -26,8 +26,8 @@ from re import match from types import FunctionType, ModuleType from typing import Any, cast -import pynvml +import pynvml import torch # bundle config system flags @@ -636,6 +636,7 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st return False return True + @functools.lru_cache(None) def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool: """ diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index e6e26f3647..835c8e5c1d 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -22,7 +22,13 @@ from monai.data import load_net_with_metadata from monai.networks import save_state from monai.utils import optional_import -from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows, SkipIfBeforeComputeCapabilityVersion +from tests.utils import ( + SkipIfBeforeComputeCapabilityVersion, + command_line_tests, + skip_if_no_cuda, + skip_if_quick, + skip_if_windows, +) _, has_torchtrt = optional_import( "torch_tensorrt", diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 4799d651f8..712d887c3b 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -20,7 +20,7 @@ from monai.networks import convert_to_trt from monai.networks.nets import UNet from monai.utils import optional_import -from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows, SkipIfBeforeComputeCapabilityVersion +from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows _, has_torchtrt = optional_import( "torch_tensorrt", diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 5e3a70da1e..49404fdbbe 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -21,7 +21,13 @@ from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 from monai.utils import min_version, optional_import -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows, SkipIfBeforeComputeCapabilityVersion +from tests.utils import ( + SkipIfAtLeastPyTorchVersion, + SkipIfBeforeComputeCapabilityVersion, + skip_if_no_cuda, + skip_if_quick, + skip_if_windows, +) trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") diff --git a/tests/test_version_after.py b/tests/test_version_after.py index aecc0baedd..b6cb741382 100644 --- a/tests/test_version_after.py +++ b/tests/test_version_after.py @@ -15,7 +15,7 @@ from parameterized import parameterized -from monai.utils import pytorch_after, compute_capabilities_after +from monai.utils import compute_capabilities_after, pytorch_after TEST_CASES_PT = ( (1, 5, 9, "1.6.0"), diff --git a/tests/utils.py b/tests/utils.py index 3d64d4e22d..2a00af50e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -47,7 +47,7 @@ from monai.networks import convert_to_onnx, convert_to_torchscript from monai.utils import optional_import from monai.utils.misc import MONAIEnvVars -from monai.utils.module import pytorch_after, compute_capabilities_after +from monai.utils.module import compute_capabilities_after, pytorch_after from monai.utils.tf32 import detect_default_tf32 from monai.utils.type_conversion import convert_data_type From b72220a3f3f247eac858bf3e6c5a27735db77b4e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:15:49 +0800 Subject: [PATCH 7/9] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/utils/module.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index e16d7371f3..03ee8f2003 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -27,7 +27,6 @@ from types import FunctionType, ModuleType from typing import Any, cast -import pynvml import torch # bundle config system flags @@ -653,11 +652,19 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s True if the current system GPU CUDA compute capability is greater than or equal to the specified version. """ if current_ver_string is None: - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU - major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) - pynvml.nvmlShutdown() - current_ver_string = f"{major_c}.{minor_c}" + pynvml, has_pynvml = optional_import("pynvml") + if not has_pynvml: # assuming that the user has Ampere and later GPU + return True + + try: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU + major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + current_ver_string = f"{major_c}.{minor_c}" + except BaseException: + pass + finally: + pynvml.nvmlShutdown() ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: From 8631db715611fb1f72aa4aba713b3c7388e4aaa2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Nov 2024 09:16:30 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/utils/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 03ee8f2003..a8e9c86eab 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -655,7 +655,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s pynvml, has_pynvml = optional_import("pynvml") if not has_pynvml: # assuming that the user has Ampere and later GPU return True - + try: pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU From ebeaeab71a98de9ecc816eb234be9be151e04f43 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Nov 2024 18:03:02 +0800 Subject: [PATCH 9/9] fix ci Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/utils/module.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 03ee8f2003..59fdd00acd 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -655,16 +655,12 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s pynvml, has_pynvml = optional_import("pynvml") if not has_pynvml: # assuming that the user has Ampere and later GPU return True - - try: - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU - major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) - current_ver_string = f"{major_c}.{minor_c}" - except BaseException: - pass - finally: - pynvml.nvmlShutdown() + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU + major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + current_ver_string = f"{major_c}.{minor_c}" + pynvml.nvmlShutdown() ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: