You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
about #6754 .
### Description
show a warning if any thing may enable tf32 is detected
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Qingpeng Li <[email protected]>
Copy file name to clipboardExpand all lines: docs/source/precision_accelerating.md
+3-3Lines changed: 3 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -29,11 +29,11 @@ by TF32 mode so the impact is very wide.
29
29
torch.backends.cuda.matmul.allow_tf32 =False# in PyTorch 1.12 and later.
30
30
torch.backends.cudnn.allow_tf32 =True
31
31
```
32
-
Please note that there are environment variables that can override the flags above. For example, the environment variables mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.
33
-
34
-
We recommend that users print out these two flags for confirmation when unsure.
32
+
Please note that there are environment variables that can override the flags above. For example, the environment variable `NVIDIA_TF32_OVERRIDE` mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.
35
33
36
34
If you are using an [NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), the container includes a layer `ENV TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`.
37
35
The default value `torch.backends.cuda.matmul.allow_tf32` will be overridden to `True`.
38
36
37
+
We recommend that users print out these two flags for confirmation when unsure.
38
+
39
39
If you can confirm through experiments that your model has no accuracy or convergence issues in TF32 mode and you have NVIDIA Ampere GPUs or above, you can set the two flags above to `True` to speed up your model.
0 commit comments