-
Notifications
You must be signed in to change notification settings - Fork 49
[Feature] support C++ dtype_trait and Python-side mapping to C++ dtype #374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,206 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| /*! | ||
| * \file tvm/ffi/extra/dtype.h | ||
| * \brief Type traits to map C++ types to DLPack dtypes. | ||
| */ | ||
| #ifndef TVM_FFI_EXTRA_DTYPE_H_ | ||
| #define TVM_FFI_EXTRA_DTYPE_H_ | ||
|
|
||
| #include <dlpack/dlpack.h> | ||
|
|
||
| #include <type_traits> | ||
|
|
||
| // Common for both CUDA and HIP | ||
| struct __half; | ||
|
|
||
| // CUDA | ||
| struct __nv_fp8_e4m3; | ||
| struct __nv_bfloat16; | ||
| struct __nv_fp8_e5m2; | ||
| struct __nv_fp8_e8m0; | ||
| struct __nv_fp4_e2m1; | ||
| struct __nv_fp4x2_e2m1; | ||
|
|
||
| // HIP | ||
| struct __hip_bfloat16; | ||
| struct hip_bfloat16; // i don't know why this is a struct instead of alias... | ||
| struct __hip_fp8_e4m3; | ||
| struct __hip_fp8_e4m3_fnuz; | ||
| struct __hip_fp8_e5m2; | ||
| struct __hip_fp8_e5m2_fnuz; | ||
| struct __hip_fp4_e2m1; | ||
| struct __hip_fp4x2_e2m1; | ||
|
|
||
| namespace tvm_ffi { | ||
|
|
||
| /// \cond Doxygen_Suppress | ||
|
|
||
| template <typename T> | ||
| struct dtype_trait {}; | ||
|
|
||
| namespace details::dtypes { | ||
|
|
||
| template <typename T> | ||
| struct integer_trait { | ||
| static constexpr DLDataType value = { | ||
| /* code = */ std::is_signed_v<T> ? kDLInt : kDLUInt, | ||
| /* bits = */ static_cast<uint8_t>(sizeof(T) * 8), | ||
| /* lanes = */ 1, | ||
| }; | ||
| }; | ||
|
|
||
| template <typename T> | ||
| struct float_trait { | ||
| static constexpr DLDataType value = { | ||
| /* code = */ kDLFloat, | ||
| /* bits = */ static_cast<uint8_t>(sizeof(T) * 8), | ||
| /* lanes = */ 1, | ||
| }; | ||
| }; | ||
|
|
||
| } // namespace details::dtypes | ||
|
|
||
| template <> | ||
| struct dtype_trait<signed char> : details::dtypes::integer_trait<signed char> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<unsigned char> : details::dtypes::integer_trait<unsigned char> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<signed short> : details::dtypes::integer_trait<signed short> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<unsigned short> : details::dtypes::integer_trait<unsigned short> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<signed int> : details::dtypes::integer_trait<signed int> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<unsigned int> : details::dtypes::integer_trait<unsigned int> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<signed long> : details::dtypes::integer_trait<signed long> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<unsigned long> : details::dtypes::integer_trait<unsigned long> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<signed long long> : details::dtypes::integer_trait<signed long long> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<unsigned long long> : details::dtypes::integer_trait<unsigned long long> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<float> : details::dtypes::float_trait<float> {}; | ||
|
|
||
| template <> | ||
| struct dtype_trait<double> : details::dtypes::float_trait<double> {}; | ||
|
|
||
| // Specialization for bool | ||
|
|
||
| template <> | ||
| struct dtype_trait<bool> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLBool, 8, 1}; | ||
| }; | ||
|
|
||
| // Specializations for CUDA | ||
|
|
||
| template <> | ||
| struct dtype_trait<__half> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat, 16, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__nv_bfloat16> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__nv_fp8_e4m3> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__nv_fp8_e5m2> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__nv_fp8_e8m0> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e8m0fnu, 8, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__nv_fp4_e2m1> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__nv_fp4x2_e2m1> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2}; | ||
| }; | ||
|
|
||
| // Specializations for HIP | ||
|
|
||
| template <> | ||
| struct dtype_trait<__hip_bfloat16> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<hip_bfloat16> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__hip_fp8_e4m3> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__hip_fp8_e4m3_fnuz> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fnuz, 8, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__hip_fp8_e5m2> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__hip_fp8_e5m2_fnuz> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2fnuz, 8, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__hip_fp4_e2m1> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1}; | ||
| }; | ||
|
|
||
| template <> | ||
| struct dtype_trait<__hip_fp4x2_e2m1> { | ||
| static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2}; | ||
| }; | ||
|
|
||
| /// \endcond | ||
|
|
||
| } // namespace tvm_ffi | ||
|
|
||
| #endif // TVM_FFI_EXTRA_DTYPE_H_ | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,104 @@ | ||||||||||||||||||
| # Licensed to the Apache Software Foundation (ASF) under one | ||||||||||||||||||
| # or more contributor license agreements. See the NOTICE file | ||||||||||||||||||
| # distributed with this work for additional information | ||||||||||||||||||
| # regarding copyright ownership. The ASF licenses this file | ||||||||||||||||||
| # to you under the Apache License, Version 2.0 (the | ||||||||||||||||||
| # "License"); you may not use this file except in compliance | ||||||||||||||||||
| # with the License. You may obtain a copy of the License at | ||||||||||||||||||
| # | ||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||
| # | ||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, | ||||||||||||||||||
| # software distributed under the License is distributed on an | ||||||||||||||||||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||||||||||||||||||
| # KIND, either express or implied. See the License for the | ||||||||||||||||||
| # specific language governing permissions and limitations | ||||||||||||||||||
| # under the License. | ||||||||||||||||||
| """Utilities for C++ dtype conversion.""" | ||||||||||||||||||
|
|
||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||
|
|
||||||||||||||||||
| import functools | ||||||||||||||||||
| from typing import Any, Literal | ||||||||||||||||||
|
|
||||||||||||||||||
| CPU_DTYPE_MAP = { | ||||||||||||||||||
| "int8": "int8_t", | ||||||||||||||||||
| "int16": "int16_t", | ||||||||||||||||||
| "int32": "int32_t", | ||||||||||||||||||
| "int64": "int64_t", | ||||||||||||||||||
| "uint8": "uint8_t", | ||||||||||||||||||
| "uint16": "uint16_t", | ||||||||||||||||||
| "uint32": "uint32_t", | ||||||||||||||||||
| "uint64": "uint64_t", | ||||||||||||||||||
| "float32": "float", | ||||||||||||||||||
| "float64": "double", | ||||||||||||||||||
| "bool": "bool", | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| CUDA_DTYPE_MAP = { | ||||||||||||||||||
| "float16": "__half", | ||||||||||||||||||
| "bfloat16": "__nv_bfloat16", | ||||||||||||||||||
| "float8_e4m3fn": "__nv_fp8_e4m3", | ||||||||||||||||||
| # "float8_e4m3fnuz": "__nv_fp8_e4m3", | ||||||||||||||||||
| "float8_e5m2": "__nv_fp8_e5m2", | ||||||||||||||||||
| # "float8_e5m2fnuz": "__nv_fp8_e5m2", | ||||||||||||||||||
| "float8_e8m0fnu": "__nv_fp8_e8m0", | ||||||||||||||||||
|
Comment on lines
+41
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The commented-out lines for
Suggested change
|
||||||||||||||||||
| "float4_e2m1": "__nv_fp4_e2m1", | ||||||||||||||||||
| "float4_e2m1fn_x2": "__nv_fp4x2_e2m1", | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| ROCM_DTYPE_MAP = { | ||||||||||||||||||
| "float16": "__half", | ||||||||||||||||||
| "bfloat16": "__hip_bfloat16", | ||||||||||||||||||
| "float8_e4m3fn": "__hip_fp8_e4m3", | ||||||||||||||||||
| "float8_e4m3fnuz": "__hip_fp8_e4m3_fnuz", | ||||||||||||||||||
| "float8_e5m2": "__hip_fp8_e5m2", | ||||||||||||||||||
| "float8_e5m2fnuz": "__hip_fp8_e5m2_fnuz", | ||||||||||||||||||
| "float4_e2m1": "__hip_fp4_e2m1", | ||||||||||||||||||
| "float4_e2m1fn_x2": "__hip_fp4x2_e2m1", | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||||||||||
| def _determine_backend_once() -> Literal["cpu", "cuda", "rocm"]: | ||||||||||||||||||
| try: | ||||||||||||||||||
| import torch # noqa: PLC0415 | ||||||||||||||||||
|
|
||||||||||||||||||
| if torch.cuda.is_available(): | ||||||||||||||||||
| if torch.version.cuda is not None: | ||||||||||||||||||
| return "cuda" | ||||||||||||||||||
| elif torch.version.hip is not None: | ||||||||||||||||||
| return "rocm" | ||||||||||||||||||
| except ImportError: | ||||||||||||||||||
| pass | ||||||||||||||||||
| return "cpu" | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def to_cpp_dtype(dtype_str: str | Any) -> str: | ||||||||||||||||||
| """Convert a dtype to its corresponding C++ dtype string. | ||||||||||||||||||
|
|
||||||||||||||||||
| Parameters | ||||||||||||||||||
| ---------- | ||||||||||||||||||
| dtype_str : `str` or `torch.dtype` | ||||||||||||||||||
| The dtype string or object to convert. | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns | ||||||||||||||||||
| ------- | ||||||||||||||||||
| str | ||||||||||||||||||
| The corresponding C++ dtype string. | ||||||||||||||||||
|
|
||||||||||||||||||
| """ | ||||||||||||||||||
| if not isinstance(dtype_str, str): | ||||||||||||||||||
| dtype_str = str(dtype_str) | ||||||||||||||||||
| if dtype_str.startswith("torch."): | ||||||||||||||||||
| dtype_str = dtype_str[6:] | ||||||||||||||||||
| cpp_str = CPU_DTYPE_MAP.get(dtype_str) | ||||||||||||||||||
| if cpp_str is not None: | ||||||||||||||||||
| return cpp_str | ||||||||||||||||||
| backend = _determine_backend_once() | ||||||||||||||||||
| if backend in ("cuda", "rocm"): | ||||||||||||||||||
| dtype_map = CUDA_DTYPE_MAP if backend == "cuda" else ROCM_DTYPE_MAP | ||||||||||||||||||
| cpp_str = dtype_map.get(dtype_str) | ||||||||||||||||||
| if cpp_str is not None: | ||||||||||||||||||
| return cpp_str | ||||||||||||||||||
| raise ValueError(f"Unsupported dtype string: {dtype_str} for {backend = }") | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment appears to be a personal note and is a bit informal for source code. It would be best to remove it to maintain a professional tone in the codebase.