diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 84a8ea9f93..4251d1fc70 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1649,6 +1649,8 @@ def max_pool_with_indices_backward_meta( nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) +if hasattr(torch.nn.functional, "scaled_mm"): + scaled_mm = _register_torch_operation("scaled_mm", module=torch.nn.functional) softmax = _register_torch_operation("softmax", like=ltorch._softmax) @@ -1975,6 +1977,8 @@ def adaptive_avg_pool2d_bwd_wrapper( pad_prim_impl = ex.register_operator("torch_pad_prim_impl", meta=prims.pad.meta, fn=_pad_prim_impl) _register_implementation(prims.pad, pad_prim_impl, checker=_always_executable) _register_implementation(ltorch._softmax, checker=_always_executable, execution_transform=_softmax_transform) +if hasattr(torch.nn.functional, "scaled_mm"): + _register_implementation(ltorch.scaled_mm, scaled_mm, checker=_always_executable) _register_implementation(ltorch.scaled_dot_product_attention, scaled_dot_product_attention, checker=_always_executable) diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index a79d856422..8251087321 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -1,11 +1,15 @@ from collections.abc import Callable import os +import functools import numpy as np import pytest import torch from torch.testing import assert_close +if hasattr(torch.nn.functional, "scaled_mm"): + from torch.nn.functional import ScalingType, SwizzleType + import thunder import thunder.core.devices as devices import thunder.core.dtypes as dtypes @@ -419,6 +423,302 @@ def fn(a): assert_close(b, b_ref) +def _cuda_version_tuple() -> tuple[int, int] | None: + if torch.version.cuda is None: + return None + parts = torch.version.cuda.split(".") + try: + major = int(parts[0]) + minor = int(parts[1]) if len(parts) > 1 else 0 + return major, minor + except ValueError: + return None + + +def _require_scaled_mm(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if not hasattr(torch.nn.functional, "scaled_mm"): + pytest.skip("torch.nn.functional.scaled_mm is not found in this PyTorch") + return fn(*args, **kwargs) + + return wrapper + + +def _ensure_fp8_tensorwise(device: torch.device) -> None: + if torch.cuda.get_device_capability(device) < (8, 9): + pytest.skip("scaled_mm tensor-wise support requires SM89 or newer") + + +def _require_fp8_tensorwise(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + device = torch.device("cuda") + _ensure_fp8_tensorwise(device) + return fn(*args, **kwargs) + + return wrapper + + +def _require_fp8_rowwise(device: torch.device) -> None: + _ensure_fp8_tensorwise(device) + if torch.cuda.get_device_capability(device) < (9, 0): + pytest.skip("row-wise scaled_mm requires SM90 or newer") + cuda_version = _cuda_version_tuple() + if cuda_version is not None and cuda_version < (12, 9): + pytest.skip("row-wise scaled_mm requires CUDA 12.9 or newer") + + +def _require_fp8_blockwise(device: torch.device) -> None: + _require_fp8_rowwise(device) + + +# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L645-L659 +@requiresCUDA +@_require_fp8_tensorwise +@_require_scaled_mm +def test_scaled_mm_tensorwise_matches_torch(): + device = torch.device("cuda") + + def reference_fn(mat_a, mat_b, scale_a, scale_b): + return torch.nn.functional.scaled_mm( + mat_a, + mat_b, + scale_a, + ScalingType.TensorWise, + scale_b, + ScalingType.TensorWise, + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.NO_SWIZZLE, + output_dtype=torch.bfloat16, + ) + + M, K, N = 16, 32, 16 + mat_a = torch.randn(M, K, device=device, dtype=torch.float32) + mat_b = torch.randn(K, N, device=device, dtype=torch.float32) + mat_a_lp = mat_a.to(torch.float8_e4m3fn) + mat_b_lp = mat_b.to(torch.float8_e4m3fn) + scale_a = torch.tensor(1.0, device=device, dtype=torch.float32) + scale_b = torch.tensor(1.0, device=device, dtype=torch.float32) + + try: + expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b) + except (NotImplementedError, RuntimeError) as exc: + pytest.skip(str(exc)) + + jf = thunder.jit(reference_fn) + result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b) + assert_close(result, expected) + + +# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L862-L910 +@requiresCUDA +@_require_fp8_tensorwise +@_require_scaled_mm +def test_scaled_mm_matches_scaled_data(): + device = torch.device("cuda") + + def quantize_to_fp8(tensor): + dtype = torch.float8_e4m3fn + max_val = torch.finfo(dtype).max + amax = tensor.abs().max() + encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32) + quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype) + decode = encode.reciprocal() + return quant, decode, encode + + def scaled_mm_fp8(mat_a, mat_b, scale_a, scale_b, *, out_dtype): + return torch.nn.functional.scaled_mm( + mat_a, + mat_b, + scale_a, + ScalingType.TensorWise, + scale_b, + ScalingType.TensorWise, + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.NO_SWIZZLE, + output_dtype=out_dtype, + ) + + M, K, N = 32, 64, 32 + mat_a = torch.randn(M, K, device=device, dtype=torch.float32) + mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32) + + mat_a_lp, decode_a, encode_a = quantize_to_fp8(mat_a) + mat_b_lp_pre, decode_b, encode_b = quantize_to_fp8(mat_b_base) + # To use cublaslt, the second matrix needs to be column-major. + mat_b_lp = mat_b_lp_pre.t() + + try: + reference = scaled_mm_fp8(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.float32) + except (NotImplementedError, RuntimeError) as exc: + pytest.skip(str(exc)) + + jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_fp8(a, b, sa, sb, out_dtype=torch.float32)) + thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b) + + assert_close(thunder_out, reference) + + +@requiresCUDA +@_require_scaled_mm +def test_scaled_mm_rowwise_matches_torch(): + device = torch.device("cuda") + _require_fp8_rowwise(device) + + def reference_fn(mat_a, mat_b, scale_a, scale_b): + return torch.nn.functional.scaled_mm( + mat_a, + mat_b, + scale_a, + ScalingType.RowWise, + scale_b, + ScalingType.RowWise, + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.NO_SWIZZLE, + output_dtype=torch.bfloat16, + ) + + M, K, N = 16, 32, 16 + mat_a = torch.randn(M, K, device=device, dtype=torch.float32) + mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32) + mat_a_lp = mat_a.to(torch.float8_e4m3fn) + # To use cublaslt, the second matrix needs to be column-major. + mat_b_lp = mat_b_base.to(torch.float8_e4m3fn).t() + scale_a = torch.ones((M, 1), device=device, dtype=torch.float32) + scale_b = torch.ones((1, N), device=device, dtype=torch.float32) + + try: + expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b) + except (NotImplementedError, RuntimeError) as exc: + pytest.skip(str(exc)) + + jf = thunder.jit(reference_fn) + result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b) + assert_close(result, expected) + + +@requiresCUDA +@_require_scaled_mm +def test_scaled_mm_rowwise_matches_scaled_data(): + device = torch.device("cuda") + _require_fp8_rowwise(device) + + dtype_fp8 = torch.float8_e4m3fn + max_val = torch.finfo(dtype_fp8).max + + def rowwise_quantize(tensor, *, dim): + amax = tensor.abs().amax(dim=dim, keepdim=True) + encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32) + quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype_fp8) + decode = encode.reciprocal() + return quant, decode, encode + + def scaled_mm_rowwise(mat_a, mat_b, scale_a, scale_b, *, out_dtype): + return torch.nn.functional.scaled_mm( + mat_a, + mat_b, + scale_a, + ScalingType.RowWise, + scale_b, + ScalingType.RowWise, + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.NO_SWIZZLE, + output_dtype=out_dtype, + ) + + M, K, N = 32, 64, 32 + mat_a = torch.randn(M, K, device=device, dtype=torch.bfloat16) + mat_b = torch.randn(K, N, device=device, dtype=torch.bfloat16) + + mat_a_lp, decode_a, encode_a = rowwise_quantize(mat_a.to(torch.float32), dim=1) + mat_b_lp, decode_b, encode_b = rowwise_quantize(mat_b.to(torch.float32), dim=0) + + try: + reference = scaled_mm_rowwise(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.bfloat16) + except (NotImplementedError, RuntimeError) as exc: + pytest.skip(str(exc)) + + jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_rowwise(a, b, sa, sb, out_dtype=torch.bfloat16)) + thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b) + + reference_f32 = reference.to(torch.float32) + thunder_out_f32 = thunder_out.to(torch.float32) + + assert_close(thunder_out_f32, reference_f32, atol=3e-2, rtol=3e-2) + + +def _blockwise_quantize(tensor: torch.Tensor, block_rows: int, block_cols: int) -> tuple[torch.Tensor, torch.Tensor]: + dtype_fp8 = torch.float8_e4m3fn + max_val = torch.finfo(dtype_fp8).max + + M, K = tensor.shape + assert M % block_rows == 0 and K % block_cols == 0 + + reshaped = tensor.reshape(M // block_rows, block_rows, K // block_cols, block_cols) + amax = reshaped.abs().amax(dim=(1, 3), keepdim=True) + encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32) + quant = torch.clamp(reshaped * encode, min=-max_val, max=max_val).to(dtype_fp8) + + return quant.reshape(M, K), encode.reshape(M // block_rows, K // block_cols).to(tensor.device) + + +@requiresCUDA +@_require_scaled_mm +@pytest.mark.parametrize("output_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) +def test_scaled_mm_blockwise_matches_torch(output_dtype, lhs_block, rhs_block): + device = torch.device("cuda") + _require_fp8_blockwise(device) + + M, K, N = 256, 256, 256 + mat_a = torch.randn(M, K, device=device, dtype=output_dtype).pow(3) + mat_b_rows = torch.randn(N, K, device=device, dtype=output_dtype).pow(3) + + mat_a_lp, encode_a = _blockwise_quantize(mat_a.to(torch.float32), lhs_block, 128) + mat_b_lp_rows, encode_b = _blockwise_quantize(mat_b_rows.to(torch.float32), rhs_block, 128) + mat_b_lp = mat_b_lp_rows.t().contiguous() + + scale_a = encode_a.reciprocal().contiguous() + scale_b = encode_b.reciprocal().t().contiguous() + + recipe_map = { + 1: ScalingType.BlockWise1x128, + 128: ScalingType.BlockWise128x128, + } + + try: + expected = torch.nn.functional.scaled_mm( + mat_a_lp, + mat_b_lp, + scale_a, + recipe_map[lhs_block], + scale_b, + recipe_map[rhs_block], + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.NO_SWIZZLE, + output_dtype=output_dtype, + ) + except (RuntimeError, NotImplementedError, ValueError) as exc: + pytest.skip(str(exc)) + + fn = thunder.jit( + lambda a, b, sa, sb: torch.nn.functional.scaled_mm( + a, + b, + sa, + recipe_map[lhs_block], + sb, + recipe_map[rhs_block], + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.NO_SWIZZLE, + output_dtype=output_dtype, + ) + ) + thunder_out = fn(mat_a_lp, mat_b_lp, scale_a, scale_b) + assert_close(thunder_out, expected) + + # https://github.com/Lightning-AI/lightning-thunder/issues/1857 def test_max_with_int(): def f(x, ids): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 81e0d199bd..a526d4d809 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5990,6 +5990,104 @@ def mse_loss( raise ValueError(f"Reduction argument {reduction} to mse_loss is not supported") +if hasattr(torch.nn.functional, "scaled_mm"): + from torch.nn.functional import ScalingType + from torch.nn.functional import SwizzleType + + @torchsymbol(torch.nn.functional.scaled_mm) + def scaled_mm( + mat_a: TensorLike, + mat_b: TensorLike, + scale_a: TensorLike | list[TensorLike], + scale_recipe_a: ScalingType | list[ScalingType], + scale_b: TensorLike | list[TensorLike], + scale_recipe_b: ScalingType | list[ScalingType], + swizzle_a: SwizzleType | list[SwizzleType] | None = None, + swizzle_b: SwizzleType | list[SwizzleType] | None = None, + bias: TensorLike | None = None, + output_dtype: torch.dtype | None = torch.bfloat16, + contraction_dim: Sequence[int] = (), + use_fast_accum: bool = False, + ) -> TensorLike: + utils.check_type(mat_a, TensorProxy) + utils.check_type(mat_b, TensorProxy) + utils.check( + mat_a.ndim == 2 and mat_b.ndim == 2, + lambda: "torch.nn.functional.scaled_mm currently supports 2D matrices", + NotImplementedError, + ) + utils.check( + mat_a.shape[1] == mat_b.shape[0], + lambda: ( + f"torch.nn.functional.scaled_mm expects mat_a.shape[-1] ({mat_a.shape[1]}) " + f"to equal mat_b.shape[-2] ({mat_b.shape[0]})" + ), + ) + utils.check( + len(contraction_dim) == 0, + lambda: "torch.nn.functional.scaled_mm does not yet support contraction_dim", + NotImplementedError, + ) + + def _expand(value): + if value is None: + return [] + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + scale_a_list = _expand(scale_a) + scale_b_list = _expand(scale_b) + + def _validate_enum_list(values: list[Any], scales: list[Any], name: str) -> None: + if not values: + return + utils.check( + len(scales) > 0, + lambda: f"{name} was provided but the corresponding scale list is empty", + ValueError, + ) + utils.check( + len(values) in (1, len(scales)), + lambda: ( + f"{name} must either be a single value or contain {len(scales)} entries " + f"to match the number of associated scale tensors" + ), + ValueError, + ) + for enum_value in values: + _ = int(enum_value.value) if hasattr(enum_value, "value") else int(enum_value) + + _validate_enum_list(_expand(scale_recipe_a), scale_a_list, "scale_recipe_a") + _validate_enum_list(_expand(scale_recipe_b), scale_b_list, "scale_recipe_b") + _validate_enum_list(_expand(swizzle_a), scale_a_list, "swizzle_a") + _validate_enum_list(_expand(swizzle_b), scale_b_list, "swizzle_b") + + def _collect_tensor_proxy_args(values: Sequence[Any]) -> list[TensorProxy]: + return [t for t in values if isinstance(t, TensorProxy)] + + tensor_args: list[TensorProxy] = [mat_a, mat_b] + tensor_args += _collect_tensor_proxy_args(scale_a_list) + tensor_args += _collect_tensor_proxy_args(scale_b_list) + if isinstance(bias, TensorProxy): + tensor_args.append(bias) + utils.check_same_device(*tensor_args) + + result_dtype = to_dtype(output_dtype or torch.bfloat16) + requires_grad = ( + mat_a.requires_grad or mat_b.requires_grad or (isinstance(bias, TensorProxy) and bias.requires_grad) + ) + + m = mat_a.shape[0] + n = mat_b.shape[1] + return TensorProxy( + shape=(m, n), + device=mat_a.device, + dtype=result_dtype, + requires_grad=requires_grad, + ) + + # TODO Add annotations # NOTE The scale parameter is kwarg-only in PyTorch @torchsymbol(torch.nn.functional.scaled_dot_product_attention, tags=(prims.OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,))