Skip to content

Commit ca96867

Browse files
authored
Tf32 warnings (#6816)
about #6754 . ### Description show a warning if any thing may enable tf32 is detected ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Qingpeng Li <[email protected]>
1 parent cb257d2 commit ca96867

File tree

11 files changed

+177
-35
lines changed

11 files changed

+177
-35
lines changed

docs/source/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ Technical documentation is available at `docs.monai.io <https://docs.monai.io>`_
6060

6161
.. toctree::
6262
:maxdepth: 1
63-
:caption: Precision and Performance
63+
:caption: Precision and Accelerating
6464

65-
precision_performance
65+
precision_accelerating
6666

6767
.. toctree::
6868
:maxdepth: 1

docs/source/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,11 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
254254
- The options are
255255

256256
```
257-
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips]
257+
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml]
258258
```
259259

260260
which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
261-
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr` and `lpips` respectively.
261+
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips` and `nvidia-ml-py` respectively.
262262

263263

264264
- `pip install 'monai[all]'` installs all the optional dependencies.
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ by TF32 mode so the impact is very wide.
2929
torch.backends.cuda.matmul.allow_tf32 = False # in PyTorch 1.12 and later.
3030
torch.backends.cudnn.allow_tf32 = True
3131
```
32-
Please note that there are environment variables that can override the flags above. For example, the environment variables mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.
33-
34-
We recommend that users print out these two flags for confirmation when unsure.
32+
Please note that there are environment variables that can override the flags above. For example, the environment variable `NVIDIA_TF32_OVERRIDE` mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.
3533

3634
If you are using an [NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), the container includes a layer `ENV TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`.
3735
The default value `torch.backends.cuda.matmul.allow_tf32` will be overridden to `True`.
3836

37+
We recommend that users print out these two flags for confirmation when unsure.
38+
3939
If you can confirm through experiments that your model has no accuracy or convergence issues in TF32 mode and you have NVIDIA Ampere GPUs or above, you can set the two flags above to `True` to speed up your model.

monai/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,13 @@
7878
"utils",
7979
"visualize",
8080
]
81+
82+
try:
83+
from .utils.tf32 import detect_default_tf32
84+
85+
detect_default_tf32()
86+
except BaseException:
87+
from .utils.misc import MONAIEnvVars
88+
89+
if MONAIEnvVars.debug():
90+
raise

monai/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
require_pkg,
116116
run_debug,
117117
run_eval,
118+
version_geq,
118119
version_leq,
119120
)
120121
from .nvtx import Range
@@ -128,6 +129,7 @@
128129
torch_profiler_time_end_to_end,
129130
)
130131
from .state_cacher import StateCacher
132+
from .tf32 import detect_default_tf32, has_ampere_or_later
131133
from .type_conversion import (
132134
convert_data_type,
133135
convert_to_cupy,

monai/utils/module.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pydoc import locate
2626
from re import match
2727
from types import FunctionType, ModuleType
28-
from typing import Any, cast
28+
from typing import Any, Iterable, cast
2929

3030
import torch
3131

@@ -55,6 +55,7 @@
5555
"get_package_version",
5656
"get_torch_version_tuple",
5757
"version_leq",
58+
"version_geq",
5859
"pytorch_after",
5960
]
6061

@@ -518,24 +519,11 @@ def get_torch_version_tuple():
518519
return tuple(int(x) for x in torch.__version__.split(".")[:2])
519520

520521

521-
def version_leq(lhs: str, rhs: str) -> bool:
522+
def parse_version_strs(lhs: str, rhs: str) -> tuple[Iterable[int | str], Iterable[int | str]]:
522523
"""
523-
Returns True if version `lhs` is earlier or equal to `rhs`.
524-
525-
Args:
526-
lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.
527-
rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.
528-
524+
Parse the version strings.
529525
"""
530526

531-
lhs, rhs = str(lhs), str(rhs)
532-
pkging, has_ver = optional_import("pkg_resources", name="packaging")
533-
if has_ver:
534-
try:
535-
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
536-
except pkging.version.InvalidVersion:
537-
return True
538-
539527
def _try_cast(val: str) -> int | str:
540528
val = val.strip()
541529
try:
@@ -554,7 +542,28 @@ def _try_cast(val: str) -> int | str:
554542
# parse the version strings in this basic way without `packaging` package
555543
lhs_ = map(_try_cast, lhs.split("."))
556544
rhs_ = map(_try_cast, rhs.split("."))
545+
return lhs_, rhs_
546+
557547

548+
def version_leq(lhs: str, rhs: str) -> bool:
549+
"""
550+
Returns True if version `lhs` is earlier or equal to `rhs`.
551+
552+
Args:
553+
lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.
554+
rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.
555+
556+
"""
557+
558+
lhs, rhs = str(lhs), str(rhs)
559+
pkging, has_ver = optional_import("pkg_resources", name="packaging")
560+
if has_ver:
561+
try:
562+
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
563+
except pkging.version.InvalidVersion:
564+
return True
565+
566+
lhs_, rhs_ = parse_version_strs(lhs, rhs)
558567
for l, r in zip(lhs_, rhs_):
559568
if l != r:
560569
if isinstance(l, int) and isinstance(r, int):
@@ -564,6 +573,33 @@ def _try_cast(val: str) -> int | str:
564573
return True
565574

566575

576+
def version_geq(lhs: str, rhs: str) -> bool:
577+
"""
578+
Returns True if version `lhs` is later or equal to `rhs`.
579+
580+
Args:
581+
lhs: version name to compare with `rhs`, return True if later or equal to `rhs`.
582+
rhs: version name to compare with `lhs`, return True if earlier or equal to `lhs`.
583+
584+
"""
585+
lhs, rhs = str(lhs), str(rhs)
586+
pkging, has_ver = optional_import("pkg_resources", name="packaging")
587+
if has_ver:
588+
try:
589+
return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs))
590+
except pkging.version.InvalidVersion:
591+
return True
592+
593+
lhs_, rhs_ = parse_version_strs(lhs, rhs)
594+
for l, r in zip(lhs_, rhs_):
595+
if l != r:
596+
if isinstance(l, int) and isinstance(r, int):
597+
return l > r
598+
return f"{l}" > f"{r}"
599+
600+
return True
601+
602+
567603
@functools.lru_cache(None)
568604
def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: str | None = None) -> bool:
569605
"""

monai/utils/tf32.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import functools
15+
import os
16+
import warnings
17+
18+
__all__ = ["has_ampere_or_later", "detect_default_tf32"]
19+
20+
21+
@functools.lru_cache(None)
22+
def has_ampere_or_later() -> bool:
23+
"""
24+
Check if there is any Ampere and later GPU.
25+
"""
26+
import torch
27+
28+
from monai.utils.module import optional_import, version_geq
29+
30+
if not (torch.version.cuda and version_geq(f"{torch.version.cuda}", "11.0")):
31+
return False
32+
33+
pynvml, has_pynvml = optional_import("pynvml")
34+
if not has_pynvml: # assuming that the user has Ampere and later GPU
35+
return True
36+
37+
try:
38+
pynvml.nvmlInit()
39+
for i in range(pynvml.nvmlDeviceGetCount()):
40+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
41+
major, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
42+
if major >= 8:
43+
return True
44+
except BaseException:
45+
pass
46+
finally:
47+
pynvml.nvmlShutdown()
48+
49+
return False
50+
51+
52+
@functools.lru_cache(None)
53+
def detect_default_tf32() -> bool:
54+
"""
55+
Dectect if there is anything that may enable TF32 mode by default.
56+
If any, show a warning message.
57+
"""
58+
may_enable_tf32 = False
59+
try:
60+
if not has_ampere_or_later():
61+
return False
62+
63+
from monai.utils.module import pytorch_after
64+
65+
if pytorch_after(1, 7, 0) and not pytorch_after(1, 12, 0):
66+
warnings.warn(
67+
"torch.backends.cuda.matmul.allow_tf32 = True by default.\n"
68+
" This value defaults to True when PyTorch version in [1.7, 1.11] and may affect precision.\n"
69+
" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
70+
)
71+
may_enable_tf32 = True
72+
73+
override_tf32_env_vars = {"NVIDIA_TF32_OVERRIDE": "1", "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE": "1"}
74+
for name, override_val in override_tf32_env_vars.items():
75+
if os.environ.get(name) == override_val:
76+
warnings.warn(
77+
f"Environment variable `{name} = {override_val}` is set.\n"
78+
f" This environment variable may enable TF32 mode accidentally and affect precision.\n"
79+
f" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
80+
)
81+
may_enable_tf32 = True
82+
83+
return may_enable_tf32
84+
except BaseException:
85+
from monai.utils.misc import MONAIEnvVars
86+
87+
if MONAIEnvVars.debug():
88+
raise
89+
return False

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@ typeguard<3 # https://github.com/microsoft/nni/issues/5457
5555
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
5656
zarr
5757
lpips==0.1.4
58+
nvidia-ml-py

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ all =
8282
onnxruntime; python_version <= '3.10'
8383
zarr
8484
lpips==0.1.4
85+
nvidia-ml-py
8586
nibabel =
8687
nibabel
8788
ninja =
@@ -153,6 +154,8 @@ zarr =
153154
zarr
154155
lpips =
155156
lpips==0.1.4
157+
pynvml =
158+
nvidia-ml-py
156159
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
157160
# MetricsReloaded =
158161
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from parameterized import parameterized
1818

19-
from monai.utils import version_leq
19+
from monai.utils import version_geq, version_leq
2020

2121

2222
# from pkg_resources
@@ -76,10 +76,15 @@ def _pairwise(iterable):
7676

7777
class TestVersionCompare(unittest.TestCase):
7878
@parameterized.expand(TEST_CASES)
79-
def test_compare(self, a, b, expected=True):
79+
def test_compare_leq(self, a, b, expected=True):
8080
"""Test version_leq with `a` and `b`"""
8181
self.assertEqual(version_leq(a, b), expected)
8282

83+
@parameterized.expand(TEST_CASES)
84+
def test_compare_geq(self, a, b, expected=True):
85+
"""Test version_geq with `b` and `a`"""
86+
self.assertEqual(version_geq(b, a), expected)
87+
8388

8489
if __name__ == "__main__":
8590
unittest.main()

0 commit comments

Comments
 (0)