Skip to content

Commit e0bdae7

Browse files
xw285cornellpytorchmergebot
authored andcommitted
[AMD] Turn on TF32 for aten::mm (#139869)
Summary: hipblaslt supports TF32, so adding the support. Test Plan: CI Differential Revision: D65435392 Pull Request resolved: pytorch/pytorch#139869 Approved by: https://github.com/leitian
1 parent 5273d8f commit e0bdae7

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

aten/src/ATen/Context.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ bool Context::userEnabledOverrideableSDP() const {
186186

187187
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
188188
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
189+
#ifdef USE_ROCM
190+
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
191+
#endif
189192

190193
bool Context::checkCuBLASConfigDeterministic() {
191194
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
@@ -237,10 +240,26 @@ void Context::setBenchmarkLimitCuDNN(int b) {
237240
}
238241

239242
bool Context::allowTF32CuBLAS() const {
243+
#ifdef USE_ROCM
244+
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
245+
if (allow_tf32 != true) {
246+
return false;
247+
}
248+
#endif
240249
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
241250
}
242251

243252
void Context::setAllowTF32CuBLAS(bool b) {
253+
#ifdef USE_ROCM
254+
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
255+
if (allow_tf32 != true) {
256+
TORCH_WARN(
257+
"torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
258+
"Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it."
259+
);
260+
return;
261+
}
262+
#endif
244263
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
245264
}
246265

aten/src/ATen/cuda/CUDABlas.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,9 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
337337
computeType = CUBLAS_COMPUTE_64F;
338338
scaleType = CUDA_R_64F;
339339
} else if constexpr (std::is_same_v<Dtype, float>) {
340-
#ifndef USE_ROCM
341340
if (at::globalContext().allowTF32CuBLAS()) {
342341
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
343342
}
344-
#endif
345343
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
346344
abcType = CUDA_C_64F;
347345
computeType = CUBLAS_COMPUTE_64F;
@@ -1237,11 +1235,9 @@ void gemm_and_bias(
12371235
computeType = CUBLAS_COMPUTE_64F;
12381236
scaleType = CUDA_R_64F;
12391237
} else if constexpr (std::is_same_v<Dtype, float>) {
1240-
#ifndef USE_ROCM
12411238
if (at::globalContext().allowTF32CuBLAS()) {
12421239
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
12431240
}
1244-
#endif
12451241
abcType = CUDA_R_32F;
12461242
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
12471243
abcType = CUDA_R_16F;

test/test_cuda.py

+33
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,33 @@ def check_workspace_size(inp):
483483

484484
torch._C._cuda_clearCublasWorkspaces()
485485

486+
@contextlib.contextmanager
487+
def _hip_allow_tf32(self):
488+
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
489+
# and only for MI300+
490+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
491+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
492+
493+
try:
494+
yield
495+
finally:
496+
if hip_allow_tf32 is not None:
497+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
498+
else:
499+
del os.environ["HIPBLASLT_ALLOW_TF32"]
500+
486501
def test_cublas_allow_tf32_get_set(self):
502+
"""
503+
We only turn on TF32 for MI300 with a special env var. This is because TF32
504+
is only available in MI300+ and is in experimental mode (hipblaslt support
505+
is current WIP)
506+
"""
507+
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
508+
509+
with tf32_ctx():
510+
self._test_cublas_allow_tf32_get_set_inner()
511+
512+
def _test_cublas_allow_tf32_get_set_inner(self):
487513
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
488514
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
489515
)
@@ -498,6 +524,12 @@ def test_cublas_allow_tf32_get_set(self):
498524
torch.backends.cuda.matmul.allow_tf32 = orig
499525

500526
def test_float32_matmul_precision_get_set(self):
527+
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
528+
529+
with tf32_ctx():
530+
self._test_float32_matmul_precision_get_set_inner()
531+
532+
def _test_float32_matmul_precision_get_set_inner(self):
501533
orig = torch.get_float32_matmul_precision()
502534
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
503535
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
@@ -509,6 +541,7 @@ def test_float32_matmul_precision_get_set(self):
509541
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
510542
else:
511543
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
544+
512545
for p in ("medium", "high"):
513546
torch.set_float32_matmul_precision(p)
514547
self.assertEqual(torch.get_float32_matmul_precision(), p)

torch/utils/hipify/cuda_to_hip_mappings.py

+4
Original file line numberDiff line numberDiff line change
@@ -7292,6 +7292,10 @@
72927292
"CUBLAS_COMPUTE_32F",
72937293
("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS)
72947294
),
7295+
(
7296+
"CUBLAS_COMPUTE_32F_FAST_TF32",
7297+
("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS)
7298+
),
72957299
(
72967300
"CUBLAS_COMPUTE_64F",
72977301
("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS)

0 commit comments

Comments
 (0)