-
Notifications
You must be signed in to change notification settings - Fork 403
/
Copy pathinstall_requirements.py
112 lines (94 loc) · 2.92 KB
/
install_requirements.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import platform
import subprocess
def pip_install_packages(packages, extra_index_url=None, verbose=False, pre=False):
for package in packages:
try:
print(f"..installing {package}")
# base command
cmd = ["pip", "install"]
if pre:
cmd.append("--pre")
# add '-q' if not verbose
if not verbose:
cmd.append("-q")
# add package name
cmd.append(package)
# add extra_index_url if it exists
if extra_index_url:
cmd.extend(["--extra-index-url", extra_index_url])
if verbose:
print(cmd)
# run the command and capture output
result = subprocess.run(cmd, capture_output=not verbose, text=True)
if verbose:
# print stdout and stderr if verbose
print(result.stdout)
print(result.stderr)
except Exception as e:
print(f"failed to install {package}: {e}")
return
def install_requirements(verbose=False):
# Detect System
os_system = platform.system()
print(f"system detected: {os_system}")
# Install pytorch
torch = [
"torch",
"torchvision",
"torchaudio"
]
extra_index_url = "https://download.pytorch.org/whl/nightly/cu121"
pip_install_packages(torch, extra_index_url=extra_index_url, verbose=verbose, pre=True)
# List of common packages to install
common = [
"clean-fid",
"colab-convert",
"einops",
"ftfy",
"ipython",
"ipywidgets",
"jsonmerge",
"jupyterlab",
"jupyter_http_over_ws",
"kornia",
"matplotlib",
"notebook",
"numexpr",
"omegaconf",
"opencv-python",
"pandas",
"pytorch_lightning==1.7.7",
"resize-right",
"scikit-image==0.19.3",
"scikit-learn",
"timm",
"torchdiffeq",
"transformers==4.19.2",
"safetensors",
"albumentations",
"more_itertools",
"devtools",
"validators",
"numpngw",
"open-clip-torch",
"torchsde",
"ninja",
"pydantic",
]
pip_install_packages(common)
# Xformers install
linux_xformers = [
"triton",
"xformers==0.0.21.dev546",
]
windows_xformers = [
"xformers==0.0.21.dev546",
]
xformers = windows_xformers if os_system == 'Windows' else linux_xformers
pip_install_packages(xformers)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--verbose', action='store_true', help='print pip install stuff')
args = parser.parse_args()
install_requirements(verbose=args.verbose)