16
16
# ==============================================================================
17
17
"""Utilities for the NVML Python bindings (`nvidia-ml-py <https://pypi.org/project/nvidia-ml-py>`_)."""
18
18
19
- # pylint: disable=invalid-name
19
+ # pylint: disable=too-many-lines, invalid-name
20
20
21
21
from __future__ import annotations
22
22
@@ -265,9 +265,13 @@ def _lazy_init() -> None:
265
265
If cannot find function :func:`pynvml.nvmlInitWithFlags`, usually the :mod:`pynvml` module
266
266
is overridden by other modules. Need to reinstall package ``nvidia-ml-py``.
267
267
"""
268
+ if __initialized :
269
+ return
270
+
268
271
with __lock :
269
272
if __initialized :
270
- return
273
+ return # type: ignore[unreachable]
274
+
271
275
nvmlInit ()
272
276
_atexit .register (nvmlShutdown )
273
277
@@ -531,12 +535,24 @@ def nvmlCheckReturn(retval: _Any, types: type | tuple[type, ...] | None = None,
531
535
# Patch layers for backward compatibility ##########################################################
532
536
_pynvml_installation_corrupted : bool = not callable (
533
537
getattr (_pynvml , '_nvmlGetFunctionPointer' , None ),
534
- )
538
+ ) and isinstance ( getattr ( _pynvml , '_PrintableStructure' , None ), type )
535
539
536
540
# Patch function `nvmlDeviceGet{Compute,Graphics,MPSCompute}RunningProcesses`
537
541
if not _pynvml_installation_corrupted :
542
+ # pylint: disable-next=ungrouped-imports
543
+ from pynvml import _nvmlGetFunctionPointer , _PrintableStructure , nvmlStructToFriendlyObject
544
+
545
+ def _nvmlLookupFunctionPointer (symbol : str ) -> _Any | None :
546
+ try :
547
+ ptr = _nvmlGetFunctionPointer (symbol )
548
+ except NVMLError_FunctionNotFound :
549
+ LOGGER .debug ('Failed to found symbol `%s`.' , symbol )
550
+ return None
551
+ LOGGER .debug ('Found symbol `%s`.' , symbol )
552
+ return ptr
553
+
538
554
# pylint: disable-next=missing-class-docstring,too-few-public-methods,function-redefined
539
- class c_nvmlProcessInfo_v1_t (_pynvml . _PrintableStructure ): # pylint: disable=protected-access
555
+ class c_nvmlProcessInfo_v1_t (_PrintableStructure ):
540
556
_fields_ : _ClassVar [list [tuple [str , type ]]] = [
541
557
# Process ID
542
558
('pid' , _ctypes .c_uint ),
@@ -550,7 +566,7 @@ class c_nvmlProcessInfo_v1_t(_pynvml._PrintableStructure): # pylint: disable=pr
550
566
}
551
567
552
568
# pylint: disable-next=missing-class-docstring,too-few-public-methods,function-redefined
553
- class c_nvmlProcessInfo_v2_t (_pynvml . _PrintableStructure ): # pylint: disable=protected-access
569
+ class c_nvmlProcessInfo_v2_t (_PrintableStructure ):
554
570
_fields_ : _ClassVar [list [tuple [str , type ]]] = [
555
571
# Process ID
556
572
('pid' , _ctypes .c_uint ),
@@ -570,7 +586,7 @@ class c_nvmlProcessInfo_v2_t(_pynvml._PrintableStructure): # pylint: disable=pr
570
586
}
571
587
572
588
# pylint: disable-next=missing-class-docstring,too-few-public-methods,function-redefined
573
- class c_nvmlProcessInfo_v3_t (_pynvml . _PrintableStructure ): # pylint: disable=protected-access
589
+ class c_nvmlProcessInfo_v3_t (_PrintableStructure ):
574
590
_fields_ : _ClassVar [list [tuple [str , type ]]] = [
575
591
# Process ID
576
592
('pid' , _ctypes .c_uint ),
@@ -599,22 +615,11 @@ def __determine_get_running_processes_version_suffix() -> str:
599
615
global __get_running_processes_version_suffix , c_nvmlProcessInfo_t # pylint: disable=global-statement
600
616
601
617
if __get_running_processes_version_suffix is None :
602
- # pylint: disable-next=protected-access,no-member
603
- nvmlGetFunctionPointer = _pynvml ._nvmlGetFunctionPointer
604
618
__get_running_processes_version_suffix = '_v3'
605
-
606
- def lookup (symbol : str ) -> _Any | None :
607
- try :
608
- ptr = nvmlGetFunctionPointer (symbol )
609
- except NVMLError_FunctionNotFound :
610
- LOGGER .debug ('Failed to found symbol `%s`.' , symbol )
611
- return None
612
- LOGGER .debug ('Found symbol `%s`.' , symbol )
613
- return ptr
614
-
615
- if lookup ('nvmlDeviceGetComputeRunningProcesses_v3' ):
616
- if lookup ('nvmlDeviceGetConfComputeMemSizeInfo' ) and not lookup (
617
- 'nvmlDeviceGetRunningProcessDetailList' ,
619
+ if _nvmlLookupFunctionPointer ('nvmlDeviceGetComputeRunningProcesses_v3' ) is not None :
620
+ if (
621
+ _nvmlLookupFunctionPointer ('nvmlDeviceGetConfComputeMemSizeInfo' ) is not None
622
+ and _nvmlLookupFunctionPointer ('nvmlDeviceGetRunningProcessDetailList' ) is None
618
623
):
619
624
LOGGER .debug (
620
625
'NVML get running process version 3 API with v3 type struct is available.' ,
@@ -634,7 +639,10 @@ def lookup(symbol: str) -> _Any | None:
634
639
'due to incompatible NVIDIA driver. Fallback to use get running process '
635
640
'version 2 API with v2 type struct.' ,
636
641
)
637
- if lookup ('nvmlDeviceGetComputeRunningProcesses_v2' ):
642
+ if (
643
+ _nvmlLookupFunctionPointer ('nvmlDeviceGetComputeRunningProcesses_v2' )
644
+ is not None
645
+ ):
638
646
LOGGER .debug (
639
647
'NVML get running process version 2 API with v2 type struct is available.' ,
640
648
)
@@ -663,8 +671,7 @@ def __nvml_device_get_running_processes(
663
671
664
672
# First call to get the size
665
673
c_count = _ctypes .c_uint (0 )
666
- # pylint: disable-next=protected-access
667
- fn = _pynvml ._nvmlGetFunctionPointer (f'{ func } { version_suffix } ' )
674
+ fn = _nvmlGetFunctionPointer (f'{ func } { version_suffix } ' )
668
675
ret = fn (handle , _ctypes .byref (c_count ), None )
669
676
670
677
if ret == NVML_SUCCESS :
@@ -679,12 +686,13 @@ def __nvml_device_get_running_processes(
679
686
680
687
# Make the call again
681
688
ret = fn (handle , _ctypes .byref (c_count ), c_processes )
682
- _pynvml ._nvmlCheckReturn (ret ) # pylint: disable=protected-access
689
+ if ret != NVML_SUCCESS :
690
+ raise NVMLError (ret )
683
691
684
692
processes = []
685
693
for i in range (c_count .value ):
686
694
# Use an alternative struct for this object
687
- obj = _pynvml . nvmlStructToFriendlyObject (c_processes [i ])
695
+ obj = nvmlStructToFriendlyObject (c_processes [i ])
688
696
if obj .usedGpuMemory == ULONGLONG_MAX :
689
697
# Special case for WDDM on Windows, see comment above
690
698
obj .usedGpuMemory = None
@@ -781,7 +789,7 @@ def nvmlDeviceGetMPSComputeRunningProcesses( # pylint: disable=function-redefin
781
789
# Patch function `nvmlDeviceGetMemoryInfo`
782
790
if not _pynvml_installation_corrupted :
783
791
# pylint: disable-next=missing-class-docstring,too-few-public-methods,function-redefined
784
- class c_nvmlMemory_v1_t (_pynvml . _PrintableStructure ): # pylint: disable=protected-access
792
+ class c_nvmlMemory_v1_t (_PrintableStructure ):
785
793
_fields_ : _ClassVar [list [tuple [str , type ]]] = [
786
794
# Total physical device memory (in bytes).
787
795
('total' , _ctypes .c_ulonglong ),
@@ -794,7 +802,7 @@ class c_nvmlMemory_v1_t(_pynvml._PrintableStructure): # pylint: disable=protect
794
802
_fmt_ : _ClassVar [dict [str , str ]] = {'<default>' : '%d B' }
795
803
796
804
# pylint: disable-next=missing-class-docstring,too-few-public-methods,function-redefined
797
- class c_nvmlMemory_v2_t (_pynvml . _PrintableStructure ): # pylint: disable=protected-access
805
+ class c_nvmlMemory_v2_t (_PrintableStructure ):
798
806
_fields_ : _ClassVar [list [tuple [str , type ]]] = [
799
807
# Structure format version (must be 2).
800
808
('version' , _ctypes .c_uint ),
@@ -810,30 +818,24 @@ class c_nvmlMemory_v2_t(_pynvml._PrintableStructure): # pylint: disable=protect
810
818
]
811
819
_fmt_ : _ClassVar [dict [str , str ]] = {'<default>' : '%d B' }
812
820
813
- nvmlMemory_v2 = getattr (_pynvml , 'nvmlMemory_v2' , _ctypes .sizeof (c_nvmlMemory_v2_t ) | 2 << 24 )
821
+ nvmlMemory_v2 = getattr (_pynvml , 'nvmlMemory_v2' , _ctypes .sizeof (c_nvmlMemory_v2_t ) | ( 2 << 24 ) )
814
822
__get_memory_info_version_suffix : str | None = None
815
823
c_nvmlMemory_t = c_nvmlMemory_v2_t
816
824
817
825
def __determine_get_memory_info_version_suffix () -> str :
818
826
global __get_memory_info_version_suffix , c_nvmlMemory_t # pylint: disable=global-statement
819
827
820
828
if __get_memory_info_version_suffix is None :
821
- # pylint: disable-next=protected-access,no-member
822
- nvml_get_function_pointer = _pynvml ._nvmlGetFunctionPointer
823
829
__get_memory_info_version_suffix = '_v2'
824
- try :
825
- nvml_get_function_pointer ('nvmlDeviceGetMemoryInfo_v2' )
826
- except NVMLError_FunctionNotFound :
827
- LOGGER .debug ('Failed to found symbol `nvmlDeviceGetMemoryInfo_v2`.' )
830
+ if _nvmlLookupFunctionPointer ('nvmlDeviceGetMemoryInfo_v2' ) is not None :
831
+ LOGGER .debug ('NVML get memory info version 2 is available.' )
832
+ else :
828
833
c_nvmlMemory_t = c_nvmlMemory_v1_t
829
834
__get_memory_info_version_suffix = ''
830
835
LOGGER .debug (
831
836
'NVML get memory info version 2 API is not available due to incompatible '
832
837
'NVIDIA driver. Fallback to use NVML get memory info version 1 API.' ,
833
838
)
834
- else :
835
- LOGGER .debug ('Found symbol `nvmlDeviceGetMemoryInfo_v2`.' )
836
- LOGGER .debug ('NVML get memory info version 2 is available.' )
837
839
838
840
return __get_memory_info_version_suffix
839
841
@@ -865,19 +867,19 @@ def nvmlDeviceGetMemoryInfo( # pylint: disable=function-redefined
865
867
if version_suffix == '_v2' :
866
868
c_memory = c_nvmlMemory_v2_t ()
867
869
c_memory .version = nvmlMemory_v2 # pylint: disable=attribute-defined-outside-init
868
- # pylint: disable-next=protected-access
869
- fn = _pynvml ._nvmlGetFunctionPointer ('nvmlDeviceGetMemoryInfo_v2' )
870
870
elif version_suffix in {'_v1' , '' }:
871
871
c_memory = c_nvmlMemory_v1_t ()
872
- # pylint: disable-next=protected-access
873
- fn = _pynvml ._nvmlGetFunctionPointer ('nvmlDeviceGetMemoryInfo' )
872
+ version_suffix = ''
874
873
else :
875
874
raise ValueError (
876
875
f'Unknown version suffix { version_suffix !r} for '
877
876
'function `nvmlDeviceGetMemoryInfo`.' ,
878
877
)
878
+
879
+ fn = _nvmlGetFunctionPointer (f'nvmlDeviceGetMemoryInfo{ version_suffix } ' )
879
880
ret = fn (handle , _ctypes .byref (c_memory ))
880
- _pynvml ._nvmlCheckReturn (ret ) # pylint: disable=protected-access
881
+ if ret != NVML_SUCCESS :
882
+ raise NVMLError (ret )
881
883
return c_memory
882
884
883
885
else :
@@ -888,6 +890,94 @@ def nvmlDeviceGetMemoryInfo( # pylint: disable=function-redefined
888
890
'`nvidia-ml-py` via `pip3 install --force-reinstall nvidia-ml-py nvitop`.' ,
889
891
)
890
892
893
+ # Patch function `nvmlDeviceGetTemperature`
894
+ if not _pynvml_installation_corrupted :
895
+ # pylint: disable-next=missing-class-docstring,too-few-public-methods,function-redefined
896
+ class c_nvmlTemperature_v1_t (_PrintableStructure ):
897
+ _fields_ : _ClassVar [list [tuple [str , type ]]] = [
898
+ # Structure format version (must be 1).
899
+ ('version' , _ctypes .c_uint ),
900
+ # Sensor type.
901
+ ('sensorType' , _ctypes .c_uint ),
902
+ # Temperature in degrees Celsius.
903
+ ('temperature' , _ctypes .c_int ),
904
+ ]
905
+
906
+ nvmlTemperature_v1 : int = getattr (
907
+ _pynvml ,
908
+ 'nvmlTemperature_v1' ,
909
+ _ctypes .sizeof (c_nvmlTemperature_v1_t ) | (1 << 24 ),
910
+ )
911
+ __get_temperature_version_suffix : str | None = None
912
+
913
+ def __determine_get_temperature_version_suffix () -> str :
914
+ """Determine the version suffix for the NVML temperature functions."""
915
+ global __get_temperature_version_suffix # pylint: disable=global-statement
916
+
917
+ if __get_temperature_version_suffix is None :
918
+ __get_temperature_version_suffix = 'V'
919
+ if _nvmlLookupFunctionPointer ('nvmlDeviceGetTemperatureV' ) is not None :
920
+ LOGGER .debug ('NVML get temperature version 1 API is available.' )
921
+ else :
922
+ __get_temperature_version_suffix = ''
923
+ LOGGER .debug (
924
+ 'NVML get temperature version 1 API is not available due to incompatible '
925
+ 'NVIDIA driver. Fallback to use NVML get temperature API without version.' ,
926
+ )
927
+
928
+ return __get_temperature_version_suffix
929
+
930
+ def nvmlDeviceGetTemperature ( # pylint: disable=function-redefined
931
+ handle : c_nvmlDevice_t ,
932
+ sensor : int ,
933
+ ) -> int :
934
+ """Retrieve the current temperature readings (in degrees C) for the given device.
935
+
936
+ Raises:
937
+ NVMLError_Uninitialized:
938
+ If NVML was not first initialized with :func:`nvmlInit`.
939
+ NVMLError_InvalidArgument:
940
+ If device is invalid, sensorType is invalid or temp is NULL.
941
+ NVMLError_NotSupported:
942
+ If the device does not have the specified sensor.
943
+ NVMLError_GpuIsLost:
944
+ If the target GPU has fallen off the bus or is otherwise inaccessible.
945
+ NVMLError_Unknown:
946
+ On any unexpected error.
947
+ """
948
+ version_suffix = __determine_get_temperature_version_suffix ()
949
+ if version_suffix == 'V' :
950
+ c_temp_v1 = c_nvmlTemperature_v1_t ()
951
+ # pylint: disable-next=attribute-defined-outside-init
952
+ c_temp_v1 .version = nvmlTemperature_v1
953
+ # pylint: disable-next=attribute-defined-outside-init
954
+ c_temp_v1 .sensorType = _ctypes .c_uint (sensor )
955
+ fn = _nvmlGetFunctionPointer ('nvmlDeviceGetTemperatureV' )
956
+ ret = fn (handle , _ctypes .byref (c_temp_v1 ))
957
+ if ret != NVML_SUCCESS :
958
+ raise NVMLError (ret )
959
+ return int (c_temp_v1 .temperature )
960
+
961
+ if version_suffix == '' :
962
+ c_temp = _ctypes .c_uint (0 )
963
+ fn = _nvmlGetFunctionPointer ('nvmlDeviceGetTemperature' )
964
+ ret = fn (handle , _ctypes .c_uint (sensor ), _ctypes .byref (c_temp ))
965
+ if ret != NVML_SUCCESS :
966
+ raise NVMLError (ret )
967
+ return c_temp .value
968
+
969
+ raise ValueError (
970
+ f'Unknown version suffix { version_suffix !r} for function `nvmlDeviceGetTemperature`.' ,
971
+ )
972
+
973
+ else :
974
+ LOGGER .warning (
975
+ 'Your installed package `nvidia-ml-py` is corrupted. '
976
+ 'Skip patch functions `nvmlDeviceGetTemperature`. '
977
+ 'You may get incorrect or incomplete results. Please consider reinstall package '
978
+ '`nvidia-ml-py` via `pip3 install --force-reinstall nvidia-ml-py nvitop`.' ,
979
+ )
980
+
891
981
892
982
# Add support for lookup fallback and context manager ##############################################
893
983
class _CustomModule (_ModuleType ):
0 commit comments