Skip to content

How to enable TF32 for CUDA GEMM? #916

@ajz34

Description

@ajz34

Hi tch-rs community!

I wish to ask a short question. How to enable TF32 in tch (for CUDA)? This option can make recent NVidia GPUs extremely fast when precision accuracy is not significant. I tried to search TF32 or precision in code of this crate, but could not find this option.
My workaround is using TF32 for GEMM, but not CUDNN.


For other tools, in candle-core of rust, this is done by

candle_core::cuda::set_gemm_reduced_precision_f32(true);

In cudarc, it seems that TF32 is not available for cublas wrapper, and TF32 is enforced (hardcoded) in cublaslt wrapper. So suing cudarc::cublaslt::safe will automatically call GEMM with TF32.

In pytorch of python, this is done by (https://pytorch.org/docs/stable/notes/cuda.html)

torch.backends.cuda.matmul.allow_tf32 = True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions