Skip to content

Commit

Permalink
[CuDNN] Increase test tol for sm80 and higher
Browse files Browse the repository at this point in the history
  • Loading branch information
Yudi Sun committed Jul 2, 2024
1 parent 23d80fe commit a48012c
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/cuda/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
import pytest
import math
import torch
import hidet
from hidet import ops
from hidet.cuda.cudnn import cudnnDataType
Expand Down Expand Up @@ -92,11 +93,6 @@ def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, s
],
)
def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol):
# Disable TF32 operations on Ampere architecture to avoid losing precision.
import os

os.environ["NVIDIA_TF32_OVERRIDE"] = "0"

tx = tw = ty = dtype
pad_dim1, pad_dim2 = padding
str_dim1, str_dim2 = stride
Expand Down Expand Up @@ -132,6 +128,9 @@ def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, paddi
dil_dim2,
)

if dtype == hidet.float32 and torch.cuda.get_device_capability()[0] >= 8:
tol = 1e-2

hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol)


Expand Down

0 comments on commit a48012c

Please sign in to comment.