Skip to content

Commit ec202e0

Browse files
committed
update compute_capabilities_after
Signed-off-by: YunLiu <[email protected]>
1 parent 9f9ebcd commit ec202e0

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

monai/utils/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
649649
current_ver_string: if None, the current system GPU CUDA compute capability will be used.
650650
651651
Returns:
652-
True if the current system GPU CUDA compute capability is greater than the specified version.
652+
True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
653653
"""
654654
if current_ver_string is None:
655655
cuda_available = torch.cuda.is_available()
@@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
667667

668668
ver, has_ver = optional_import("packaging.version", name="parse")
669669
if has_ver:
670-
return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore
670+
return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
671671
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
672672
while len(parts) < 2:
673673
parts += ["0"]
674674
c_major, c_minor = parts[:2]
675675
c_mn = int(c_major), int(c_minor)
676676
mn = int(major), int(minor)
677-
return c_mn >= mn
677+
return c_mn > mn

tests/test_bundle_trt_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
@skip_if_windows
5454
@skip_if_no_cuda
5555
@skip_if_quick
56-
@SkipIfBeforeComputeCapabilityVersion((7, 0))
56+
@SkipIfBeforeComputeCapabilityVersion((7, 5))
5757
class TestTRTExport(unittest.TestCase):
5858

5959
def setUp(self):

tests/test_convert_to_trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@skip_if_windows
3939
@skip_if_no_cuda
4040
@skip_if_quick
41-
@SkipIfBeforeComputeCapabilityVersion((7, 0))
41+
@SkipIfBeforeComputeCapabilityVersion((7, 5))
4242
class TestConvertToTRT(unittest.TestCase):
4343

4444
def setUp(self):

tests/test_trt_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: f
5050
@skip_if_quick
5151
@unittest.skipUnless(trt_imported, "tensorrt is required")
5252
@unittest.skipUnless(polygraphy_imported, "polygraphy is required")
53-
@SkipIfBeforeComputeCapabilityVersion((7, 0))
53+
@SkipIfBeforeComputeCapabilityVersion((7, 5))
5454
class TestTRTCompile(unittest.TestCase):
5555

5656
def setUp(self):

tests/test_version_after.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
TEST_CASES_SM = [
4040
# (major, minor, sm, expected)
41-
(6, 1, "6.1", False),
41+
(6, 1, "6.1", True),
4242
(6, 1, "6.0", False),
4343
(6, 0, "8.6", True),
4444
(7, 0, "8", True),

0 commit comments

Comments
 (0)