forked from microsoft/SparTA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
71 lines (62 loc) · 2.16 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from setuptools import setup, find_packages
import jinja2
import torch
from torch.utils.cpp_extension import BuildExtension
version = '1.0'
ext_modules = []
if torch.cuda.is_available():
from torch.utils.cpp_extension import CUDAExtension
major, minor = torch.cuda.get_device_capability()
with open(os.path.join('csrc', 'moe_sparse_forward_kernel.cu.j2')) as f:
moe_template = f.read()
moe_kernel = jinja2.Template(moe_template).render({'FP16': major >= 7})
os.makedirs(os.path.join('csrc', 'build'), exist_ok=True)
with open(os.path.join('csrc', 'build', 'moe_sparse_forward_kernel.cu'), 'w') as f:
f.write(moe_kernel)
moe_ext = CUDAExtension(
name='sparta.sp_moe_ops',
sources=[
os.path.join('csrc', 'moe_sparse_forward.cpp'),
os.path.join('csrc', 'build', 'moe_sparse_forward_kernel.cu'),
],
extra_compile_args=[
'-std=c++17',
'-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
],
)
ext_modules.append(moe_ext)
seqlen_dynamic_attention_ext = CUDAExtension(
name='sparta.sp_attn_ops',
sources=[
os.path.join('csrc', 'seqlen_dynamic_sparse_attention_forward.cpp'),
os.path.join('csrc', 'seqlen_dynamic_sparse_attention_forward_kernel.cu'),
],
extra_compile_args=['-std=c++17', '-O3'],
)
ext_modules.append(seqlen_dynamic_attention_ext)
setup(
name='SparTA',
version=version,
description='Deployment tool',
author='MSRA',
author_email='[email protected]',
packages=find_packages(exclude=['test', 'test.*', 'examples', 'examples.*']),
install_requires=[
'jinja2',
'pycuda', # 'pip install pycuda' works for most cases
'nni',
],
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension},
include_package_data=True,
package_data={
'sparta.kernels.templates': ['*.j2'],
'sparta.kernels.look_up_tables': ['*.csv'],
'sparta.tesa.templates': ['*.j2'],
},
)