diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 985bcd82a..118c56fd3 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -92,6 +92,11 @@ def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, s ], ) def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): + # Disable TF32 operations on Ampere architecture to avoid losing precision. + import os + + os.environ["NVIDIA_TF32_OVERRIDE"] = "0" + tx = tw = ty = dtype pad_dim1, pad_dim2 = padding str_dim1, str_dim2 = stride