@@ -205,24 +205,28 @@ else:
205205
206206if ml_dtypes is not None :
207207 MLDTYPES_DTYPE_TO_DTYPE = {
208- numpy.dtype(ml_dtypes.int2): DLDataType(0 , 2 , 1 ),
209208 numpy.dtype(ml_dtypes.int4): DLDataType(0 , 4 , 1 ),
210- numpy.dtype(ml_dtypes.uint2): DLDataType(1 , 2 , 1 ),
211209 numpy.dtype(ml_dtypes.uint4): DLDataType(1 , 4 , 1 ),
212210 numpy.dtype(ml_dtypes.bfloat16): DLDataType(4 , 16 , 1 ),
213- numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7 , 8 , 1 ),
214- numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8 , 8 , 1 ),
215211 numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9 , 8 , 1 ),
216212 numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10 , 8 , 1 ),
217213 numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11 , 8 , 1 ),
218214 numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12 , 8 , 1 ),
219215 numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13 , 8 , 1 ),
220- numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14 , 8 , 1 ),
221- numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15 , 6 , 1 ),
222- numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16 , 6 , 1 ),
223- numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17 , 4 , 1 ),
224216 }
225217
218+ if hasattr (ml_dtypes, " int2" ): # ml_dtypes >= 0.5.0
219+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.int2] = DLDataType(0 , 2 , 1 )
220+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.uint2] = DLDataType(1 , 2 , 1 )
221+
222+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e3m4] = DLDataType(7 , 8 , 1 )
223+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e4m3] = DLDataType(8 , 8 , 1 )
224+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e8m0fnu] = DLDataType(14 , 8 , 1 )
225+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e2m3fn] = DLDataType(15 , 6 , 1 )
226+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e3m2fn] = DLDataType(16 , 6 , 1 )
227+ MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float4_e2m1fn] = DLDataType(17 , 4 , 1 )
228+
229+
226230if numpy is not None :
227231 NUMPY_DTYPE_TO_DTYPE = {
228232 numpy.dtype(numpy.int8): DLDataType(0 , 8 , 1 ),
0 commit comments