Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions include/tvm/ffi/extra/dtype.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ffi/extra/dtype.h
* \brief Type traits to map C++ types to DLPack dtypes.
*/
#ifndef TVM_FFI_EXTRA_DTYPE_H_
#define TVM_FFI_EXTRA_DTYPE_H_

#include <dlpack/dlpack.h>

#include <type_traits>

// Common for both CUDA and HIP
struct __half;

// CUDA
struct __nv_fp8_e4m3;
struct __nv_bfloat16;
struct __nv_fp8_e5m2;
struct __nv_fp8_e8m0;
struct __nv_fp4_e2m1;
struct __nv_fp4x2_e2m1;

// HIP
struct __hip_bfloat16;
struct hip_bfloat16; // i don't know why this is a struct instead of alias...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment appears to be a personal note and is a bit informal for source code. It would be best to remove it to maintain a professional tone in the codebase.

Suggested change
struct hip_bfloat16; // i don't know why this is a struct instead of alias...
struct hip_bfloat16;

struct __hip_fp8_e4m3;
struct __hip_fp8_e4m3_fnuz;
struct __hip_fp8_e5m2;
struct __hip_fp8_e5m2_fnuz;
struct __hip_fp4_e2m1;
struct __hip_fp4x2_e2m1;

namespace tvm_ffi {

/// \cond Doxygen_Suppress

template <typename T>
struct dtype_trait {};

namespace details::dtypes {

template <typename T>
struct integer_trait {
static constexpr DLDataType value = {
/* code = */ std::is_signed_v<T> ? kDLInt : kDLUInt,
/* bits = */ static_cast<uint8_t>(sizeof(T) * 8),
/* lanes = */ 1,
};
};

template <typename T>
struct float_trait {
static constexpr DLDataType value = {
/* code = */ kDLFloat,
/* bits = */ static_cast<uint8_t>(sizeof(T) * 8),
/* lanes = */ 1,
};
};

} // namespace details::dtypes

template <>
struct dtype_trait<signed char> : details::dtypes::integer_trait<signed char> {};

template <>
struct dtype_trait<unsigned char> : details::dtypes::integer_trait<unsigned char> {};

template <>
struct dtype_trait<signed short> : details::dtypes::integer_trait<signed short> {};

template <>
struct dtype_trait<unsigned short> : details::dtypes::integer_trait<unsigned short> {};

template <>
struct dtype_trait<signed int> : details::dtypes::integer_trait<signed int> {};

template <>
struct dtype_trait<unsigned int> : details::dtypes::integer_trait<unsigned int> {};

template <>
struct dtype_trait<signed long> : details::dtypes::integer_trait<signed long> {};

template <>
struct dtype_trait<unsigned long> : details::dtypes::integer_trait<unsigned long> {};

template <>
struct dtype_trait<signed long long> : details::dtypes::integer_trait<signed long long> {};

template <>
struct dtype_trait<unsigned long long> : details::dtypes::integer_trait<unsigned long long> {};

template <>
struct dtype_trait<float> : details::dtypes::float_trait<float> {};

template <>
struct dtype_trait<double> : details::dtypes::float_trait<double> {};

// Specialization for bool

template <>
struct dtype_trait<bool> {
static constexpr DLDataType value = {DLDataTypeCode::kDLBool, 8, 1};
};

// Specializations for CUDA

template <>
struct dtype_trait<__half> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat, 16, 1};
};

template <>
struct dtype_trait<__nv_bfloat16> {
static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
};

template <>
struct dtype_trait<__nv_fp8_e4m3> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1};
};

template <>
struct dtype_trait<__nv_fp8_e5m2> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1};
};

template <>
struct dtype_trait<__nv_fp8_e8m0> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e8m0fnu, 8, 1};
};

template <>
struct dtype_trait<__nv_fp4_e2m1> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1};
};

template <>
struct dtype_trait<__nv_fp4x2_e2m1> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2};
};

// Specializations for HIP

template <>
struct dtype_trait<__hip_bfloat16> {
static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
};

template <>
struct dtype_trait<hip_bfloat16> {
static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
};

template <>
struct dtype_trait<__hip_fp8_e4m3> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1};
};

template <>
struct dtype_trait<__hip_fp8_e4m3_fnuz> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fnuz, 8, 1};
};

template <>
struct dtype_trait<__hip_fp8_e5m2> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1};
};

template <>
struct dtype_trait<__hip_fp8_e5m2_fnuz> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2fnuz, 8, 1};
};

template <>
struct dtype_trait<__hip_fp4_e2m1> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1};
};

template <>
struct dtype_trait<__hip_fp4x2_e2m1> {
static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2};
};

/// \endcond

} // namespace tvm_ffi

#endif // TVM_FFI_EXTRA_DTYPE_H_
2 changes: 2 additions & 0 deletions python/tvm_ffi/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.
"""C++ integration helpers for building and loading inline modules."""

from .dtype import to_cpp_dtype
from .extension import build, build_inline, load, load_inline

__all__ = [
"build",
"build_inline",
"load",
"load_inline",
"to_cpp_dtype",
]
104 changes: 104 additions & 0 deletions python/tvm_ffi/cpp/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utilities for C++ dtype conversion."""

from __future__ import annotations

import functools
from typing import Any, Literal

CPU_DTYPE_MAP = {
"int8": "int8_t",
"int16": "int16_t",
"int32": "int32_t",
"int64": "int64_t",
"uint8": "uint8_t",
"uint16": "uint16_t",
"uint32": "uint32_t",
"uint64": "uint64_t",
"float32": "float",
"float64": "double",
"bool": "bool",
}

CUDA_DTYPE_MAP = {
"float16": "__half",
"bfloat16": "__nv_bfloat16",
"float8_e4m3fn": "__nv_fp8_e4m3",
# "float8_e4m3fnuz": "__nv_fp8_e4m3",
"float8_e5m2": "__nv_fp8_e5m2",
# "float8_e5m2fnuz": "__nv_fp8_e5m2",
"float8_e8m0fnu": "__nv_fp8_e8m0",
Comment on lines +41 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out lines for float8_e4m3fnuz and float8_e5m2fnuz can be confusing. If these dtypes are not supported for the CUDA backend, it's cleaner to remove these lines entirely. This avoids ambiguity for future maintainers about whether this is an incomplete feature or an intentional omission.

Suggested change
"float8_e4m3fn": "__nv_fp8_e4m3",
# "float8_e4m3fnuz": "__nv_fp8_e4m3",
"float8_e5m2": "__nv_fp8_e5m2",
# "float8_e5m2fnuz": "__nv_fp8_e5m2",
"float8_e8m0fnu": "__nv_fp8_e8m0",
"float8_e4m3fn": "__nv_fp8_e4m3",
"float8_e5m2": "__nv_fp8_e5m2",
"float8_e8m0fnu": "__nv_fp8_e8m0",

"float4_e2m1": "__nv_fp4_e2m1",
"float4_e2m1fn_x2": "__nv_fp4x2_e2m1",
}

ROCM_DTYPE_MAP = {
"float16": "__half",
"bfloat16": "__hip_bfloat16",
"float8_e4m3fn": "__hip_fp8_e4m3",
"float8_e4m3fnuz": "__hip_fp8_e4m3_fnuz",
"float8_e5m2": "__hip_fp8_e5m2",
"float8_e5m2fnuz": "__hip_fp8_e5m2_fnuz",
"float4_e2m1": "__hip_fp4_e2m1",
"float4_e2m1fn_x2": "__hip_fp4x2_e2m1",
}


@functools.lru_cache(maxsize=None)
def _determine_backend_once() -> Literal["cpu", "cuda", "rocm"]:
try:
import torch # noqa: PLC0415

if torch.cuda.is_available():
if torch.version.cuda is not None:
return "cuda"
elif torch.version.hip is not None:
return "rocm"
except ImportError:
pass
return "cpu"


def to_cpp_dtype(dtype_str: str | Any) -> str:
"""Convert a dtype to its corresponding C++ dtype string.

Parameters
----------
dtype_str : `str` or `torch.dtype`
The dtype string or object to convert.

Returns
-------
str
The corresponding C++ dtype string.

"""
if not isinstance(dtype_str, str):
dtype_str = str(dtype_str)
if dtype_str.startswith("torch."):
dtype_str = dtype_str[6:]
cpp_str = CPU_DTYPE_MAP.get(dtype_str)
if cpp_str is not None:
return cpp_str
backend = _determine_backend_once()
if backend in ("cuda", "rocm"):
dtype_map = CUDA_DTYPE_MAP if backend == "cuda" else ROCM_DTYPE_MAP
cpp_str = dtype_map.get(dtype_str)
if cpp_str is not None:
return cpp_str
raise ValueError(f"Unsupported dtype string: {dtype_str} for {backend = }")