Skip to content
This repository was archived by the owner on Jan 6, 2023. It is now read-only.

Commit 7e1f0ac

Browse files
aivanoufacebook-github-bot
authored andcommitted
Torchelastic bring back setup.py
Summary: Restore setup.py that requires for building docs. Reviewed By: kiukchung Differential Revision: D27974755 fbshipit-source-id: 2c2ecd9f9b4a9cd0f3415991ff05270ebb6ef195
1 parent c633529 commit 7e1f0ac

File tree

4 files changed

+113
-2
lines changed

4 files changed

+113
-2
lines changed

__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.

setup.py

+49-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,54 @@
77
# LICENSE file in the root directory of this source tree.
88

99

10+
import sys
11+
12+
from setuptools import find_packages, setup
13+
14+
15+
def get_version():
16+
return "0.2.3.dev0"
17+
18+
1019
if __name__ == "__main__":
11-
raise RuntimeError(
12-
"Torchelastic got merged with pytorch. Use pytorch: https://pytorch.org/"
20+
if sys.version_info < (3, 8):
21+
sys.exit("python >= 3.8 required for torchelastic")
22+
23+
with open("README.md", encoding="utf8") as f:
24+
readme = f.read()
25+
26+
with open("requirements.txt") as f:
27+
reqs = f.read()
28+
29+
version = get_version()
30+
print("-- Building version: " + version)
31+
32+
setup(
33+
# Metadata
34+
name="torchelastic",
35+
version=version,
36+
author="PyTorch Elastic Devs",
37+
author_email="[email protected]",
38+
description="PyTorch Elastic Training",
39+
long_description=readme,
40+
long_description_content_type="text/markdown",
41+
url="https://github.com/pytorch/elastic",
42+
license="BSD-3",
43+
keywords=["pytorch", "machine learning", "elastic", "distributed"],
44+
python_requires=">=3.8",
45+
install_requires=reqs.strip().split("\n"),
46+
include_package_data=True,
47+
packages=find_packages(exclude=("*.test", "aws*", "*.fb")),
48+
test_suite="torchelastic.tsm.test.suites.unittests",
49+
# PyPI package information.
50+
classifiers=[
51+
"Development Status :: 4 - Beta",
52+
"Intended Audience :: Developers",
53+
"Intended Audience :: Science/Research",
54+
"License :: OSI Approved :: BSD License",
55+
"Programming Language :: Python :: 3",
56+
"Programming Language :: Python :: 3.8",
57+
"Topic :: System :: Distributed Computing",
58+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
59+
],
1360
)

torchelastic/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.

torchelastic/tsm/test/suites.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
import os
10+
import random
11+
import unittest
12+
from itertools import chain
13+
14+
15+
def _circleci_parallelism(suite):
16+
"""Allow for parallelism in CircleCI for speedier tests.."""
17+
if int(os.environ.get("CIRCLE_NODE_TOTAL", 0)) <= 1:
18+
# either not running on circleci, or we're not using parallelism.
19+
return suite
20+
# tests are automatically sorted by discover, so we will get the same ordering
21+
# on all hosts.
22+
total = int(os.environ["CIRCLE_NODE_TOTAL"])
23+
index = int(os.environ["CIRCLE_NODE_INDEX"])
24+
25+
# right now each test is corresponds to a /file/. Certain files are slower than
26+
# others, so we want to flatten it
27+
tests = [testfile._tests for testfile in suite._tests]
28+
tests = list(chain.from_iterable(tests))
29+
random.Random(42).shuffle(tests)
30+
tests = [t for i, t in enumerate(tests) if i % total == index]
31+
return unittest.TestSuite(tests)
32+
33+
34+
def unittests():
35+
"""
36+
Short tests.
37+
38+
Runs on CircleCI on every commit. Returns everything in the tests root directory.
39+
"""
40+
test_loader = unittest.TestLoader()
41+
test_suite = test_loader.discover(
42+
"torchelastic/tsm", pattern="*_test.py", top_level_dir="."
43+
)
44+
test_suite = _circleci_parallelism(test_suite)
45+
return test_suite
46+
47+
48+
if __name__ == "__main__":
49+
runner = unittest.TextTestRunner()
50+
runner.run(unittests())

0 commit comments

Comments
 (0)