-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
55 lines (51 loc) · 2.22 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from setuptools import setup, find_packages
from pathlib import Path
# Check PyTorch version
pytorch_version_l = '1.11.0'
pytorch_version_u = '2.1.0' # excluded
msg_install_pytorch = (f'It is recommended to manually install PyTorch '
f'(>={pytorch_version_l},<{pytorch_version_u}) suitable '
'for your system ahead: https://pytorch.org/get-started.\n')
try:
import torch
if torch.__version__ < pytorch_version_l:
print(f'PyTorch version {torch.__version__} is too low. '
+ msg_install_pytorch)
if torch.__version__ >= pytorch_version_u:
print(f'PyTorch version {torch.__version__} is too high. '
+ msg_install_pytorch)
except ModuleNotFoundError:
print(f'PyTorch is not installed. {msg_install_pytorch}')
with open('auto_LiRPA/__init__.py') as file:
for line in file.readlines():
if '__version__' in line:
version = eval(line.strip().split()[-1])
this_directory = Path(__file__).parent
long_description = (this_directory / 'README.md').read_text()
print(f'Installing auto_LiRPA {version}')
setup(
name='auto_LiRPA',
version=version,
description='A library for Automatic Linear Relaxation based Perturbation Analysis (LiRPA) on general computational graphs, with a focus on adversarial robustness verification and certification of deep neural networks.',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/KaidiXu/auto_LiRPA',
author='Kaidi Xu, Zhouxing Shi, Huan Zhang, Yihan Wang, Shiqi Wang, Linyi Li, Jinqi (Kathryn) Chen, Zhuolin Yang',
packages=find_packages(),
install_requires=[
f'torch>={pytorch_version_l},<{pytorch_version_u}',
'torchvision>=0.9',
'numpy>=1.20',
'packaging>=20.0',
'pytest>=5.0',
'pylint>=2.15',
'pytest-order>=1.0.0',
'appdirs>=1.4',
'pyyaml>=5.0',
'ninja>=1.10',
'tqdm>=4.64',
],
platforms=['any'],
license='BSD',
)