Skip to content

Commit

Permalink
setup.py: add compile flags for bf16 and fp8.
Browse files Browse the repository at this point in the history
  • Loading branch information
fanshiqing committed Apr 25, 2024
1 parent 9c8c42a commit 8040918
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
24 changes: 12 additions & 12 deletions csrc/permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ std::tuple<torch::Tensor, torch::Tensor, std::vector<Tensor>> moe_permute_topK_o

break;
}
// #ifdef ENABLE_BF16
#ifdef ENABLE_BF16
case at::ScalarType::BFloat16:
{
using dType = cutlass::bfloat16_t;
Expand All @@ -545,8 +545,8 @@ std::tuple<torch::Tensor, torch::Tensor, std::vector<Tensor>> moe_permute_topK_o

break;
}
// #endif
// #ifdef ENABLE_FP8
#endif
#ifdef ENABLE_FP8
case at::ScalarType::Float8_e5m2:
{
using dType = cutlass::float_e5m2_t;
Expand Down Expand Up @@ -589,7 +589,7 @@ std::tuple<torch::Tensor, torch::Tensor, std::vector<Tensor>> moe_permute_topK_o

break;
}
// #endif
#endif
default:
throw std::runtime_error("Wrong activation tensor type.");
}
Expand Down Expand Up @@ -670,7 +670,7 @@ torch::Tensor moe_recover_topK_op(

break;
}
// #ifdef ENABLE_BF16
#ifdef ENABLE_BF16
case at::ScalarType::BFloat16:
{
using dType = cutlass::bfloat16_t;
Expand All @@ -692,8 +692,8 @@ torch::Tensor moe_recover_topK_op(

break;
}
// #endif
// #ifdef ENABLE_FP8
#endif
#ifdef ENABLE_FP8
case at::ScalarType::Float8_e5m2:
{
using dType = cutlass::float_e5m2_t;
Expand Down Expand Up @@ -736,7 +736,7 @@ torch::Tensor moe_recover_topK_op(

break;
}
// #endif
#endif
default:
throw std::runtime_error("Wrong activation tensor type.");
}
Expand Down Expand Up @@ -819,7 +819,7 @@ std::tuple<torch::Tensor, torch::Tensor> moe_recover_topK_bwd_op(

break;
}
// #ifdef ENABLE_BF16
#ifdef ENABLE_BF16
case at::ScalarType::BFloat16:
{
using dType = cutlass::bfloat16_t;
Expand All @@ -844,8 +844,8 @@ std::tuple<torch::Tensor, torch::Tensor> moe_recover_topK_bwd_op(

break;
}
// #endif
// #ifdef ENABLE_FP8
#endif
#ifdef ENABLE_FP8
case at::ScalarType::Float8_e5m2:
{
using dType = cutlass::float_e5m2_t;
Expand Down Expand Up @@ -894,7 +894,7 @@ std::tuple<torch::Tensor, torch::Tensor> moe_recover_topK_bwd_op(

break;
}
// #endif
#endif
default:
throw std::runtime_error("Wrong activation tensor type.");
}
Expand Down
15 changes: 14 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


if os.environ.get("TORCH_CUDA_ARCH_LIST"):
# Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}

# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "9.0" or "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX". Here,
# the "9.0+PTX" option asks the
# compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if env_arch_list:
# Let PyTorch builder to choose device to target for.
device_capability = ""
else:
Expand All @@ -16,6 +27,8 @@

nvcc_flags = [
"-std=c++17", # NOTE: CUTLASS requires c++17
"-DENABLE_BF16", # Enable BF16 for cuda_version >= 11
# "-DENABLE_FP8", # Enable FP8 for cuda_version >= 11.8
]

if device_capability:
Expand Down

0 comments on commit 8040918

Please sign in to comment.