@@ -634,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
634634 if is_prerelease :
635635 return False
636636 return True
637+
638+
639+ @functools .lru_cache (None )
640+ def compute_capabilities_after (major : int , minor : int = 0 , current_ver_string : str | None = None ) -> bool :
641+ """
642+ Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
643+ The current system GPU CUDA compute capability is determined by the first GPU in the system.
644+ The compared version is a string in the form of "major.minor".
645+
646+ Args:
647+ major: major version number to be compared with.
648+ minor: minor version number to be compared with. Defaults to 0.
649+ current_ver_string: if None, the current system GPU CUDA compute capability will be used.
650+
651+ Returns:
652+ True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
653+ """
654+ if current_ver_string is None :
655+ cuda_available = torch .cuda .is_available ()
656+ pynvml , has_pynvml = optional_import ("pynvml" )
657+ if not has_pynvml : # assuming that the user has Ampere and later GPU
658+ return True
659+ if not cuda_available :
660+ return False
661+ else :
662+ pynvml .nvmlInit ()
663+ handle = pynvml .nvmlDeviceGetHandleByIndex (0 ) # get the first GPU
664+ major_c , minor_c = pynvml .nvmlDeviceGetCudaComputeCapability (handle )
665+ current_ver_string = f"{ major_c } .{ minor_c } "
666+ pynvml .nvmlShutdown ()
667+
668+ ver , has_ver = optional_import ("packaging.version" , name = "parse" )
669+ if has_ver :
670+ return ver ("." .join ((f"{ major } " , f"{ minor } " ))) <= ver (f"{ current_ver_string } " ) # type: ignore
671+ parts = f"{ current_ver_string } " .split ("+" , 1 )[0 ].split ("." , 2 )
672+ while len (parts ) < 2 :
673+ parts += ["0" ]
674+ c_major , c_minor = parts [:2 ]
675+ c_mn = int (c_major ), int (c_minor )
676+ mn = int (major ), int (minor )
677+ return c_mn > mn
0 commit comments