From 5d1f8ffd2131d14b19cb0d0220a31e8af5e6dac7 Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Tue, 28 Oct 2025 17:15:55 +0800 Subject: [PATCH] feat: support `ml_dtypes<0.5` --- python/tvm_ffi/cython/dtype.pxi | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi index 3d903466..15f94187 100644 --- a/python/tvm_ffi/cython/dtype.pxi +++ b/python/tvm_ffi/cython/dtype.pxi @@ -205,24 +205,28 @@ else: if ml_dtypes is not None: MLDTYPES_DTYPE_TO_DTYPE = { - numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1), numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1), - numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1), numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1), numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1), - numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1), - numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1), numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1), numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1), numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11, 8, 1), numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12, 8, 1), numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13, 8, 1), - numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14, 8, 1), - numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15, 6, 1), - numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16, 6, 1), - numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17, 4, 1), } + if hasattr(ml_dtypes, "int2"): # ml_dtypes >= 0.5.0 + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] = DLDataType(0, 2, 1) + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.uint2)] = DLDataType(1, 2, 1) + + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e3m4)] = DLDataType(7, 8, 1) + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e4m3)] = DLDataType(8, 8, 1) + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] = DLDataType(14, 8, 1) + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] = DLDataType(15, 6, 1) + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] = DLDataType(16, 6, 1) + MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] = DLDataType(17, 4, 1) + + if numpy is not None: NUMPY_DTYPE_TO_DTYPE = { numpy.dtype(numpy.int8): DLDataType(0, 8, 1),