-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathsetup.py
69 lines (64 loc) · 1.64 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
ext_modules = [
CUDAExtension(
name="randint_cuda",
sources=[
"src/liberate/csprng/randint.cpp",
"src/liberate/csprng/randint_cuda_kernel.cu",
],
),
CUDAExtension(
name="randround_cuda",
sources=[
"src/liberate/csprng/randround.cpp",
"src/liberate/csprng/randround_cuda_kernel.cu",
],
),
CUDAExtension(
name="discrete_gaussian_cuda",
sources=[
"src/liberate/csprng/discrete_gaussian.cpp",
"src/liberate/csprng/discrete_gaussian_cuda_kernel.cu",
],
),
CUDAExtension(
name="chacha20_cuda",
sources=[
"src/liberate/csprng/chacha20.cpp",
"src/liberate/csprng/chacha20_cuda_kernel.cu",
],
),
]
ext_modules_ntt = [
CUDAExtension(
name="ntt_cuda",
sources=[
"src/liberate/ntt/ntt.cpp",
"src/liberate/ntt/ntt_cuda_kernel.cu",
],
)
]
if __name__ == "__main__":
setup(
name="csprng",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
script_args=["build_ext"],
options={
"build": {
"build_lib": "src/liberate/csprng",
}
},
)
setup(
name="ntt",
ext_modules=ext_modules_ntt,
script_args=["build_ext"],
cmdclass={"build_ext": BuildExtension},
options={
"build": {
"build_lib": "src/liberate/ntt",
}
},
)