Skip to content

Commit 31f46fe

Browse files
jeffraShaden SmithReza Yazdani
authored
DeepSpeed JIT op + PyPI support (#496)
Co-authored-by: Shaden Smith <[email protected]> Co-authored-by: Reza Yazdani <[email protected]>
1 parent 0ad4fd8 commit 31f46fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1673
-681
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ build/
1010
dist/
1111
*.so
1212
deepspeed.egg-info/
13+
build.txt
1314

1415
# Website
1516
docs/_site/
@@ -23,3 +24,7 @@ docs/code-docs/build
2324

2425
# Testing data
2526
tests/unit/saved_checkpoint/
27+
28+
# Dev/IDE data
29+
.vscode
30+
.theia

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
[submodule "third_party/apex"]
2-
path = third_party/apex
3-
url = https://github.com/NVIDIA/apex.git
41
[submodule "DeepSpeedExamples"]
52
path = DeepSpeedExamples
63
url = https://github.com/microsoft/DeepSpeedExamples

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
global-include *.cpp *.h *.cu *.tr *.cuh *.cc *.txt

README.md

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[![Build Status](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_apis/build/status/microsoft.DeepSpeed?branchName=master)](https://dev.azure.com/DeepSpeedMSFT/DeepSpeed/_build/latest?definitionId=1&branchName=master)
2+
[![PyPI version](https://badge.fury.io/py/deepspeed.svg)](https://badge.fury.io/py/deepspeed)
23
[![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)
34
[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
45
[![Docker Pulls](https://img.shields.io/docker/pulls/deepspeed/deepspeed)](https://hub.docker.com/r/deepspeed/deepspeed)
@@ -31,29 +32,25 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
3132

3233

3334
# News
34-
* [2020/09/10] [DeepSpeed: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
35+
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
36+
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
37+
* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
3538
* [Powering 10x longer sequences and 6x faster execution through DeepSpeed Sparse Attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention-news.html)
3639
* [Training a trillion parameters with pipeline parallelism](https://www.deepspeed.ai/news/2020/09/08/pipeline-parallelism.html)
3740
* [Up to 5x less communication and 3.4x faster training through 1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-news.html)
3841
* [10x bigger model training on a single GPU with ZeRO-Offload](https://www.deepspeed.ai/news/2020/09/08/ZeRO-Offload.html)
3942
* [2020/08/07] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) is now available on-demand
40-
* [2020/07/24] [DeepSpeed Microsoft Research Webinar](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-On-Demand.html) on August 6th, 2020
41-
[![DeepSpeed webinar](docs/assets/images/webinar-aug2020.png)](https://note.microsoft.com/MSR-Webinar-DeepSpeed-Registration-Live.html)
42-
* [2020/05/19] [ZeRO-2 & DeepSpeed: Shattering Barriers of Deep Learning Speed & Scale](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/)
43-
* [2020/05/19] [An Order-of-Magnitude Larger and Faster Training with ZeRO-2](https://www.deepspeed.ai/news/2020/05/18/zero-stage2.html)
44-
* [2020/05/19] [The Fastest and Most Efficient BERT Training through Optimized Transformer Kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
45-
* [2020/02/13] [Turing-NLG: A 17-billion-parameter language model by Microsoft](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/)
46-
* [2020/02/13] [ZeRO & DeepSpeed: New system optimizations enable training models with over 100 billion parameters](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)
4743

4844

4945
# Table of Contents
5046
| Section | Description |
5147
| --------------------------------------- | ------------------------------------------- |
5248
| [Why DeepSpeed?](#why-deepspeed) | DeepSpeed overview |
53-
| [Features](#features) | DeepSpeed features |
54-
| [Further Reading](#further-reading) | DeepSpeed documentation, tutorials, etc. |
55-
| [Contributing](#contributing) | Instructions for contributing to DeepSpeed |
56-
| [Publications](#publications) | DeepSpeed publications |
49+
| [Install](#installation) | Installation details |
50+
| [Features](#features) | Feature list and overview |
51+
| [Further Reading](#further-reading) | Documentation, tutorials, etc. |
52+
| [Contributing](#contributing) | Instructions for contributing |
53+
| [Publications](#publications) | Publications related to DeepSpeed |
5754

5855
# Why DeepSpeed?
5956
Training advanced deep learning models is challenging. Beyond model design,
@@ -65,8 +62,32 @@ a large model easily runs out of memory with pure data parallelism and it is
6562
difficult to use model parallelism. DeepSpeed addresses these challenges to
6663
accelerate model development *and* training.
6764

68-
# Features
65+
# Installation
66+
67+
The quickest way to get started with DeepSpeed is via pip, this will install
68+
the latest release of DeepSpeed which is not tied to specific PyTorch or CUDA
69+
versions. DeepSpeed includes several C++/CUDA extensions that we commonly refer
70+
to as our 'ops'. By default, all of these extensions/ops will be built
71+
just-in-time (JIT) using [torch's JIT C++ extension loader that relies on
72+
ninja](https://pytorch.org/docs/stable/cpp_extension.html) to build and
73+
dynamically link them at runtime.
74+
75+
```bash
76+
pip install deepspeed
77+
```
78+
79+
After installation you can validate your install and see which extensions/ops
80+
your machine is compatible with via the DeepSpeed environment report.
6981

82+
```bash
83+
ds_report
84+
```
85+
86+
If you would like to pre-install any of the DeepSpeed extensions/ops (instead
87+
of JIT compiling) or install pre-compiled ops via PyPI please see our [advanced
88+
installation instructions](https://www.deepspeed.ai/tutorials/advanced-install/).
89+
90+
# Features
7091
Below we provide a brief feature list, see our detailed [feature
7192
overview](https://www.deepspeed.ai/features/) for descriptions and usage.
7293

azure-pipelines.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,15 @@ jobs:
4343
conda install -q --yes conda
4444
conda install -q --yes pip
4545
conda install -q --yes gxx_linux-64
46-
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
4746
echo "PATH=$PATH, LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
4847
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'
4948
5049
# Manually install torch/torchvision first to enforce versioning.
5150
- script: |
5251
source activate $(conda_env)
5352
pip install --progress-bar=off torch==$(pytorch.version) torchvision==$(torchvision.version)
54-
#-f https://download.pytorch.org/whl/torch_stable.html
55-
./install.sh --local_only
56-
#python -I basic_install_test.py
53+
pip install .[dev]
54+
ds_report
5755
displayName: 'Install DeepSpeed'
5856
5957
- script: |
@@ -71,7 +69,8 @@ jobs:
7169
7270
- script: |
7371
source activate $(conda_env)
74-
pytest --durations=0 --forked --verbose -x tests/unit/
72+
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
73+
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
7574
displayName: 'Unit tests'
7675
7776
# - script: |

basic_install_test.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

bin/ds_report

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env python
2+
3+
from deepspeed.env_report import main
4+
5+
if __name__ == '__main__':
6+
main()

csrc/adam/compat.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/* Copyright 2020 The Microsoft DeepSpeed Team
2+
Copyright NVIDIA/apex
3+
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
4+
*/
5+
6+
#ifndef TORCH_CHECK
7+
#define TORCH_CHECK AT_CHECK
8+
#endif
9+
10+
#ifdef VERSION_GE_1_3
11+
#define DATA_PTR data_ptr
12+
#else
13+
#define DATA_PTR data
14+
#endif

csrc/adam/custom_cuda_kernel.cu

100644100755
Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,15 @@
44

55
__global__ void param_update_kernel(const float* input, __half* output, int size)
66
{
7-
const float4* input_cast = reinterpret_cast<const float4*>(input);
8-
float2* output_cast = reinterpret_cast<float2*>(output);
9-
107
int id = blockIdx.x * blockDim.x + threadIdx.x;
118

12-
if (id < size) {
13-
float4 data = input_cast[id];
14-
float2 cast_data;
15-
__half* output_h = reinterpret_cast<__half*>(&cast_data);
16-
17-
output_h[0] = (__half)data.x;
18-
output_h[1] = (__half)data.y;
19-
output_h[2] = (__half)data.z;
20-
output_h[3] = (__half)data.w;
21-
22-
output_cast[id] = cast_data;
23-
}
9+
if (id < size) { output[id] = (__half)input[id]; }
2410
}
2511

2612
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
2713
{
28-
int threads = 512;
14+
int threads = 1024;
2915

30-
size /= 4;
3116
dim3 grid_dim((size - 1) / threads + 1);
3217
dim3 block_dim(threads);
3318

csrc/adam/fused_adam_frontend.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <torch/extension.h>
2+
3+
void multi_tensor_adam_cuda(int chunk_size,
4+
at::Tensor noop_flag,
5+
std::vector<std::vector<at::Tensor>> tensor_lists,
6+
const float lr,
7+
const float beta1,
8+
const float beta2,
9+
const float epsilon,
10+
const int step,
11+
const int mode,
12+
const int bias_correction,
13+
const float weight_decay);
14+
15+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
16+
{
17+
m.def("multi_tensor_adam",
18+
&multi_tensor_adam_cuda,
19+
"Compute and apply gradient update to parameters for Adam optimizer");
20+
}

0 commit comments

Comments
 (0)