Skip to content

Commit f8ce93f

Browse files
committed
feat: support ml_types<0.5
1 parent 789e9e5 commit f8ce93f

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

python/tvm_ffi/cython/dtype.pxi

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,24 +205,28 @@ else:
205205

206206
if ml_dtypes is not None:
207207
MLDTYPES_DTYPE_TO_DTYPE = {
208-
numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1),
209208
numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
210-
numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1),
211209
numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
212210
numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
213-
numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1),
214-
numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1),
215211
numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1),
216212
numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1),
217213
numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11, 8, 1),
218214
numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12, 8, 1),
219215
numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13, 8, 1),
220-
numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14, 8, 1),
221-
numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15, 6, 1),
222-
numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16, 6, 1),
223-
numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17, 4, 1),
224216
}
225217

218+
if hasattr(ml_dtypes, "int2"): # ml_dtypes >= 0.5.0
219+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.int2] = DLDataType(0, 2, 1)
220+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.uint2] = DLDataType(1, 2, 1)
221+
222+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e3m4] = DLDataType(7, 8, 1)
223+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e4m3] = DLDataType(8, 8, 1)
224+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e8m0fnu] = DLDataType(14, 8, 1)
225+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e2m3fn] = DLDataType(15, 6, 1)
226+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e3m2fn] = DLDataType(16, 6, 1)
227+
MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float4_e2m1fn] = DLDataType(17, 4, 1)
228+
229+
226230
if numpy is not None:
227231
NUMPY_DTYPE_TO_DTYPE = {
228232
numpy.dtype(numpy.int8): DLDataType(0, 8, 1),

0 commit comments

Comments
 (0)