diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 118c56fd3..86269a576 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -11,6 +11,7 @@ # limitations under the License. import pytest import math +import torch import hidet from hidet import ops from hidet.cuda.cudnn import cudnnDataType @@ -92,11 +93,6 @@ def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, s ], ) def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): - # Disable TF32 operations on Ampere architecture to avoid losing precision. - import os - - os.environ["NVIDIA_TF32_OVERRIDE"] = "0" - tx = tw = ty = dtype pad_dim1, pad_dim2 = padding str_dim1, str_dim2 = stride @@ -132,6 +128,9 @@ def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, paddi dil_dim2, ) + if dtype == hidet.float32 and torch.cuda.get_device_capability()[0] >= 8: + tol = 1e-2 + hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol)