-
Notifications
You must be signed in to change notification settings - Fork 412
Open
Description
Hi tch-rs community!
I wish to ask a short question. How to enable TF32 in tch (for CUDA)? This option can make recent NVidia GPUs extremely fast when precision accuracy is not significant. I tried to search TF32 or precision in code of this crate, but could not find this option.
My workaround is using TF32 for GEMM, but not CUDNN.
For other tools, in candle-core of rust, this is done by
candle_core::cuda::set_gemm_reduced_precision_f32(true);In cudarc, it seems that TF32 is not available for cublas wrapper, and TF32 is enforced (hardcoded) in cublaslt wrapper. So suing cudarc::cublaslt::safe will automatically call GEMM with TF32.
In pytorch of python, this is done by (https://pytorch.org/docs/stable/notes/cuda.html)
torch.backends.cuda.matmul.allow_tf32 = TrueMetadata
Metadata
Assignees
Labels
No labels