@@ -205,24 +205,28 @@ else:
205205
206206if  ml_dtypes is  not  None :
207207    MLDTYPES_DTYPE_TO_DTYPE =  {
208-         numpy.dtype(ml_dtypes.int2): DLDataType(0 , 2 , 1 ),
209208        numpy.dtype(ml_dtypes.int4): DLDataType(0 , 4 , 1 ),
210-         numpy.dtype(ml_dtypes.uint2): DLDataType(1 , 2 , 1 ),
211209        numpy.dtype(ml_dtypes.uint4): DLDataType(1 , 4 , 1 ),
212210        numpy.dtype(ml_dtypes.bfloat16): DLDataType(4 , 16 , 1 ),
213-         numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7 , 8 , 1 ),
214-         numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8 , 8 , 1 ),
215211        numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9 , 8 , 1 ),
216212        numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10 , 8 , 1 ),
217213        numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11 , 8 , 1 ),
218214        numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12 , 8 , 1 ),
219215        numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13 , 8 , 1 ),
220-         numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14 , 8 , 1 ),
221-         numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15 , 6 , 1 ),
222-         numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16 , 6 , 1 ),
223-         numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17 , 4 , 1 ),
224216    }
225217
218+     if  hasattr (ml_dtypes, " int2"  ):  #  ml_dtypes >= 0.5.0
219+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] =  DLDataType(0 , 2 , 1 )
220+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.uint2)] =  DLDataType(1 , 2 , 1 )
221+ 
222+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e3m4)] =  DLDataType(7 , 8 , 1 )
223+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e4m3)] =  DLDataType(8 , 8 , 1 )
224+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] =  DLDataType(14 , 8 , 1 )
225+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] =  DLDataType(15 , 6 , 1 )
226+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] =  DLDataType(16 , 6 , 1 )
227+         MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] =  DLDataType(17 , 4 , 1 )
228+ 
229+ 
226230if  numpy is  not  None :
227231    NUMPY_DTYPE_TO_DTYPE =  {
228232        numpy.dtype(numpy.int8): DLDataType(0 , 8 , 1 ),
0 commit comments