From c99e973825874dbc33f9e35da653225ec080088c Mon Sep 17 00:00:00 2001 From: jiamingwang-mt Date: Mon, 2 Feb 2026 15:34:38 +0800 Subject: [PATCH] musa support --- transformer_engine/musa/__init__.py | 193 ++ transformer_engine/musa/common/CMakeLists.txt | 195 ++ .../common/activation/activation_template.h | 74 + .../musa/common/activation/gelu.mu | 60 + .../musa/common/activation/relu.mu | 60 + .../musa/common/activation/swiglu.mu | 34 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 1274 ++++++++ .../userbuffers/ipcsocket.cc | 1 + .../comm_gemm_overlap/userbuffers/ipcsocket.h | 1 + .../userbuffers/userbuffers-host.cpp | 628 ++++ .../userbuffers/userbuffers.h | 335 ++ .../userbuffers/userbuffers.mu | 2790 +++++++++++++++++ transformer_engine/musa/common/common.h | 517 +++ transformer_engine/musa/common/common.mu | 148 + .../musa/common/fused_attn/fused_attn.cpp | 857 +++++ .../musa/common/fused_attn/thd_utils.h | 250 ++ .../musa/common/fused_attn/thd_utils.mu | 78 + .../musa/common/fused_rope/fused_rope.mu | 366 +++ .../scaled_aligned_causal_masked_softmax.mu | 568 ++++ .../fused_softmax/scaled_masked_softmax.mu | 850 +++++ .../scaled_upper_triang_masked_softmax.mu | 615 ++++ .../musa/common/gemm/mudnn_gemm.cpp | 383 +++ .../include/transformer_engine/activation.h | 273 ++ .../common/include/transformer_engine/cast.h | 219 ++ .../transformer_engine/cast_transpose_noop.h | 46 + .../transformer_engine/comm_gemm_overlap.h | 301 ++ .../include/transformer_engine/fused_attn.h | 548 ++++ .../include/transformer_engine/fused_rope.h | 129 + .../common/include/transformer_engine/gemm.h | 124 + .../transformer_engine/normalization.h | 156 + .../include/transformer_engine/padding.h | 51 + .../include/transformer_engine/permutation.h | 46 + .../include/transformer_engine/recipe.h | 80 + .../include/transformer_engine/softmax.h | 132 + .../include/transformer_engine/swizzle.h | 37 + .../transformer_engine/transformer_engine.h | 618 ++++ .../include/transformer_engine/transpose.h | 325 ++ transformer_engine/musa/common/nvtx.h | 24 + .../musa/common/permutation/permutation.mu | 373 +++ .../common/permutation/permutation_mask.mu | 699 +++++ .../musa/common/recipe/__init__.py | 198 ++ .../musa/common/recipe/delayed_scaling.mu | 420 +++ .../musa/common/recipe/recipe_common.muh | 74 + .../musa/common/swizzle/swizzle.mu | 338 ++ .../musa/common/transformer_engine.cpp | 416 +++ .../musa/common/transpose/cast_transpose.h | 28 + .../musa/common/transpose/cast_transpose.mu | 359 +++ .../common/transpose/cast_transpose_fusion.mu | 1414 +++++++++ .../common/transpose/multi_cast_transpose.mu | 341 ++ .../common/transpose/rtc/cast_transpose.mu | 129 + .../transpose/rtc/cast_transpose_fusion.mu | 255 ++ .../musa/common/transpose/rtc/transpose.mu | 101 + .../musa/common/transpose/transpose.mu | 301 ++ .../musa/common/transpose/transpose_fusion.mu | 501 +++ transformer_engine/musa/common/util/cast.mu | 147 + .../musa/common/util/cast_gated_kernels.muh | 1093 +++++++ .../musa/common/util/cast_kernels.muh | 1291 ++++++++ .../musa/common/util/dequantize_kernels.muh | 367 +++ transformer_engine/musa/common/util/logging.h | 79 + transformer_engine/musa/common/util/math.h | 1 + .../common/util/mtfp8_blockwise_quantize.muh | 378 +++ .../musa/common/util/mtfp8_cast.muh | 104 + .../musa/common/util/mtfp8_cast_transpose.h | 12 + .../musa/common/util/mtfp8_cast_transpose.mu | 402 +++ .../musa/common/util/mtfp8_dequantize.mu | 102 + .../common/util/mtfp8_groupwise_quantize.muh | 329 ++ .../musa/common/util/mtfp8_utils.muh | 65 + transformer_engine/musa/common/util/mudnn.h | 90 + .../musa/common/util/musa_driver.cpp | 108 + .../musa/common/util/musa_driver.h | 62 + .../musa/common/util/musa_runtime.cpp | 194 ++ .../musa/common/util/musa_runtime.h | 74 + .../musa/common/util/padding.mu | 219 ++ .../musa/common/util/pybind_helper.h | 111 + transformer_engine/musa/common/util/rtc.cpp | 237 ++ transformer_engine/musa/common/util/rtc.h | 184 ++ transformer_engine/musa/common/util/string.h | 1 + .../musa/common/util/string_header.h.in | 1 + .../musa/common/util/system.cpp | 1 + transformer_engine/musa/common/util/system.h | 1 + .../musa/common/util/vectorized_pointwise.h | 597 ++++ transformer_engine/musa/common/utils.muh | 989 ++++++ transformer_engine/musa/pytorch/__init__.py | 0 transformer_engine/musa/pytorch/attention.py | 1066 +++++++ .../musa/pytorch/cpp_extensions/__init__.py | 0 .../musa/pytorch/cpp_extensions/cast.py | 19 + .../musa/pytorch/csrc/common.cpp | 249 ++ transformer_engine/musa/pytorch/csrc/common.h | 328 ++ .../musa/pytorch/csrc/extensions.h | 474 +++ .../pytorch/csrc/extensions/activation.cpp | 118 + .../pytorch/csrc/extensions/apply_rope.cpp | 223 ++ .../musa/pytorch/csrc/extensions/attention.mu | 1011 ++++++ .../musa/pytorch/csrc/extensions/bias.cpp | 51 + .../musa/pytorch/csrc/extensions/cast.cpp | 129 + .../csrc/extensions/comm_gemm_overlap.cpp | 327 ++ .../fp8_block_scaling_partial_cast.mu | 229 ++ .../musa/pytorch/csrc/extensions/gemm.cpp | 404 +++ .../musa/pytorch/csrc/extensions/misc.cpp | 11 + .../multi_tensor/multi_tensor_adam.mu | 644 ++++ .../multi_tensor_compute_scale.mu | 68 + .../multi_tensor_l2norm_kernel.mu | 412 +++ .../multi_tensor/multi_tensor_scale_kernel.mu | 120 + .../multi_tensor/multi_tensor_sgd_kernel.mu | 203 ++ .../pytorch/csrc/extensions/normalization.cpp | 275 ++ .../musa/pytorch/csrc/extensions/padding.cpp | 80 + .../pytorch/csrc/extensions/permutation.mu | 314 ++ .../musa/pytorch/csrc/extensions/pybind.cpp | 362 +++ .../pytorch/csrc/extensions/quantizer.cpp | 324 ++ .../musa/pytorch/csrc/extensions/recipe.cpp | 48 + .../musa/pytorch/csrc/extensions/softmax.cpp | 247 ++ .../musa/pytorch/csrc/extensions/swizzle.cpp | 120 + .../pytorch/csrc/extensions/transpose.cpp | 434 +++ .../csrc/extensions/type_converters.cpp | 109 + .../musa/pytorch/csrc/extensions/util.cpp | 16 + .../musa/pytorch/csrc/multi_tensor_apply.muh | 141 + transformer_engine/musa/pytorch/csrc/pybind.h | 119 + .../musa/pytorch/csrc/type_shim.h | 1 + transformer_engine/musa/pytorch/csrc/util.h | 1 + .../musa/pytorch/distributed.py | 681 ++++ transformer_engine/musa/pytorch/fp8.py | 261 ++ .../musa/pytorch/module/__init__.py | 0 .../musa/pytorch/module/base.py | 82 + .../musa/pytorch/module/grouped_linear.py | 680 ++++ .../musa/pytorch/module/linear.py | 592 ++++ .../musa/pytorch/ops/__init__.py | 0 transformer_engine/musa/pytorch/ops/op.py | 87 + .../musa/pytorch/tensor/__init__.py | 9 + .../musa/pytorch/tensor/mtfp8_tensor.py | 510 +++ .../musa/pytorch/tensor/mtfp8_tensor_base.py | 111 + transformer_engine/musa/pytorch/utils.py | 28 + .../dot_product_attention/backends.py | 33 +- .../dot_product_attention/context_parallel.py | 46 +- .../dot_product_attention.py | 27 +- .../dot_product_attention/softmax.py | 2 +- .../attention/dot_product_attention/utils.py | 47 +- .../pytorch/attention/inference.py | 20 +- .../pytorch/attention/multi_head_attention.py | 10 +- transformer_engine/pytorch/attention/rope.py | 4 +- transformer_engine/pytorch/cpu_offload.py | 188 ++ transformer_engine/pytorch/module/_common.py | 35 + transformer_engine/pytorch/module/base.py | 36 +- .../pytorch/module/grouped_linear.py | 12 +- .../pytorch/module/layernorm.py | 4 +- .../pytorch/module/layernorm_linear.py | 10 +- .../pytorch/module/layernorm_mlp.py | 22 +- transformer_engine/pytorch/module/linear.py | 10 +- transformer_engine/pytorch/module/rmsnorm.py | 4 +- transformer_engine/pytorch/ops/op.py | 67 + transformer_engine/pytorch/tensor/__init__.py | 3 +- .../pytorch/tensor/_internal/__init__.py | 4 + .../tensor/_internal/float8_tensor_base.py | 137 + .../tensor/_internal/mxfp8_tensor_base.py | 134 + .../pytorch/tensor/float8_tensor.py | 2 +- transformer_engine/pytorch/utils.py | 7 + 154 files changed, 39021 insertions(+), 133 deletions(-) create mode 100644 transformer_engine/musa/__init__.py create mode 100644 transformer_engine/musa/common/CMakeLists.txt create mode 100644 transformer_engine/musa/common/activation/activation_template.h create mode 100644 transformer_engine/musa/common/activation/gelu.mu create mode 100644 transformer_engine/musa/common/activation/relu.mu create mode 100644 transformer_engine/musa/common/activation/swiglu.mu create mode 100644 transformer_engine/musa/common/comm_gemm_overlap/comm_gemm_overlap.cpp create mode 120000 transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.cc create mode 120000 transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.h create mode 100644 transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp create mode 100644 transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.h create mode 100644 transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.mu create mode 100644 transformer_engine/musa/common/common.h create mode 100644 transformer_engine/musa/common/common.mu create mode 100644 transformer_engine/musa/common/fused_attn/fused_attn.cpp create mode 100644 transformer_engine/musa/common/fused_attn/thd_utils.h create mode 100644 transformer_engine/musa/common/fused_attn/thd_utils.mu create mode 100644 transformer_engine/musa/common/fused_rope/fused_rope.mu create mode 100644 transformer_engine/musa/common/fused_softmax/scaled_aligned_causal_masked_softmax.mu create mode 100644 transformer_engine/musa/common/fused_softmax/scaled_masked_softmax.mu create mode 100644 transformer_engine/musa/common/fused_softmax/scaled_upper_triang_masked_softmax.mu create mode 100644 transformer_engine/musa/common/gemm/mudnn_gemm.cpp create mode 100644 transformer_engine/musa/common/include/transformer_engine/activation.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/cast.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/cast_transpose_noop.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/comm_gemm_overlap.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/fused_attn.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/fused_rope.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/gemm.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/normalization.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/padding.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/permutation.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/recipe.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/softmax.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/swizzle.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/transformer_engine.h create mode 100644 transformer_engine/musa/common/include/transformer_engine/transpose.h create mode 100644 transformer_engine/musa/common/nvtx.h create mode 100644 transformer_engine/musa/common/permutation/permutation.mu create mode 100644 transformer_engine/musa/common/permutation/permutation_mask.mu create mode 100644 transformer_engine/musa/common/recipe/__init__.py create mode 100644 transformer_engine/musa/common/recipe/delayed_scaling.mu create mode 100644 transformer_engine/musa/common/recipe/recipe_common.muh create mode 100644 transformer_engine/musa/common/swizzle/swizzle.mu create mode 100644 transformer_engine/musa/common/transformer_engine.cpp create mode 100644 transformer_engine/musa/common/transpose/cast_transpose.h create mode 100644 transformer_engine/musa/common/transpose/cast_transpose.mu create mode 100644 transformer_engine/musa/common/transpose/cast_transpose_fusion.mu create mode 100644 transformer_engine/musa/common/transpose/multi_cast_transpose.mu create mode 100644 transformer_engine/musa/common/transpose/rtc/cast_transpose.mu create mode 100644 transformer_engine/musa/common/transpose/rtc/cast_transpose_fusion.mu create mode 100644 transformer_engine/musa/common/transpose/rtc/transpose.mu create mode 100644 transformer_engine/musa/common/transpose/transpose.mu create mode 100644 transformer_engine/musa/common/transpose/transpose_fusion.mu create mode 100644 transformer_engine/musa/common/util/cast.mu create mode 100644 transformer_engine/musa/common/util/cast_gated_kernels.muh create mode 100644 transformer_engine/musa/common/util/cast_kernels.muh create mode 100644 transformer_engine/musa/common/util/dequantize_kernels.muh create mode 100644 transformer_engine/musa/common/util/logging.h create mode 120000 transformer_engine/musa/common/util/math.h create mode 100644 transformer_engine/musa/common/util/mtfp8_blockwise_quantize.muh create mode 100644 transformer_engine/musa/common/util/mtfp8_cast.muh create mode 100644 transformer_engine/musa/common/util/mtfp8_cast_transpose.h create mode 100644 transformer_engine/musa/common/util/mtfp8_cast_transpose.mu create mode 100644 transformer_engine/musa/common/util/mtfp8_dequantize.mu create mode 100644 transformer_engine/musa/common/util/mtfp8_groupwise_quantize.muh create mode 100644 transformer_engine/musa/common/util/mtfp8_utils.muh create mode 100644 transformer_engine/musa/common/util/mudnn.h create mode 100644 transformer_engine/musa/common/util/musa_driver.cpp create mode 100644 transformer_engine/musa/common/util/musa_driver.h create mode 100644 transformer_engine/musa/common/util/musa_runtime.cpp create mode 100644 transformer_engine/musa/common/util/musa_runtime.h create mode 100644 transformer_engine/musa/common/util/padding.mu create mode 100644 transformer_engine/musa/common/util/pybind_helper.h create mode 100644 transformer_engine/musa/common/util/rtc.cpp create mode 100644 transformer_engine/musa/common/util/rtc.h create mode 120000 transformer_engine/musa/common/util/string.h create mode 120000 transformer_engine/musa/common/util/string_header.h.in create mode 120000 transformer_engine/musa/common/util/system.cpp create mode 120000 transformer_engine/musa/common/util/system.h create mode 100644 transformer_engine/musa/common/util/vectorized_pointwise.h create mode 100644 transformer_engine/musa/common/utils.muh create mode 100644 transformer_engine/musa/pytorch/__init__.py create mode 100644 transformer_engine/musa/pytorch/attention.py create mode 100644 transformer_engine/musa/pytorch/cpp_extensions/__init__.py create mode 100644 transformer_engine/musa/pytorch/cpp_extensions/cast.py create mode 100644 transformer_engine/musa/pytorch/csrc/common.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/common.h create mode 100644 transformer_engine/musa/pytorch/csrc/extensions.h create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/activation.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/apply_rope.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/attention.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/bias.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/cast.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/comm_gemm_overlap.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/gemm.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/misc.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/normalization.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/padding.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/permutation.mu create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/pybind.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/quantizer.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/recipe.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/softmax.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/swizzle.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/transpose.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/type_converters.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/extensions/util.cpp create mode 100644 transformer_engine/musa/pytorch/csrc/multi_tensor_apply.muh create mode 100644 transformer_engine/musa/pytorch/csrc/pybind.h create mode 120000 transformer_engine/musa/pytorch/csrc/type_shim.h create mode 120000 transformer_engine/musa/pytorch/csrc/util.h create mode 100644 transformer_engine/musa/pytorch/distributed.py create mode 100644 transformer_engine/musa/pytorch/fp8.py create mode 100644 transformer_engine/musa/pytorch/module/__init__.py create mode 100644 transformer_engine/musa/pytorch/module/base.py create mode 100644 transformer_engine/musa/pytorch/module/grouped_linear.py create mode 100644 transformer_engine/musa/pytorch/module/linear.py create mode 100644 transformer_engine/musa/pytorch/ops/__init__.py create mode 100644 transformer_engine/musa/pytorch/ops/op.py create mode 100644 transformer_engine/musa/pytorch/tensor/__init__.py create mode 100644 transformer_engine/musa/pytorch/tensor/mtfp8_tensor.py create mode 100644 transformer_engine/musa/pytorch/tensor/mtfp8_tensor_base.py create mode 100644 transformer_engine/musa/pytorch/utils.py create mode 100644 transformer_engine/pytorch/tensor/_internal/__init__.py create mode 100644 transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py diff --git a/transformer_engine/musa/__init__.py b/transformer_engine/musa/__init__.py new file mode 100644 index 0000000000..c14679d888 --- /dev/null +++ b/transformer_engine/musa/__init__.py @@ -0,0 +1,193 @@ +import sys +import torch +import torch.utils +import torch.utils.data +import torch_musa + + +def patch_before_import_te(): + from .pytorch import attention + from .pytorch import tensor + from .pytorch import fp8 + from .pytorch import distributed + from .pytorch.module import base + from .pytorch.ops import op + from .pytorch.cpp_extensions import cast + from .pytorch.module import linear + from .pytorch.module import grouped_linear + from .pytorch import utils + +def patch_after_import_torch(): + def hook_cuda_device(device): + if isinstance(device, str) and device.startswith("cuda"): + return device.replace("cuda", "musa") + if isinstance(device, torch.device) and device.type == "cuda": + return torch.device("musa", device.index) + return device + + def maybe_hook_cuda_args(args, kwargs): + new_args = [] + for arg in args: + new_args.append(hook_cuda_device(arg)) + if "device" in kwargs: + v = kwargs["device"] + kwargs['device'] = hook_cuda_device(v) + return tuple(new_args), kwargs + + torch.cuda.is_available = torch.musa.is_available + torch.cuda.current_device = torch.musa.current_device + torch.cuda.device_count = torch.musa.device_count + torch.cuda.set_device = torch.musa.set_device + torch.cuda.DoubleTensor = torch.musa.DoubleTensor + torch.cuda.FloatTensor = torch.musa.FloatTensor + torch.cuda.LongTensor = torch.musa.LongTensor + torch.cuda.HalfTensor = torch.musa.HalfTensor + torch.cuda.BFloat16Tensor = torch.musa.BFloat16Tensor + torch.cuda.IntTensor = torch.musa.IntTensor + torch.cuda.synchronize = torch.musa.synchronize + torch.cuda.get_rng_state = torch.musa.get_rng_state + torch.cuda.set_rng_state = torch.musa.set_rng_state + torch.cuda.synchronize = torch.musa.synchronize + torch.cuda.empty_cache = torch.musa.empty_cache + torch.Tensor.cuda = torch.Tensor.musa + torch.cuda.manual_seed = torch.musa.manual_seed + torch.cuda.Event = torch.musa.Event + torch.cuda.Stream = torch.musa.Stream + torch.cuda.current_stream = torch.musa.current_stream + torch.cuda.set_stream = torch.musa.set_stream + torch.cuda.get_device_properties = torch.musa.get_device_properties + # add torch.musa.current_devce() to activate torch.musa.default_generators + d = torch.musa.current_device() + torch.cuda.default_generators = torch.musa.default_generators + + torch.cuda.memory_allocated = torch.musa.memory_allocated + torch.cuda.max_memory_allocated = torch.musa.max_memory_allocated + torch.cuda.memory_reserved = torch.musa.memory_reserved + torch.cuda.max_memory_reserved = torch.musa.max_memory_reserved + + # (yehua.zhang) replace lazy_call to avoid cpu memory leak, + # because failure of cuda init in lazy_call will cause endless operation of emplace back. + torch.cuda._lazy_call = torch.musa.core._lazy_init._lazy_call + torch.cuda._lazy_init = torch.musa.core._lazy_init._lazy_init + + original_tensor = torch.tensor + def patched_tensor(*args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + result = original_tensor(*args, **kwargs) + return result + torch.tensor = patched_tensor + + orig_type = torch.Tensor.type + def musa_type(*args, **kwargs): + result = orig_type(*args, **kwargs) + if isinstance(result, str): + result = result.replace("musa", "cuda") + return result + torch.Tensor.type = musa_type + + original_zeros = torch.zeros + def patched_zeros(*args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + result = original_zeros(*args, **kwargs) + return result + torch.zeros = patched_zeros + + original_ones = torch.ones + def patched_ones(*args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + result = original_ones(*args, **kwargs) + return result + torch.ones = patched_ones + + original_empty = torch.empty + def patched_empty(*args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + result = original_empty(*args, **kwargs) + return result + torch.empty = patched_empty + + original_rand = torch.rand + def patched_rand(*args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + result = original_rand(*args, **kwargs) + return result + torch.rand = patched_rand + + original_arange = torch.arange + def patched_arange(*args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + result = original_arange(*args, **kwargs) + return result + torch.arange = patched_arange + + original_empty_like = torch.empty_like + def patched_empty_like(*args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + result = original_empty_like(*args, **kwargs) + return result + torch.empty_like = patched_empty_like + + original_is_cuda = torch.Tensor.is_cuda + def always_cuda(self): + return True + torch.Tensor.is_cuda = property(always_cuda) + + origin_init_process_group = torch.distributed.init_process_group + def patched_init_process_group(*args, **kwargs): + if 'backend' in kwargs and kwargs['backend'] == 'nccl': + kwargs['backend'] = 'mccl' + result = origin_init_process_group(*args, **kwargs) + return result + torch.distributed.init_process_group = patched_init_process_group + + # def pin_memory(data, device=None): + # return data + # torch.utils.data._utils.pin_memory.pin_memory = pin_memory + + def _pass_pvtx(*args, **kwargs): + return + torch.cuda.nvtx.range_push = _pass_pvtx + torch.cuda.nvtx.range_pop = _pass_pvtx + + torch.cuda.is_current_stream_capturing = lambda: False + + origin_module_to = torch.nn.Module.to + def patched_module_to(self, *args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + return origin_module_to(self, *args, **kwargs) + torch.nn.Module.to = patched_module_to + + origin_tensor_to = torch.Tensor.to + def patched_tensor_to(self, *args, **kwargs): + args, kwargs = maybe_hook_cuda_args(args, kwargs) + return origin_tensor_to(self, *args, **kwargs) + torch.Tensor.to = patched_tensor_to + + def get_default_device(): + device = torch.device("musa", torch.musa.current_device()) + return device + torch.get_default_device = get_default_device + + def is_autocast_enabled(device_type=None): + return False + torch.is_autocast_enabled = is_autocast_enabled + + import os + #HACK(sherry): enable torch.compile + os.environ["NVTE_TORCH_COMPILE"] = "1" + os.environ["TORCHDYNAMO_DISABLE"] = "0" + #HACK(sherry) + +def py_patch(): + if sys.version_info >= (3.9, 0): + return + import math + def lcm(a, b): + return abs(a * b) // math.gcd(a, b) + math.lcm = lcm + return + + +py_patch() +patch_before_import_te() +patch_after_import_torch() diff --git a/transformer_engine/musa/common/CMakeLists.txt b/transformer_engine/musa/common/CMakeLists.txt new file mode 100644 index 0000000000..969ef6ba52 --- /dev/null +++ b/transformer_engine/musa/common/CMakeLists.txt @@ -0,0 +1,195 @@ +cmake_minimum_required(VERSION 3.21) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(MUSA_DIR "/usr/local/musa") +set(MUSA_ARCH "31") + +# Transformer Engine library +project(transformer_engine LANGUAGES CXX) + +list(APPEND CMAKE_MODULE_PATH "${MUSA_DIR}/cmake") +list(APPEND CMAKE_MODULE_PATH "${MUSA_DIR}/lib/cmake/mudnn") + +find_package(MUSA REQUIRED) +string(APPEND MUSA_MCC_FLAGS " -std=c++${CMAKE_CXX_STANDARD}") +string(APPEND MUSA_MCC_FLAGS " --offload-arch=mp_${MUSA_ARCH}") +# -mllvm -mtgpu-tempint-prealloc=1 just work for MUSA_ARCH=31 +if (MUSA_ARCH STREQUAL "31") + string(APPEND MUSA_MCC_FLAGS " -mllvm -mtgpu-tempint-prealloc=1") +endif() +set(MUSA_VERBOSE_BUILD ON) +set(MUSA_LINK_LIBRARIES_KEYWORD PUBLIC) + +if (CMAKE_BUILD_TYPE STREQUAL "Debug") + string(APPEND MUSA_MCC_FLAGS " -g") +endif() + +set(DEPENDENT_TARGETS) + +find_package(MUSAToolkit REQUIRED) +list(APPEND DEPENDENT_TARGETS MUSA::toolkit) + +include(mudnnTargets) +list(APPEND DEPENDENT_TARGETS mudnn) + +find_package(MCCL REQUIRED) +add_library(MUSA::mccl SHARED IMPORTED) +set_target_properties(MUSA::mccl PROPERTIES + IMPORTED_LOCATION ${MCCL_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES ${MCCL_INCLUDE_DIRS} +) +list(APPEND DEPENDENT_TARGETS MUSA::mccl) + +find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +list(APPEND DEPENDENT_TARGETS Python::Module) + +execute_process( + COMMAND + ${Python_EXECUTABLE} -c "import os, torch_musa;print(os.path.dirname(torch_musa.__file__))" + ERROR_QUIET + OUTPUT_VARIABLE TORCH_MUSA_PYTHONPATH +) +string(REGEX REPLACE "^(.+)\n$" "\\1" TORCH_MUSA_PYTHONPATH ${TORCH_MUSA_PYTHONPATH}) + +add_library(torch_musa_python SHARED IMPORTED) +set_target_properties(torch_musa_python PROPERTIES + IMPORTED_LOCATION "${TORCH_MUSA_PYTHONPATH}/lib/libmusa_python.so" +) +set_property(TARGET torch_musa_python APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/.." +) +set_property(TARGET torch_musa_python APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/share/torch_musa_codegen" +) +set_property(TARGET torch_musa_python APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/share/generated_cuda_compatible/include" +) +set_property(TARGET torch_musa_python APPEND PROPERTY + INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/share/generated_cuda_compatible/include/torch/csrc/api/include" +) +list(APPEND DEPENDENT_TARGETS torch_musa_python) + +execute_process( + COMMAND + ${Python_EXECUTABLE} -c "import os, torch;print(os.path.dirname(torch.__file__))" + ERROR_QUIET + OUTPUT_VARIABLE TORCH_PYTHONPATH +) +string(REGEX REPLACE "^(.+)\n$" "\\1" TORCH_PYTHONPATH ${TORCH_PYTHONPATH}) + +add_library(torch_python SHARED IMPORTED) +set_target_properties(torch_python PROPERTIES + IMPORTED_LOCATION "${TORCH_PYTHONPATH}/lib/libtorch_python.so" +) +list(APPEND DEPENDENT_TARGETS torch_python) + +# Configure Transformer Engine library +set(transformer_engine_SOURCES) +set(PLUGIN_NAME "transformer_engine") +list(APPEND transformer_engine_SOURCES + common.mu + transformer_engine.cpp + activation/gelu.mu + activation/relu.mu + activation/swiglu.mu + comm_gemm_overlap/comm_gemm_overlap.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.mu + fused_attn/fused_attn.cpp + fused_attn/thd_utils.mu + fused_rope/fused_rope.mu + fused_softmax/scaled_aligned_causal_masked_softmax.mu + fused_softmax/scaled_masked_softmax.mu + fused_softmax/scaled_upper_triang_masked_softmax.mu + gemm/mudnn_gemm.cpp + permutation/permutation.mu + permutation/permutation_mask.mu + recipe/delayed_scaling.mu + swizzle/swizzle.mu + transpose/multi_cast_transpose.mu + transpose/cast_transpose_fusion.mu + transpose/transpose_fusion.mu + transpose/transpose.mu + transpose/cast_transpose.mu + util/cast.mu + util/musa_driver.cpp + util/musa_runtime.cpp + util/padding.mu + util/rtc.cpp + util/system.cpp + util/mtfp8_cast_transpose.mu + util/mtfp8_dequantize.mu +) +set_source_files_properties(${transformer_engine_SOURCES} + PROPERTIES + MUSA_SOURCE_PROPERTY_FORMAT OBJ +) + +musa_add_library(${PLUGIN_NAME} SHARED ${transformer_engine_SOURCES}) +target_include_directories(${PLUGIN_NAME} PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}/.." + "${CMAKE_CURRENT_SOURCE_DIR}/include" +) +target_link_libraries(${PLUGIN_NAME} PUBLIC ${DEPENDENT_TARGETS}) + +# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) +if (NVTE_UB_WITH_MPI) + find_package(MPI REQUIRED) + target_link_libraries(${PLUGIN_NAME} PUBLIC MPI::MPI_CXX) + target_compile_definitions(${PLUGIN_NAME} PUBLIC NVTE_UB_WITH_MPI) +endif() + +# Helper functions to make header files with C++ strings +function(make_string_header STRING STRING_NAME) + configure_file( + "util/string_header.h.in" + "string_headers/${STRING_NAME}.h" + @ONLY + ) +endfunction() +function(make_string_header_from_file file_ STRING_NAME) + file(READ "${file_}" STRING) + configure_file( + util/string_header.h.in + "string_headers/${STRING_NAME}.h" + @ONLY + ) +endfunction() + +# Header files with C++ strings +make_string_header( + "${MUSA_DIR}/include" + string_path_musa_include +) +make_string_header_from_file( + transpose/rtc/cast_transpose_fusion.mu + string_code_transpose_rtc_cast_transpose_fusion_mu +) +make_string_header_from_file( + transpose/rtc/cast_transpose.mu + string_code_transpose_rtc_cast_transpose_mu +) +make_string_header_from_file( + transpose/rtc/transpose.mu + string_code_transpose_rtc_transpose_mu +) +make_string_header_from_file( + utils.muh + string_code_utils_muh +) +make_string_header_from_file( + util/math.h + string_code_util_math_h +) +target_include_directories(${PLUGIN_NAME} PRIVATE + "${CMAKE_CURRENT_BINARY_DIR}/string_headers" +) + +set_target_properties(${PLUGIN_NAME} PROPERTIES INSTALL_RPATH_USE_LINK_PATH ON) + +# Install library +install(TARGETS ${PLUGIN_NAME} DESTINATION .) diff --git a/transformer_engine/musa/common/activation/activation_template.h b/transformer_engine/musa/common/activation/activation_template.h new file mode 100644 index 0000000000..f0d636b697 --- /dev/null +++ b/transformer_engine/musa/common/activation/activation_template.h @@ -0,0 +1,74 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file activation_template.h + * \brief Activation functions template. + */ + +#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ +#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ + +#include +#include + +#include "../common.h" +#include "../util/cast_gated_kernels.muh" +#include "../util/cast_kernels.muh" +#include "../util/math.h" +#include "../util/vectorized_pointwise.h" + +namespace transformer_engine { + +template +void act_fn(const NVTETensor input, NVTETensor output, musaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = true; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; + + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); +} + +template +void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); +} + +template +void gated_act_fn(const NVTETensor input, NVTETensor output, musaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = false; + constexpr NVTETensor grad = nullptr; + + quantize_gated_helper(grad, input, output, stream); +} + +template +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = true; + + quantize_gated_helper(grad, input, output, stream); +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ diff --git a/transformer_engine/musa/common/activation/gelu.mu b/transformer_engine/musa/common/activation/gelu.mu new file mode 100644 index 0000000000..170f187c61 --- /dev/null +++ b/transformer_engine/musa/common/activation/gelu.mu @@ -0,0 +1,60 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_gelu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_gelu); + using namespace transformer_engine; + act_fn>(input, output, stream); +} + +void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dgelu); + using namespace transformer_engine; + dact_fn>(grad, input, output, stream); +} + +void nvte_geglu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_geglu); + using namespace transformer_engine; + gated_act_fn>(input, output, stream); +} + +void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dgeglu); + using namespace transformer_engine; + dgated_act_fn, dgelu>(grad, input, output, stream); +} + +void nvte_qgelu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_qgelu); + using namespace transformer_engine; + act_fn>(input, output, stream); +} + +void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dqgelu); + using namespace transformer_engine; + dact_fn>(grad, input, output, stream); +} + +void nvte_qgeglu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_qgeglu); + using namespace transformer_engine; + gated_act_fn>(input, output, stream); +} + +void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dqgeglu); + using namespace transformer_engine; + dgated_act_fn, dqgelu>(grad, input, output, stream); +} diff --git a/transformer_engine/musa/common/activation/relu.mu b/transformer_engine/musa/common/activation/relu.mu new file mode 100644 index 0000000000..c54e772e35 --- /dev/null +++ b/transformer_engine/musa/common/activation/relu.mu @@ -0,0 +1,60 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_relu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_relu); + using namespace transformer_engine; + act_fn>(input, output, stream); +} + +void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_drelu); + using namespace transformer_engine; + dact_fn>(grad, input, output, stream); +} + +void nvte_reglu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_reglu); + using namespace transformer_engine; + gated_act_fn>(input, output, stream); +} + +void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dreglu); + using namespace transformer_engine; + dgated_act_fn, drelu>(grad, input, output, stream); +} + +void nvte_srelu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_srelu); + using namespace transformer_engine; + act_fn>(input, output, stream); +} + +void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dsrelu); + using namespace transformer_engine; + dact_fn>(grad, input, output, stream); +} + +void nvte_sreglu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_sreglu); + using namespace transformer_engine; + gated_act_fn>(input, output, stream); +} + +void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dsreglu); + using namespace transformer_engine; + dgated_act_fn, dsrelu>(grad, input, output, stream); +} diff --git a/transformer_engine/musa/common/activation/swiglu.mu b/transformer_engine/musa/common/activation/swiglu.mu new file mode 100644 index 0000000000..4752dfda17 --- /dev/null +++ b/transformer_engine/musa/common/activation/swiglu.mu @@ -0,0 +1,34 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_silu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_silu); + using namespace transformer_engine; + act_fn>(input, output, stream); +} + +void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dsilu); + using namespace transformer_engine; + dact_fn>(grad, input, output, stream); +} + +void nvte_swiglu(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_swiglu); + using namespace transformer_engine; + gated_act_fn>(input, output, stream); +} + +void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_dswiglu); + using namespace transformer_engine; + dgated_act_fn, dsilu>(grad, input, output, stream); +} diff --git a/transformer_engine/musa/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/musa/common/comm_gemm_overlap/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..8f8c89b030 --- /dev/null +++ b/transformer_engine/musa/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -0,0 +1,1274 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/util/musa_driver.h" +#include "common/util/musa_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "userbuffers/userbuffers.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +#define AS_VECTOR(shape) std::vector(shape.data, shape.data + shape.ndim) + +using namespace std::placeholders; + +namespace transformer_engine { + +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ + +bool ubuf_built_with_mpi() { +#ifdef NVTE_UB_WITH_MPI + return true; +#else + return false; +#endif +} + +CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { + // Initialize userbuf communicator + if (!_comm_created) { + if (myrank == 0) { + printf("!!! [UB] Create Userbuffers Communicator\n"); + } +#ifdef NVTE_UB_WITH_MPI + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); +#else + create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + allgather_handle, barrier_handle, 1, 1, tp_size, 1); +#endif + _comm_created = true; + } + _use_ce = static_cast(use_ce); + _num_comm_sm = num_comm_sm; + _cga_size = comm_cga_size; + + if (gemm_priority == 0 && comm_priority == 0) { + transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority); + } else { + _gemm_priority = gemm_priority; + _comm_priority = comm_priority; + } + for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + musaStream_t stream; + NVTE_CHECK_CUDA(musaStreamCreateWithPriority(&stream, musaStreamNonBlocking, _gemm_priority)); + _stream_compute.push_back(std::move(stream)); + } + + if (use_ce) { + // need tp_size-1 streams for comm with peers + for (int i = 0; i < tp_size - 1; i++) { + musaStream_t stream; + NVTE_CHECK_CUDA(musaStreamCreateWithPriority(&stream, musaStreamNonBlocking, comm_priority)); + _stream_comm_ce.push_back(std::move(stream)); + } + } + + _num_splits = num_splits; + _rank = _ub_comm->myrank; + _tp_size = tp_size; + _tp_id = _rank % _tp_size; + + // Set the number of SMs for GEMM with margin + int sm_count = transformer_engine::cuda::sm_count(); + _math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + void *counter_ptr; + size_t counter_bytes = _num_splits * 2 * sizeof(int32_t); + NVTE_CHECK_CUDA(musaMalloc(&counter_ptr, counter_bytes)); + NVTE_CHECK_CUDA(musaMemset(counter_ptr, 0, counter_bytes)); + NVTE_CHECK_CUDA(musaMemset(counter_ptr, 1, counter_bytes / 2)); + _counter = TensorWrapper(counter_ptr, std::vector{static_cast(_num_splits * 2)}, + DType::kInt32); + } + // CUDA event creation + musaEventCreateWithFlags(&_start_compute, 0); + musaEventCreateWithFlags(&_stop_compute, 0); + musaEventCreateWithFlags(&_start_comm, 0); + musaEventCreateWithFlags(&_stop_comm, 0); + + /* + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. + This is needed only for Hopper, which uses persistent CTA execution. + */ + int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); + int runtime_version = 0; + musaRuntimeGetVersion(&runtime_version); + musaDeviceProp deviceProp; + musaGetDeviceProperties(&deviceProp, 0); + if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { + musaEventCreateWithFlags(&_comm_launch_event, musaEventDisableTiming); + } else { + _comm_launch_event = 0; + } +} + +CommOverlapCore::~CommOverlapCore() { + musaEventDestroy(_stop_comm); + musaEventDestroy(_start_comm); + musaEventDestroy(_stop_compute); + musaEventDestroy(_start_compute); + if (_comm_launch_event) musaEventDestroy(_comm_launch_event); + + if (_atomic_gemm) musaFree(_counter.dptr()); + + for (size_t i = 0; i < _stream_compute.size(); i++) musaStreamDestroy(_stream_compute[i]); + for (size_t i = 0; i < _stream_comm_ce.size(); i++) musaStreamDestroy(_stream_comm_ce[i]); + + if (_comm_created) { +#ifdef NVTE_UB_WITH_MPI + destroy_communicator_mpi(_ub_comm); +#else + destroy_communicator(_ub_comm); +#endif + _comm_created = false; + } +} + +TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, + const std::vector &chunk_shape) { + TensorWrapper chunk; + for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { + auto param_type = static_cast(param_id); + auto param = source.get_parameter(param_type); + auto param_dptr = reinterpret_cast(param.data_ptr); + auto param_dtype = static_cast(param.dtype); + auto param_shape = AS_VECTOR(param.shape); + + if (param_dptr != nullptr) { + if (param_type == NVTETensorParam::kNVTERowwiseData || + param_type == NVTETensorParam::kNVTEColumnwiseData) { + // Offset data pointer + param_dptr += chunk_offset * typeToSize(param_dtype); + param_shape = chunk_shape; + + if (param_type == NVTETensorParam::kNVTEColumnwiseData && + source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { + // Columnwise shape for non-block scaled tensors shifts the last dimension to the front + auto last_dim = param_shape.back(); + param_shape.pop_back(); + param_shape.insert(param_shape.begin(), last_dim); + } + } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && + (param_type == NVTETensorParam::kNVTERowwiseScaleInv || + param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { + // Calculate block scaling offset and size + auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? source.shape().data[0] + : source.columnwise_shape().data[0]; + auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? chunk_shape.front() + : chunk_shape.back(); + auto chunk_scale_start = chunk_offset / 32; + auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; + auto chunk_scale_size = chunk_scale_end - chunk_scale_start; + param_dptr += chunk_scale_start * typeToSize(param_dtype); + param_shape = std::vector{chunk_scale_size}; + } + + // Set chunked source parameters into the chunked tensor output + chunk.set_parameter(param_type, reinterpret_cast(param_dptr), param_dtype, + param_shape); + } + } + return chunk; +} + +TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source, + size_t chunk_offset, + const std::vector &chunk_shape) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); + + // Update chunk with offset data pointers from the communication buffer + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), + chunk.columnwise_shape()); + } + return chunk; +} + +/*************************************************************************************************** + * Comm+GEMM Overlap Base (Pipelined / Collective) + **************************************************************************************************/ + +CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, + bool rs_overlap_first_gemm) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm) { + _rs_overlap_first_gemm = rs_overlap_first_gemm; + _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); + NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, + "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", + "or 2 (multi-atomic)."); + + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); + + NVTE_CHECK_CUDA( + musaStreamCreateWithPriority(&_stream_comm, musaStreamNonBlocking, _comm_priority)); + NVTE_CHECK_CUDA(musaEventCreateWithFlags(&_start_d2dcopy, 0)); +} + +CommOverlapBase::~CommOverlapBase() { + musaEventDestroy(_start_d2dcopy); + musaStreamDestroy(_stream_comm); +} + +void CommOverlapBase::comm_userbuff_over_ce(void *rs_output, transformer_engine::DType dtype, const int chunk_idx, const int offset, + const int rowelements, const int colelements, const int strideelements, + bool out_of_place, bool comm_rs, bool is_pipeline, musaStream_t compute_stream) { + + assert(dtype == transformer_engine::DType::kFloat16 || dtype == transformer_engine::DType::kBFloat16); + + MUatomicType atomicType = MUatomicType::MU_ATOMIC_TYPE_ATOMIC_ADD_BF16; + if (dtype == transformer_engine::DType::kFloat16) { + atomicType = MUatomicType::MU_ATOMIC_TYPE_ATOMIC_ADD_HF16; + } + + size_t elements = rowelements * colelements; + size_t elements_bytes = elements * _ubuf.element_size(); + size_t slice = elements / _tp_size; + size_t slice_bytes = slice * _ubuf.element_size(); + size_t gpu_flag_offset = NVTE_REG0_OFFSET(_ub_comm) - NVTE_REG0_SINGLENODE + NVTE_MAX_OPS; + void* my_gpu_flag_rs = reinterpret_cast(_ub_comm->gpu_ptrs) + gpu_flag_offset + chunk_idx * sizeof(uint64_t); + void* my_gpu_flag_sync = reinterpret_cast(_ub_comm->gpu_ptrs) + gpu_flag_offset + (chunk_idx + _num_splits) * sizeof(uint64_t); + + // ensure all peer finish the same gemm chunk before RS + if (comm_rs && is_pipeline) { + for (int i = 1; i < _tp_size; i++) { + char* peer_comm_ptr = reinterpret_cast(_ub_comm->peer_ptr[0][(_tp_id + i) % _tp_size]); + void* peer_gpu_flag = peer_comm_ptr + gpu_flag_offset + (chunk_idx + _num_splits) * sizeof(uint64_t); + + NVTE_CHECK_CUDA_DRIVER(muMemoryAtomicValueAsync( + (MUdeviceptr)peer_gpu_flag, + 1, + MUatomicValueType::MU_ATOMIC_VALUE_TYPE_ATOMIC_ADD64, + (MUstream)_stream_comm_ce[i - 1])); + } + for (int i = 1; i < _tp_size; i++) { + NVTE_CHECK_CUDA_DRIVER(muStreamWaitValue64( + (MUstream)_stream_comm_ce[i - 1], + (MUdeviceptr)my_gpu_flag_sync, + (muuint64_t)(_tp_size - 1), + MUstreamWaitValue_flags::MU_STREAM_WAIT_VALUE_EQ)); + } + } + + for (int i = 1; i < _tp_size; i++) { + size_t my_offset = 0; + size_t my_offset_bytes = 0; + if (comm_rs) { + my_offset = offset + _tp_id * slice; + my_offset_bytes = my_offset * _ubuf.element_size(); + } else { + my_offset = offset + ((_tp_id + i) % _tp_size) * slice; + my_offset_bytes = my_offset * _ubuf.element_size(); + } + int peer = (_tp_id + i) % _tp_size; + void* my_ptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + my_offset_bytes; + void* peer_ptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peer]) + my_offset_bytes; + + // pull mode + if (comm_rs) { + NVTE_CHECK_CUDA_DRIVER(muMemoryAtomicAsync( + (MUdeviceptr)my_ptr, + (MUdeviceptr)peer_ptr, + slice, + atomicType, + (MUstream)_stream_comm_ce[i - 1])); + } else { + NVTE_CHECK_CUDA(musaMemcpyAsync( + my_ptr, + peer_ptr, + slice_bytes, + musaMemcpyDeviceToDevice, + _stream_comm_ce[i - 1])); + } + + // TODO: maybe we can remove wait in AG for higher perf + NVTE_CHECK_CUDA_DRIVER(muMemoryAtomicValueAsync( + (MUdeviceptr)my_gpu_flag_rs, + 1, + MUatomicValueType::MU_ATOMIC_VALUE_TYPE_ATOMIC_ADD64, + (MUstream)_stream_comm_ce[i - 1])); + } + + NVTE_CHECK_CUDA_DRIVER(muStreamWaitValue64( + _stream_comm, + (MUdeviceptr)my_gpu_flag_rs, + (muuint64_t)(_tp_size - 1), + MUstreamWaitValue_flags::MU_STREAM_WAIT_VALUE_EQ)); + + //TODO: this sync will affect perf, we try to remove it; but cost will imbalance when we remove it + NVTE_CHECK_CUDA(musaStreamSynchronize(_stream_comm)); + + if (out_of_place) { + void* ubuffer_ptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + (offset + _tp_id * slice) * _ubuf.element_size(); + NVTE_CHECK_CUDA(musaMemcpy2DAsync( + (void *)rs_output, + strideelements * _ubuf.element_size(), + (void *)ubuffer_ptr, + colelements * _ubuf.element_size(), + colelements * _ubuf.element_size(), + rowelements / _tp_size, + musaMemcpyDeviceToDevice, + _stream_comm)); + } + } + +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, + musaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + int m = _ubuf.size(0); + int n = _ubuf.size(1); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(musaEventRecord(_start_comm, stream_main)); + if (_use_ce) { + for (int i = 0; i < _tp_size - 1; i++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)_stream_comm_ce[i], _start_comm, 0)); + } + } + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)_stream_comm, _start_comm, 0)); + + // Communication: AG and RS + int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size + if (comm_type == CommOverlapType::AG) { + if (_use_ce) { + comm_userbuff_over_ce(nullptr, A.dtype(), 0, 0, m, n, n, false, false, false, (musaStream_t)stream_main); + } else { + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (musaEvent_t)_comm_launch_event); + } + } else { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + comm_elements *= 2; + assert(rs_output.numel() == _ubuf.numel() / _tp_size); + assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); + assert(rs_output.element_size() == 2); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + reducescatter2_userbuff_fp8<__mt_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, + comm_elements, _ub_comm, _stream_comm, + (musaEvent_t)_comm_launch_event); + } else { + if (_use_ce) { + comm_userbuff_over_ce(nullptr, A.dtype(), 0, 0, m, n, n, false, true, false, (musaStream_t)stream_main); + } else { + reducescatter2_userbuff_inplace(_ub_reg, A.dtype(), 0, comm_elements, _ub_comm, _stream_comm, + (musaEvent_t)_comm_launch_event); + } + + } + } + + assert(pre_gelu_out.numel() == 0); + // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch + if (_comm_launch_event) + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)stream_main, _comm_launch_event, 0)); + nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, + grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, + stream_main); + + _ub_comm->sms = ori_sms; + + if (_use_ce) { + size_t gpu_flag_offset = NVTE_REG0_OFFSET(_ub_comm) - NVTE_REG0_SINGLENODE + NVTE_MAX_OPS; + void* my_gpu_flag_rs = reinterpret_cast(_ub_comm->gpu_ptrs) + gpu_flag_offset; + void* my_gpu_flag_sync = reinterpret_cast(_ub_comm->gpu_ptrs) + gpu_flag_offset + _num_splits * sizeof(uint64_t); + NVTE_CHECK_CUDA_DRIVER(muStreamWriteValue64( + (MUstream)_stream_comm, + (MUdeviceptr)my_gpu_flag_sync, + 0, + MUstreamWriteValue_flags::MU_STREAM_WRITE_VALUE_DEFAULT)); + NVTE_CHECK_CUDA_DRIVER(muStreamWriteValue64( + (MUstream)_stream_comm, + (MUdeviceptr)my_gpu_flag_rs, + 0, + MUstreamWriteValue_flags::MU_STREAM_WRITE_VALUE_DEFAULT)); + } + NVTE_CHECK_CUDA(musaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + musaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); + size_t m_chunk = m / _num_splits; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _num_splits, false, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(musaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_compute[0], _start_compute, 0)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + auto output_d = get_buffer_chunk_like(D, 0, {n, m}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), + _stream_compute[0]); + + for (int i = 0; i < _num_splits; i++) { + if (_rs_kernel_type == 1) { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_atomic_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + &counter_ptr[i], _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _num_splits, &counter_ptr[i], _ub_comm, + _stream_comm); + } + } else if (_rs_kernel_type == 2) { + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_multiatomic_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + counter_ptr, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, + _num_splits, counter_ptr, _ub_comm, + _stream_comm); + } + break; + } else { + consumer(counter_ptr, i, _stream_comm); + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, D.scale_inv(), + _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(musaEventRecord(_stop_compute, _stream_compute[0])); + NVTE_CHECK_CUDA(musaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, musaStream_t stream_main) { + // Get GEMM dimensions + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); + size_t m_chunk = m / _num_splits; + size_t input_a_chunk_size = m_chunk * k; + size_t output_chunk_size = n * m_chunk; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(musaEventRecord(_start_compute, stream_main)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_rs_overlap_first_gemm) { + auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[0]); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA( + musaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + if (_use_ce) { + for (int j = 0; j < _tp_size - 1; j++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)_stream_comm_ce[j], _start_comm, 0)); + } + } else { + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)_stream_comm, _start_comm, 0)); + } + + // Communication chunk + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + if (_use_ce) { + comm_userbuff_over_ce(rs_output_ptr, A.dtype(), i - 1, (i - 1) * output_chunk_size, n, m_chunk, m, true, true, true, + (musaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, A.dtype(), _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); + } + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + NVTE_CHECK_CUDA(musaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Last communication chunk with max SM + _ub_comm->sms = UB_MAX_SM; + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, + n, m, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, A.dtype(), _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } else { + for (int i = 0; i < _num_splits; i++) { + auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA(musaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); + if (_use_ce) { + for (int j = 0; j < _tp_size - 1; j++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)_stream_comm_ce[j], _start_comm, 0)); + } + } + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)_stream_comm, _start_comm, 0)); + + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + if (_use_ce) { + comm_userbuff_over_ce(rs_output_ptr, A.dtype(), i, i * output_chunk_size, n, m_chunk, m, true, true, true, + (musaStream_t)_stream_compute[i % _stream_compute.size()]); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, A.dtype(), _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, + (musaStream_t)_stream_comm); + } + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(musaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + if (_use_ce) { + for (size_t i = 0; i < _stream_comm_ce.size(); i++) { + NVTE_CHECK_CUDA(musaEventRecord(_stop_comm, (musaStream_t)_stream_comm_ce[i])); + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)stream_main, _stop_comm, 0)); + } + + size_t gpu_flag_offset = NVTE_REG0_OFFSET(_ub_comm) - NVTE_REG0_SINGLENODE + NVTE_MAX_OPS; + for (size_t i = 0; i < _num_splits; i++) { + void* my_gpu_flag_rs = reinterpret_cast(_ub_comm->gpu_ptrs) + gpu_flag_offset + i * sizeof(uint64_t); + void* my_gpu_flag_sync = reinterpret_cast(_ub_comm->gpu_ptrs) + gpu_flag_offset + (i + _num_splits) * sizeof(uint64_t); + NVTE_CHECK_CUDA_DRIVER(muStreamWriteValue64( + (MUstream)_stream_comm, + (MUdeviceptr)my_gpu_flag_sync, + 0, + MUstreamWriteValue_flags::MU_STREAM_WRITE_VALUE_DEFAULT)); + NVTE_CHECK_CUDA_DRIVER(muStreamWriteValue64( + (MUstream)_stream_comm, + (MUdeviceptr)my_gpu_flag_rs, + 0, + MUstreamWriteValue_flags::MU_STREAM_WRITE_VALUE_DEFAULT)); + } + } + NVTE_CHECK_CUDA(musaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::split_overlap_rs + +/*************************************************************************************************** + * Comm+GEMM Overlap P2P Base (Ring-Exchange) + **************************************************************************************************/ + +CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm) { + _is_p2p = true; + _is_reduce_scatter = comm_type == CommOverlapType::RS; + _aggregate = aggregate; + + // Create workspace tensor with userbuffer + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + int buffer_chunk_bytes = buffer_bytes / tp_size; + _num_ubuf_chunks = tp_size; + if (_is_reduce_scatter) { + // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk + // outputs for reduction at the end of the pipelining. + buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); + _num_ubuf_chunks = tp_size * 2 - 1; + } + + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + buffer_dtype); + + // Create tensor chunks for easy management + char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); + for (int i = 0; i < _num_ubuf_chunks; i++) { + _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), + {buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype)); + ubuf_byte_ptr += buffer_chunk_bytes; + } + + _rank_round_tp = (_rank / _tp_size) * _tp_size; + _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; + _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; + + _self_chunk_id = _tp_id; + if (_atomic_gemm && !_is_reduce_scatter) { + _use_multiatomic_ag = getenv("NVTE_AG_P2P_MULTI_ATOMIC"); + if (_use_multiatomic_ag) { + _use_ce = 0; + _ub_comm->push = 1; + if (_rank == 0) { + printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); + } + } + _self_chunk_id = 0; + NVTE_CHECK_CUDA(musaMemset(_counter.dptr(), 0, sizeof(int32_t))); + } + + for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + musaStream_t stream; + NVTE_CHECK_CUDA(musaStreamCreateWithPriority(&stream, musaStreamNonBlocking, _comm_priority)); + _stream_send.push_back(std::move(stream)); + } + NVTE_CHECK_CUDA( + musaStreamCreateWithPriority(&_stream_recv, musaStreamNonBlocking, _comm_priority)); + NVTE_CHECK_CUDA( + musaStreamCreateWithPriority(&_stream_comm_ce, musaStreamNonBlocking, _comm_priority)); + NVTE_CHECK_CUDA(musaEventCreateWithFlags(&_stop_send, 0)); + NVTE_CHECK_CUDA(musaEventCreateWithFlags(&_stop_recv, 0)); + NVTE_CHECK_CUDA(musaEventCreateWithFlags(&_stop_comm, 0)); +} + +CommOverlapP2PBase::~CommOverlapP2PBase() { + musaEventDestroy(_stop_recv); + musaEventDestroy(_stop_send); + musaEventDestroy(_stop_comm); + musaStreamDestroy(_stream_recv); + musaStreamDestroy(_stream_comm_ce); + for (size_t i = 0; i < _stream_send.size(); i++) musaStreamDestroy(_stream_send[i]); +} + +TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, + size_t chunk_id) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape())); + + // Update chunk with offset data pointers from the communication buffer + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape()); + } + return chunk; +} + +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, musaStream_t stream_main) { +/* + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t n_chunk = _ubufs[0].size(0); + assert(pre_gelu_out.numel() == 0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Create an GEMM output buffer with N+1 chunks in a contiguous memory + void *D_buffer_ptr; + int D_chunk_bytes = n_chunk * m * D.element_size(); + NVTE_CHECK_CUDA(musaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), + D.scale_inv(), D.scale_inv_shape(), D.scaling_mode()); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, true, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(musaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape())); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + + for (int i = 0; i < _tp_size - 1; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = i; + int recv_chunk_id = i + 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + if (_use_multiatomic_ag) { + if (i == 0) { + _ub_comm->use_ce = 0; + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, + true, _stream_recv); + } + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, + _stream_recv); + producer(counter_ptr, recv_chunk_id, _stream_recv); + } + if (i == 0) { + nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false, + _counter.data(), stream_main); + } + } + + // Store the input activation for backprop + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); + assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); + NVTE_CHECK_CUDA( + musaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), + _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), + musaMemcpyDeviceToDevice, _stream_send[0])); + NVTE_CHECK_CUDA(musaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_send, 0)); + } + + // Copy the first GEMM output chunk to the end chunk position of D_buffer + char *src_ptr = reinterpret_cast(D_buffer.dptr()); + NVTE_CHECK_CUDA(musaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes, + musaMemcpyDeviceToDevice, stream_main)); + + // Return the last N rows of D_buffer + NVTE_CHECK_CUDA(musaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), + musaMemcpyDeviceToDevice, stream_main)); + + // Clean up buffer allocation + NVTE_CHECK_CUDA(musaFreeAsync(D_buffer_ptr, stream_main)); + + _ub_comm->sms = ori_sms; +*/ +} // CommOverlapP2PBase::atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + musaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const bool do_gelu = pre_gelu_out.numel() > 0; + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + NVTE_CHECK_CUDA(musaEventRecord(_start_compute, stream_main)); + if (_use_ce) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_comm_ce, _start_compute, 0)); + } + else { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_recv, _start_compute, 0)); + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + if (_aggregate) { + const int num_steps = _tp_size / 2; + input_chunk_size *= 2; + output_chunk_size *= 2; + + // Initial 1X input chunk exchange between neighboring peers + int send_chunk_id = _tp_id; + int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + _stream_send[0]); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + _stream_recv); + NVTE_CHECK_CUDA(musaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); + + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + auto input_b_chunk = + get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < num_steps - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, _stream_send[0]); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, _stream_recv); + NVTE_CHECK_CUDA(musaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + NVTE_CHECK_CUDA( + musaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(musaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + musaMemcpyDeviceToDevice, _stream_send[0])); + } + } + } else { + for (int i = 0; i < _tp_size; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < _tp_size - 1) { + // P2P communication + if (_use_ce) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_comm_ce, _start_comm, 0)); + comm_userbuff_over_ce(_ub_reg, recv_offset, _ub_reg, recv_offset, _ubufs[0].numel(), + comm_bytes, _ub_comm, _next_rank, _prev_rank, A.dtype(), _tp_id, + _stream_comm_ce); + + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, _stream_send[0]); + NVTE_CHECK_CUDA(musaStreamSynchronize(_stream_send[0])); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, _stream_recv); + NVTE_CHECK_CUDA(musaStreamSynchronize(_stream_recv)); + NVTE_CHECK_CUDA(musaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + + } + NVTE_CHECK_CUDA( + musaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(musaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + musaMemcpyDeviceToDevice, _stream_send[0])); + } + } + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(musaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + if (!_use_ce) { + NVTE_CHECK_CUDA(musaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(musaEventRecord(_stop_recv, _stream_recv)); + } else { + NVTE_CHECK_CUDA(musaEventRecord(_stop_recv, _stream_comm_ce)); + } + + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // CommOverlapP2PBase::split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + musaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Reset counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, false, stream_main); + + // Catch up the main stream + NVTE_CHECK_CUDA(musaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + // Atomic GEMM + // Process GEMM chunks in the order that AG+GEMM places the output chunks. + auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape())); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, true, _counter.data(), stream_main); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + + consumer(counter_ptr, send_chunk_id, _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + NVTE_CHECK_CUDA(musaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + _ub_comm->sms = ori_sms; +} + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + musaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n_chunk = _ubufs[0].size(0); + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Get input and workspace data pointers + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Catch up the main stream + NVTE_CHECK_CUDA(musaEventRecord(_start_compute, stream_main)); + if (_use_ce) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_comm_ce, _start_compute, 0)); + } else { + for (size_t i = 0; i < _stream_send.size(); i++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_send[i], _start_compute, 0)); + } + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_recv, _start_compute, 0)); + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int stream_id = i % _stream_compute.size(); + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + + auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); + auto output_chunk = get_buffer_chunk_by_id(D, i); + auto workspace_chunk = + get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[stream_id]); + + if (i > 0) { + // P2P communication chunk + int prev_stream_id = (i - 1) % _stream_compute.size(); + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + NVTE_CHECK_CUDA(musaEventRecord(_start_comm, _stream_compute[prev_stream_id])); + if (_use_ce) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_comm_ce, _start_comm, 0)); + comm_userbuff_over_ce(_ub_reg, send_offset, _ub_reg, recv_offset, _ubufs[0].numel(), + comm_bytes, _ub_comm, send_rank, recv_rank, A.dtype(), _tp_id, + _stream_comm_ce); + } else { + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_send[prev_stream_id]); + NVTE_CHECK_CUDA(musaStreamSynchronize(_stream_send[prev_stream_id])); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + NVTE_CHECK_CUDA(musaStreamSynchronize(_stream_recv)); + } + } + } + NVTE_CHECK_CUDA(musaStreamSynchronize(_stream_comm_ce)); + + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(musaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + if (!_use_ce) { + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(musaEventRecord(_stop_send, _stream_send[i])); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_send, 0)); + } + NVTE_CHECK_CUDA(musaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_recv, 0)); + } + else { + NVTE_CHECK_CUDA(musaEventRecord(_stop_comm, _stream_comm_ce)); + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream_main, _stop_comm, 0)); + } + + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + + _ub_comm->sms = ori_sms; +} + +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.cc b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.cc new file mode 120000 index 0000000000..b1925f2198 --- /dev/null +++ b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.cc @@ -0,0 +1 @@ +../../../../common/comm_gemm_overlap/userbuffers/ipcsocket.cc \ No newline at end of file diff --git a/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.h b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.h new file mode 120000 index 0000000000..6068aec40a --- /dev/null +++ b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/ipcsocket.h @@ -0,0 +1 @@ +../../../../common/comm_gemm_overlap/userbuffers/ipcsocket.h \ No newline at end of file diff --git a/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp new file mode 100644 index 0000000000..d9f649f630 --- /dev/null +++ b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -0,0 +1,628 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/util/musa_driver.h" +#include "common/util/musa_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "ipcsocket.h" +#include "userbuffers.h" + +#ifdef NVTE_UB_WITH_MPI +static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD; +static MPI_Comm EXT_COMM_INTRA; +static MPI_Comm EXT_COMM_INTER; + +#define UB_MPI_CHECK(expr) \ + do { \ + const int mpicode = (expr); \ + if (mpicode != MPI_SUCCESS) { \ + char mpimsg[MPI_MAX_ERROR_STRING]; \ + int mpilen; \ + MPI_Error_string(mpicode, mpimsg, &mpilen); \ + std::vector errmsg(1024); \ + snprintf(errmsg.data(), errmsg.size(), "%s:%d in function %s: %s", __FILE__, __LINE__, \ + __func__, mpimsg); \ + throw std::runtime_error(errmsg.data()); \ + } \ + } while (false) + +void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, + ExtComm comm) { + int numranks; + UB_MPI_CHECK(MPI_Comm_size(comm, &numranks)); + assert(globalbytes == numranks * localbytes); + UB_MPI_CHECK( + MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm)); +} + +void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } +#else +#define EXT_COMM_WORLD "world" +#define EXT_COMM_INTRA "intra" +#define EXT_COMM_INTER "inter" +#endif + +#define MULTICAST_GB_TOTAL 512 + +int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } + +#define IPCCHECK(cmd) \ + do { \ + ipcSocketResult_t r = cmd; \ + if (r != ipcSocketSuccess) { \ + printf("Failed, UDS error %s:%d '%s'\n", __FILE__, __LINE__, ipcSocketGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define IPCCHECKGOTO(call, RES, label) \ + do { \ + RES = call; \ + if (RES != ipcSocketSuccess && RES != ipcSocketInProgress) { \ + goto label; \ + } \ + } while (0); + +int pipe_rank(communicator *comm, int step) { + int mynode = comm->myrank / comm->nvsize; + int mylocal = comm->nvrank; + int numlocal = comm->nvsize; + + int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize; + int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal; + int newnode = mynode; + newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes; + int allnodes = comm->nranks / comm->nvsize; + newnode = (allnodes + (newnode % allnodes)) % allnodes; + return newnode * numlocal + newlocal; +} + +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes) { + *comm = new communicator(); + + (*comm)->comm_world = EXT_COMM_WORLD; + (*comm)->_allgather = ext_allgather; + (*comm)->_barrier = ext_barrier; + (*comm)->nranks = numranks; + (*comm)->myrank = myrank; + (*comm)->free_region = 0; + (*comm)->launch_mode = NVTE_LAUNCH_GPU | NVTE_LAUNCH_CPU; + + int cur_dev, ndev; + musaDeviceProp device_prop; + NVTE_CHECK_CUDA(musaGetDevice(&cur_dev)); + NVTE_CHECK_CUDA(musaGetDeviceCount(&ndev)); + NVTE_CHECK_CUDA(musaGetDeviceProperties(&device_prop, cur_dev)); + (*comm)->sm_arch = device_prop.major; + // (*comm)->use_rr_kernel = device_prop.major == 8; + (*comm)->use_rr_kernel = 0; + (*comm)->push = 1; + (*comm)->use_ce = 0; + (*comm)->cga_size = 2; + for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; + (*comm)->head = 0; + (*comm)->tail = 0; + (*comm)->active_nreqs = 0; + for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; + + int device_clock = 0; + // 110 sec wait time by default + int sec_timeout = getenv("UB_TIMEOUT") ? atoi(getenv("UB_TIMEOUT")) : 110; + NVTE_CHECK_CUDA(musaDeviceGetAttribute(&device_clock, musaDevAttrClockRate, cur_dev)); + (*comm)->ub_timeout = 1000ull * device_clock * sec_timeout; + if ((*comm)->myrank == 0) { + printf("UB_TIMEOUT is set to %d sec, %" PRIu64 " cycles, freq: %dkhz\n", sec_timeout, + (*comm)->ub_timeout, device_clock); + } + + (*comm)->comm_intra = EXT_COMM_INTRA; + (*comm)->nvrank = mylocal; + (*comm)->nvsize = numlocal; + + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + int core; + if (mylocal == 0) core = 50; + if (mylocal == 1) core = 58; + if (mylocal == 2) core = 18; + if (mylocal == 3) core = 26; + if (mylocal == 4) core = 114; + if (mylocal == 5) core = 122; + if (mylocal == 6) core = 82; + if (mylocal == 7) core = 90; + + CPU_SET(core, &cpuset); + if (!getenv("NVTE_NODOUBLE")) { + if (core > 128) + CPU_SET(core - 128, &cpuset); + else + CPU_SET(core + 128, &cpuset); + } + if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); + + if (ndev == numlocal) { // all visible devices + if (cur_dev != mylocal) + printf("%d: device used %d[%d] ,resetting device to %d\n", myrank, cur_dev, ndev, mylocal); + NVTE_CHECK_CUDA(musaSetDevice(mylocal)); + } + (*comm)->mydev = cur_dev; + // FIXME need to check that numlocal is multiple of pipegpus x tensorgpus + // ar1 is data + int divgpus = pipegpus * tensorgpus; + int datagpus = numlocal / divgpus; + (*comm)->ar_nvsize = datagpus; + (*comm)->ar_firstgpu = mylocal - ((mylocal / tensorgpus) % datagpus) * tensorgpus; + (*comm)->ar_nvrank = (mylocal - (*comm)->ar_firstgpu) / tensorgpus; + // ar2 is tensor + (*comm)->ar2_nvsize = tensorgpus; + (*comm)->ar2_firstgpu = mylocal - mylocal % tensorgpus; + (*comm)->ar2_nvrank = mylocal - (*comm)->ar2_firstgpu; + // ar2 has step equal to ar_nvsize + int allnodes = numranks / numlocal; + int nodeid = myrank / numlocal; + int datanodes = allnodes / pipenodes / tensornodes; + int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes); + + (*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus); + + (*comm)->comm_inter = EXT_COMM_INTER; + (*comm)->first_node = nodeid - mynode; + (*comm)->num_nodes = numnodes; + (*comm)->my_node = mynode; + + (*comm)->num2_nodes = tensornodes; + (*comm)->my2_node = (mynode / datanodes) % tensornodes; + (*comm)->first2_node = mynode - (*comm)->my2_node * datanodes; + + (*comm)->fifo = reinterpret_cast(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); + (*comm)->nblocks = 8; + (*comm)->alignblock = 1024 * 512; + (*comm)->minblock = 1024 * 2 * 1024; + (*comm)->asyncblocks = 16; + +#define NBUF 2 + +#if CUDART_VERSION >= 12010 + if (!transformer_engine::getenv("UB_SKIPMC") && + transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { + // multicast init only for TP ops (____2 operations) + size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30); + (*comm)->mc_offset = 0; + (*comm)->use_mc = 1; + size_t gran; + CUmulticastObjectProp mcProp = {}; + mcProp.numDevices = (*comm)->ar2_nvsize; + mcProp.size = (*comm)->mc_maxsize; + mcProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + NVTE_CALL_CHECK_CUDA_DRIVER( + muMulticastGetGranularity, &gran, &mcProp, + static_cast(CU_MULTICAST_GRANULARITY_RECOMMENDED)); + mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran; + mcProp.size = mc_maxsize; + (*comm)->mc_maxsize = mc_maxsize; + + // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. + // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the + // file descriptor and prevent muMemImportFromShareableHandle() from correctly + // interpreting the file. Instead, we use Unix domain sockets for the kernel to + // recreate the correct file descriptor on every receiving rank. + int fd; + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu; + ipcSocketResult_t ret = ipcSocketSuccess; + IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); + (*comm)->_barrier((*comm)->comm_world); + + if ((*comm)->ar2_nvrank == 0) { + NVTE_CALL_CHECK_CUDA_DRIVER(muMulticastCreate, &(*comm)->mc_handle, &mcProp); + NVTE_CALL_CHECK_CUDA_DRIVER( + muMemExportToShareableHandle, reinterpret_cast(&fd), (*comm)->mc_handle, + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); + + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + } + } else { + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + } + } + + error: + if ((*comm)->ar2_nvrank != 0) { + NVTE_CALL_CHECK_CUDA_DRIVER( + muMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + } + IPCCHECK(ipcSocketClose(&ipcSock)); + close(fd); + NVTE_CALL_CHECK_CUDA_DRIVER(muMulticastAddDevice, (*comm)->mc_handle, + (MUdeviceptr)(*comm)->mydev); + + MUdeviceptr mc_va; + NVTE_CALL_CHECK_CUDA_DRIVER(muMemAddressReserve, &mc_va, mc_maxsize, (size_t)0, (MUdeviceptr)0U, + (uint64_t)0); + NVTE_CALL_CHECK_CUDA_DRIVER(muMemMap, mc_va, mc_maxsize, (size_t)0, (*comm)->mc_handle, + (uint64_t)0); + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = (*comm)->mydev; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + NVTE_CALL_CHECK_CUDA_DRIVER(muMemSetAccess, mc_va, mc_maxsize, + const_cast(&accessDesc), (size_t)1); + + (*comm)->mc_baseptr = reinterpret_cast(mc_va); + (*comm)->_barrier((*comm)->comm_world); + if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); + } else { +#endif + if (!(*comm)->myrank) printf("MC NOT initialized and used\n"); + (*comm)->mc_maxsize = 0; + (*comm)->mc_offset = 0; + (*comm)->use_mc = 0; +#if CUDART_VERSION >= 12010 + } +#endif + +#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) + // peer pointers + op flags + comm buffer + + NVTE_CHECK_CUDA(musaDeviceSynchronize()); + register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); + NVTE_CHECK_CUDA(musaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA(musaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA(musaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); + NVTE_CHECK_CUDA( + musaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + (*comm)->sms = 16; + (*comm)->threads = 1024; + +#define GPU_PAGE_SHIFT 16 +#define GPU_PAGE_SIZE (1UL << GPU_PAGE_SHIFT) +#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1) +#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) + + NVTE_CHECK_CUDA(musaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); + NVTE_CHECK_CUDA(musaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); + (*comm)->flags = + reinterpret_cast(((MUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); + + using namespace std; + + sched_param param; + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_getschedparam(&attr, ¶m); + param.sched_priority = sched_get_priority_max(SCHED_FIFO); + + pthread_attr_setschedparam(&attr, ¶m); + + if (getenv("NVTE_UBDEBUG")) + printf( + "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " + "%dx%d PIPE_ID %d/%d\n", + myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, + (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, + (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, + pipegpus * pipenodes); + fflush(NULL); + + return 0; +} + +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes) { + return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); +} + +int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier) { + return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + ext_allgather, ext_barrier, 1, 1, 1, 1); +} + +int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, + int tensorgpus, int tensornodes) { +#ifdef NVTE_UB_WITH_MPI + // get global numbers + int myrank, numranks; + UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank)); + UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks)); + + // find intranode numbers and make internode communicator + char hostname[MPI_MAX_PROCESSOR_NAME]; + int namelen; + UB_MPI_CHECK(MPI_Get_processor_name(hostname, &namelen)); + + char(*hostnames)[MPI_MAX_PROCESSOR_NAME] = + static_cast(malloc(numranks * MPI_MAX_PROCESSOR_NAME)); + strcpy(hostnames[myrank], hostname); // NOLINT(*) + for (int n = 0; n < numranks; n++) + UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD)); + qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp); + + int color = 0; + for (int n = 0; n < numranks; n++) { + if (n > 0 && strcmp(hostnames[n - 1], hostnames[n])) color++; + if (strcmp(hostname, hostnames[n]) == 0) break; + } + free(hostnames); + + int mylocal, numlocal; + UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, color, myrank, &EXT_COMM_INTRA)); + UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal)); + UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal)); + + // find internode numbers and make internode communicator + NVTE_CHECK_CUDA(musaFree(0)); + int allnodes = numranks / numlocal; + int datanodes = allnodes / pipenodes / tensornodes; + // data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1 + int datanodegroup_id = myrank / numlocal / datanodes; + // mpi communicator only needed for SHARP which is always allreduce1/data-parallel + UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, myrank, + &EXT_COMM_INTER)); + // different rails from same group are in different subcommunicators + int mynode, numnodes; + UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTER, &numnodes)); + UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTER, &mynode)); + + // finally call the abstracted constructor with MPI info + return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + &ub_mpi_allgather, &ub_mpi_barrier, pipegpus, pipenodes, + tensorgpus, tensornodes); +#else + NVTE_ERROR(std::string("Bootstrapping Userbuffers with MPI requires building") + + std::string("Transformer Engine with NVTE_UB_WITH_MPI=1 and MPI_HOME=/path/to/mpi")); +#endif +} + +int create_communicator_grouped_mpi(communicator **comm, int pipegpus, int pipenodes) { + return create_communicator_grouped2_mpi(comm, pipegpus, pipenodes, 1, 1); +} + +int create_communicator_mpi(communicator **comm) { + return create_communicator_grouped2_mpi(comm, 1, 1, 1, 1); +} + +void destroy_communicator(communicator *comm) { + for (int hndl = 0; hndl < comm->free_region; hndl++) { + if (comm->use_mc && comm->mem_dealloc[hndl]) { + for (int rank = 0; rank < comm->nvsize; rank++) { + if (rank == comm->nvrank) { + NVTE_CALL_CHECK_CUDA_DRIVER(muMemRelease, comm->uchandles[hndl][rank]); + } else { + comm->uchandles[hndl][rank] = 0; + } + } + free(reinterpret_cast(comm->uchandles[hndl])); + } else { + for (int rank = 0; rank < comm->nvsize; rank++) { + if (rank != comm->nvrank) { + musaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]); + } else if (comm->mem_dealloc[hndl]) { + NVTE_CHECK_CUDA(musaFree(comm->peer_ptr[hndl][rank])); + } else { + comm->peer_ptr[hndl][rank] = nullptr; // remove reference to external buffer + } + } + } + free(comm->peer_ptr[hndl]); + comm->mem_ptr[hndl] = nullptr; + } + musaFree(reinterpret_cast(comm->recv_id)); + musaFree(reinterpret_cast(comm->send_id)); + if (comm->use_mc) { + NVTE_CALL_CHECK_CUDA_DRIVER(muMemRelease, comm->mc_handle); + } + free(comm->fifo); + delete comm; +} + +void destroy_communicator_mpi(communicator *comm) { +#ifdef NVTE_UB_WITH_MPI + MPI_Comm_free(static_cast(&(comm->comm_inter))); + MPI_Comm_free(static_cast(&(comm->comm_intra))); + destroy_communicator(comm); +#else + NVTE_ERROR(std::string("Communicator is not bootstrapped with MPI and ") + + std::string("can only be deallocated with destroy_communicator().")); +#endif +} + +int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { + if (comm->free_region > NVTE_MAX_REGIONS) return -1; + int hndl = comm->free_region; + comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); + size_t aligned_size = bytes; + comm->memflags[hndl] = 0; + comm->mem_dealloc[hndl] = alloc; + +#if CUDART_VERSION >= 12010 + if (comm->use_mc && alloc) { + int nranks = comm->nvsize; // total GPUs in NVLINK domain + int myrank = comm->nvrank; + void **remptrs = reinterpret_cast(malloc(nranks * sizeof(void *))); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = comm->mydev; + prop.requestedHandleTypes = + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; // CU_MEM_HANDLE_TYPE_FABRIC; + + size_t granularity = 0; + NVTE_CALL_CHECK_CUDA_DRIVER( + muMemGetAllocationGranularity, &granularity, &prop, + static_cast(CU_MULTICAST_GRANULARITY_MINIMUM)); + // MPI_Allreduce MAX of granularity check + aligned_size = (bytes + granularity - 1) / granularity * granularity; + + if (comm->use_mc) { + CUmulticastObjectProp mcProp = {}; + mcProp.numDevices = nranks; + mcProp.size = aligned_size; + mcProp.handleTypes = prop.requestedHandleTypes; + NVTE_CALL_CHECK_CUDA_DRIVER( + muMulticastGetGranularity, &granularity, &mcProp, + static_cast(CU_MULTICAST_GRANULARITY_MINIMUM)); + aligned_size = (aligned_size + granularity - 1) / granularity * granularity; + } + + prop.location.id = comm->mydev; + comm->uchandles[hndl] = reinterpret_cast( + malloc(nranks * sizeof(CUmemGenericAllocationHandle))); + NVTE_CALL_CHECK_CUDA_DRIVER(muMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop, + (uint64_t)0); + + int *peerfd = reinterpret_cast(malloc(nranks * sizeof(int))); + NVTE_CALL_CHECK_CUDA_DRIVER( + muMemExportToShareableHandle, reinterpret_cast(&peerfd[myrank]), + comm->uchandles[hndl][myrank], + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); + + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafebeef; + ipcSocketResult_t ret = ipcSocketSuccess; + + // All-gather POSIX file descriptors across local ranks + IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); + for (int p = 1; p < nranks; p++) { + int send_to = (myrank + p) % nranks; + int recv_from = (myrank + nranks - p) % nranks; + comm->_barrier(comm->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, error); + IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); + } + + error: + IPCCHECK(ipcSocketClose(&ipcSock)); + + for (int p = 0; p < nranks; p++) { + if (p != myrank) + NVTE_CALL_CHECK_CUDA_DRIVER( + muMemImportFromShareableHandle, &comm->uchandles[hndl][p], + reinterpret_cast(peerfd[p]), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(peerfd[p]); + } + free(peerfd); + + MUdeviceptr ptr; + NVTE_CALL_CHECK_CUDA_DRIVER(muMemAddressReserve, &ptr, (size_t)(aligned_size * nranks), + (size_t)0, (MUdeviceptr)0, (uint64_t)0); + comm->ucbase_ptr[hndl] = reinterpret_cast(ptr); + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + accessDesc.location.id = comm->mydev; + + for (int i = 0; i < nranks; i++) { + remptrs[i] = reinterpret_cast(ptr + (aligned_size * i)); + NVTE_CALL_CHECK_CUDA_DRIVER(muMemMap, reinterpret_cast(remptrs[i]), aligned_size, + (size_t)0, comm->uchandles[hndl][i], (uint64_t)0); + if (i == comm->nvrank) { + if (hndl) + *gpubuff = remptrs[i]; + else + comm->gpu_ptrs = remptrs[i]; + } + comm->peer_ptr[hndl][i] = remptrs[i]; + } + NVTE_CALL_CHECK_CUDA_DRIVER(muMemSetAccess, ptr, (size_t)(aligned_size * nranks), + const_cast(&accessDesc), (size_t)1); + + if (hndl == 0) NVTE_CHECK_CUDA(musaMemset(comm->gpu_ptrs, 0, aligned_size)); + NVTE_CHECK_CUDA( + musaMemcpy((reinterpret_cast(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)), + remptrs, nranks * sizeof(void *), musaMemcpyHostToDevice)); + free(remptrs); + comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED; + + if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) { + NVTE_CALL_CHECK_CUDA_DRIVER(muMulticastBindMem, comm->mc_handle, comm->mc_offset, + comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/, + aligned_size, (uint64_t)0); + comm->memflags[hndl] |= UB_MEM_MC_CREATED; + comm->mc_ptr[hndl] = reinterpret_cast(comm->mc_baseptr) + comm->mc_offset; + comm->mc_offset += aligned_size; + } else if (!comm->myrank) { + printf("UB: warning region %d size %ld MB registered without MC access\n", hndl, + aligned_size / 1024 / 1024); + } + + } else { +#endif + if (alloc) { + NVTE_CHECK_CUDA(musaMalloc(gpubuff, bytes)); + NVTE_CHECK_CUDA(musaMemset(*gpubuff, 0, bytes)); + } + + NVTE_CHECK(comm->nvsize <= 8, "CUDA IPC supports only up to 8 GPUs in an NVLink domain."); + musaIpcMemHandle_t memhndl; + NVTE_CHECK_CUDA(musaIpcGetMemHandle(&memhndl, *gpubuff)); + + musaIpcMemHandle_t *tmp = + reinterpret_cast(malloc(comm->nvsize * sizeof(musaIpcMemHandle_t))); + comm->_allgather(reinterpret_cast(tmp), comm->nvsize * sizeof(musaIpcMemHandle_t), + reinterpret_cast(&memhndl), sizeof(musaIpcMemHandle_t), + comm->comm_intra); + + for (int i = 0; i < comm->nvsize; i++) { + if (i != comm->nvrank) { + NVTE_CHECK_CUDA(musaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + musaIpcMemLazyEnablePeerAccess)); + } + } + comm->peer_ptr[hndl][comm->nvrank] = *gpubuff; + NVTE_CHECK_CUDA(musaDeviceSynchronize()); + + NVTE_CHECK_CUDA(musaMemcpy( + reinterpret_cast(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)), + comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), musaMemcpyHostToDevice)); + + NVTE_CHECK_CUDA(musaDeviceSynchronize()); + free(tmp); +#if CUDART_VERSION >= 12010 + } +#endif + comm->mem_size[hndl] = aligned_size; + + comm->mem_ptr[hndl] = *gpubuff; + + return comm->free_region++; +} diff --git a/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.h new file mode 100644 index 0000000000..389d52a89f --- /dev/null +++ b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -0,0 +1,335 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_USERBUFFERS_H_ +#define TRANSFORMER_ENGINE_USERBUFFERS_H_ + +#include +#include +#include + +#include +#include +#include + +#include "common/util/logging.h" +#include "transformer_engine/transformer_engine.h" + +#ifdef NVTE_UB_WITH_MPI +#include +#define ExtComm MPI_Comm +#else +#define ExtComm const char * +#endif + +using ExtAllgatherOp = std::function; +using ExtBarrierOp = std::function; + +#define NVTE_MAX_REGIONS 16 +#define NVTE_MAX_SMS 32 +#define NVTE_MAX_OPS 32 +#define NVTE_MAX_PEERS 8192 +#define NVTE_MAX_REQUESTS 1024 +#define NVTE_LAUNCH_GPU 1 +#define NVTE_LAUNCH_CPU 2 +#define NVTE_MAX_NVLINK 8 + +#define UB_MEM_UC_CONTIG 1 +#define UB_MEM_MC_CREATED 2 +#define UB_MEM_ALLOCATED 4 + +#define NVTE_UB_MEM_UC_CONTIG 1 +#define NVTE_UB_MEM_MC_CREATED 2 +#define NVTE_UB_MEM_ALLOCATED 4 + +// region 0 flag offsets +#define NVTE_REG0_OPFLAGS 1024 +#define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types) +#define NVTE_REG0_SINGLENODE (2 * NVTE_MAX_NVLINK * NVTE_MAX_SMS + NVTE_MAX_OPS) +#define NVTE_REG0_OFFSET(comm) \ + ((2 * NVTE_MAX_REGIONS) * NVTE_MAX_NVLINK + NVTE_REG0_SINGLENODE * 2 + NVTE_MAX_PEERS) +#define NVTE_REG0_COMMBUFFER 0 +// x3 for [flagptr, ce_start_ptr, ce_end_ptr] +#define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS * 3) +#define NVTE_REG0_IBRS 32 +#define NVTE_REG0_IBAG 512 + +#if defined(UCP) || !defined(NOSHARP) +#undef REG0_COMMBUFFER +#define REG0_COMMBUFFER (1024 * 1024 * 16) +#endif +// gpuflags map offsets +#define NVTE_GF_STATE 16000 +#define NVTE_GF_IBSHARPDONE 0 +#define NVTE_HF_NVRSDONE (userbuffers_op_types + 1) +#define NVTE_HF_NVREDUCEDONE (userbuffers_op_types + 3) +#define NVTE_MAX_SHARP 16 + +typedef struct ub_request { + int optype; + int blocksize; + int basecounter; + int elements; + int handler; + int handler2; + size_t offset; + size_t offset2; + int peer; + // ----execution states + int active, maxcredit; + int nblock, numblocks, unconfirmed_ib_in_flight; +} ub_request; + +enum req_type { + userbuffers_allreduceop_sharp, + userbuffers_sendop, + userbuffers_allreduceop_nonsharp, + userbuffers_allreduceop_nonsharp2, + userbuffers_alltoall, + userbuffers_op_types +}; + +struct communicator { + int myrank, nranks; // global job communicator + int nvrank, nvsize; // single node comm_intra + int free_region; + + int launch_mode; + + void *gpu_ptrs; + int sms, threads; + int use_rr_kernel; // Whether to use RR (or RW) for NVLink-only kernel + int cga_size; + int push, use_ce; + + void *mem_ptr[NVTE_MAX_REGIONS]; + void **peer_ptr[NVTE_MAX_REGIONS]; + + int memflags[NVTE_MAX_REGIONS]; // UC,MC, user/lib allocated + + MUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS]; + void *ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory + size_t mem_size[NVTE_MAX_REGIONS]; + bool mem_dealloc[NVTE_MAX_REGIONS]; + + void *mc_ptr[NVTE_MAX_REGIONS]; + void *mc_baseptr; + MUmemGenericAllocationHandle mc_handle; + size_t mc_offset, mc_maxsize; + int use_mc; // 1: use MC if available, 0: override not to use MC + + int ar_nvsize, ar_firstgpu, + ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup + // (_splitar init used) would be equal to (nvsize,0) for regular comm_create + int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step + int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size) + int sm_arch; + int num_nodes, my_node, + first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes) + int num2_nodes, my2_node, first2_node; // with num_nodes as a stride + // max value for running block counters in hostflags + int basecounter[userbuffers_op_types]; // NOLINT(*) + + int *flags, *map_flags; + + void *mem_mr[NVTE_MAX_REGIONS]; + + ub_request *fifo; + int nblocks, alignblock, minblock, asyncblocks, active_nreqs; + ub_request active_req[userbuffers_op_types]; // NOLINT(*) + int padding[7]; + volatile int head; + int padding2[15]; + volatile int tail; + + // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) + ExtAllgatherOp _allgather; + ExtBarrierOp _barrier; + + ExtComm comm_world; + ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail + ExtComm comm_intra; // full intranode (all ndev GPUS) +#ifdef NVTE_UB_WITH_MPI + MPI_Request mpihndl[NVTE_MAX_SHARP]; +#endif + + int *send_id, *recv_id; + int mydev; + uint64_t ub_timeout; +}; +typedef struct communicator communicator; + +void producer(void *atomic_ptr, int chunk_i, musaStream_t stream); +void consumer(void *atomic_ptr, int chunk_i, musaStream_t stream); +void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, musaStream_t stream); +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, musaStream_t stream); + +/* creates communicator, allocates all internal buffers if necessary */ +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes); + +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes); + +int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier); + +int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, + int tensorgpus, int tensornodes); + +int create_communicator_grouped_mpi(communicator **comm, int pipegpus, int pipenodes); + +int create_communicator_mpi(communicator **comm); + +void destroy_communicator(communicator *comm); + +void destroy_communicator_mpi(communicator *comm); + +// int check_user_buffer_registration(void* gpubuff, int bytes, communicator* comm, size_t* offset); +/* + local calls, doesnt communicate between peers + returns handler if buffer is registered already, or -1 if not. + returned offset is offset of gpubuff relative to buffer registered +*/ + +int pipe_rank(communicator *comm, + int step); // helper function to help walk across allreduce1 x allreduce2 groups + // data-parallel and tensor-parallel position within data and tensor + // groups would be preserved + +int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc); +/* returns handler and registers buffers. assumed to be collective i.e. you use same groups and + dont mix buffers for different operations returns -1 if cant register (too many preregistered + regions already) if alloc==true will allocate memory and fill the pointers (required for NVL + SHARP and NSO/MNNVL) +*/ + +// for TP-parallelism, only single node is implemented +void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, musaStream_t stream = 0, + musaEvent_t comm_launch_event = 0); +/* +each Rank input is +allgather2_userbuff_inplace: offset+myrank*elements +allgather2_userbuff_inplace_sliced: offset+myrank*elements*nslices+slice_id*elements + +equivalent codes would be: +for(int slice=0;slice +void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements, + communicator *comm, musaStream_t stream = 0, + musaEvent_t comm_launch_event = 0); +template +void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, + const int elements, communicator *comm, musaStream_t stream = 0, + musaEvent_t comm_launch_event = 0); +template +void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements_out, + const int strideelements_in, const int numchunks, + void *counters, communicator *comm, + musaStream_t stream = 0); +template +void reducescatter2_userbuff_strided_multiatomic_fp8( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, musaStream_t stream = 0); +void reducescatter2_userbuff_strided(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, communicator *comm, + musaStream_t stream = 0); +void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + musaStream_t stream = 0); +void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + musaStream_t stream = 0); +/* everything should be 16byte aligned = 8 elts aligned +output is strided: row starts separated by stride elements*/ + +/* inplace allreduce: works only with buffers registered by previous call. offset should be same + * for all peers */ + +// two matching pairs, intended to work as push from sender or pull by receiver +// either way signal is a write by sender meaning +// push model: data arrived and visible at receiver(barrier enforced) +// pull model: data ready to be pulled by receiver(no barrier needed) + +void comm_userbuff_over_ce(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const int elements, const int comm_bytes, + communicator *comm, const int send_peer, const int recv_peer, + transformer_engine::DType dtype, const int _tp_id, musaStream_t stream = 0); +void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes, communicator *comm, + const int peer, musaStream_t stream = 0); +void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes, communicator *comm, + const int peer, musaStream_t stream = 0); +void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset, + const size_t recv_offset, const size_t bytes, communicator *comm, + const int send_peer, const int recv_peer, musaStream_t stream = 0); +void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, + const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, void *counters, musaStream_t stream = 0); +void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, + const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, const int nchunks, void *counters, + bool shuffle, musaStream_t stream = 0); + +// alltoall split send and recv to allow for overlap +// send kicks in sending data to the destination - invoke on same stream as data generation +// recv returns once data has received +// send and recv can be on different streams +// void userbuffers_alltoall_send(const int srchandler, const size_t srcoffset, const int dsthandler, +// const size_t dstoffset, const size_t bytes, communicator *comm, +// musaStream_t stream = 0); +// void userbuffers_alltoall_recv(communicator *comm, musaStream_t stream = 0); + +// void unregister_user_buffer(int handler); + +void destroy_communicator(communicator *comm); + +template +void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size, + musaStream_t stream); + +void reduce_bf16(void *input, void *output, int num_inputs, int input_size, musaStream_t stream); + +#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.mu b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.mu new file mode 100644 index 0000000000..b2b1c04496 --- /dev/null +++ b/transformer_engine/musa/common/comm_gemm_overlap/userbuffers/userbuffers.mu @@ -0,0 +1,2790 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "common/util/system.h" +#include "userbuffers.h" + +const uint64_t global_idf = 0ULL; +#define MAX_THREADS 1024 + +#define CHECK_MUSA_DRIVER(cmd) \ + do { \ + MUresult err = cmd; \ + if (err != MUSA_SUCCESS) { \ + const char *errStr; \ + muGetErrorString(err, &errStr); \ + fprintf(stderr, "MUSA Driver Error at %d:%s\n %s\n", __LINE__, __FILE__, errStr); \ + exit(1); \ + } \ + } while (0) + + +// TODO(yuzhe.wu): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified +#define ATOMIC_CONSUMER(chunk) \ + if (counters) { \ + if (threadIdx.x == 0 && blockIdx.x == 0) { \ + while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + } \ + ((unsigned int *)counters)[chunk] = 1; \ + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); \ + } \ + if (blockIdx.x == 0) __syncthreads(); \ + } + +#define ATOMIC_PRODUCER(chunk) \ + if (counters) { \ + ((unsigned int *)counters)[chunk] = 0; \ + } + +// Return true if producer > consumer, otherwise false while preventing integer overflow +// If we expect that producer will be 2B+ messages behind consumer +#define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX)) + +// Strip the path from a full filename +// #define FILENAME(file) \ +// ({ \ +// const char *filename = file; \ +// const char *basename = filename; \ +// for (const char *ptr = filename; *ptr != '\0'; ptr++) { \ +// if (*ptr == '/' || *ptr == '\\') { \ +// basename = ptr + 1; \ +// } \ +// } \ +// basename; \ +// }) + +// Printf to provide enough information so it is easier to attribute failures +#define UB_PRINT(message, ...) \ + printf("[%s:%s:%d] " message "\n", __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__) // FIXME(yehua.zhang): FILENAME(__FILE__) compile err because of mcc + +// Report and error on timeout +#define CHECK_TIMEOUT(t, timeout) ((clock64() - (t)) > timeout) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx, + const uint64_t ub_timeout) { +// __shared__ int4 *userptr[RANKS]; +// int *flagptr, physgpu, targetgpu, *myptr; +// int *reduceidptr, reduce_id; + +// if (threadIdx.x < RANKS) { +// physgpu = myrank * gpustep + firstrank; +// targetgpu = threadIdx.x * gpustep + firstrank; +// const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; +// myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; +// reduceidptr = myptr - NVTE_MAX_OPS; // +op; +// reduce_id = (*reduceidptr) + 1; +// flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; +// myptr += blockflagoffset; + +// flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&(myptr[targetgpu]); +// userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (CHECK_TIMEOUT(s, ub_timeout)) { +// UB_PRINT("[%d] Allreduce reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, +// blockIdx.x, threadIdx.x, reduce_id, *flag); +// break; +// } +// } +// reduce_id++; +// } +// __syncthreads(); + +// int warp = blockIdx.x + (threadIdx.x >> 5); +// int dest[RANKS]; +// #pragma unroll +// for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + +// __syncthreads(); +// for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; +// line += blockDim.x * gridDim.x * RANKS) { +// int4 val[RANKS]; + +// #pragma unroll +// for (int i = 0; i < RANKS; i++) { +// // int dest = (i+myrank+warp)&(RANKS-1); +// val[i] = userptr[dest[i]][lineoffset + line]; +// } + +// int4 sum = val[0]; +// half *s = reinterpret_cast(&sum); + +// #pragma unroll +// for (int i = 1; i < RANKS; i++) { +// half *x = reinterpret_cast(&val[i]); +// #pragma unroll +// for (int j = 0; j < 8; j++) s[j] += x[j]; +// } +// #pragma unroll +// for (int i = 0; i < RANKS; i++) { +// // int dest = (i+myrank+warp)&(RANKS-1); +// userptr[dest[i]][lineoffset + line] = sum; +// } +// } + +// __syncthreads(); +// if (threadIdx.x == 0) __threadfence_system(); +// __syncthreads(); + +// if (threadIdx.x < RANKS) { +// flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&myptr[targetgpu]; +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (CHECK_TIMEOUT(s, ub_timeout)) { +// UB_PRINT("[%d] Allreduce Gather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, +// threadIdx.x, reduce_id, *flag); +// break; +// } +// } +// } +// if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Volta,Hopper) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx, + const uint64_t ub_timeout) { +// __shared__ int4 *userptr[RANKS]; +// int *flagptr, physgpu, targetgpu, *myptr; +// int *reduceidptr, reduce_id; +// if (threadIdx.x < RANKS) { +// physgpu = myrank * gpustep + firstrank; +// targetgpu = threadIdx.x * gpustep + firstrank; +// const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; +// myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; +// reduceidptr = myptr - NVTE_MAX_OPS; // +op; +// reduce_id = (*reduceidptr) + 1; +// flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; +// myptr += blockflagoffset; + +// flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&(myptr[targetgpu]); +// userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (CHECK_TIMEOUT(s, ub_timeout)) { +// UB_PRINT("[%d ]Allreduce reduce-scatter:SM %d [%d]: expecting %d got %d", myrank, +// blockIdx.x, threadIdx.x, reduce_id, *flag); +// break; +// } +// } +// reduce_id++; +// } +// __syncthreads(); + +// int warp = blockIdx.x + (threadIdx.x >> 5); +// int dest[RANKS]; +// #pragma unroll +// for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + +// __syncthreads(); +// for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; +// line += blockDim.x * gridDim.x * RANKS) { +// int4 val[RANKS]; + +// #pragma unroll +// for (int i = 0; i < RANKS; i++) { +// val[i] = userptr[dest[i]][lineoffset + line]; +// } + +// int4 sum = val[0]; +// half *s = reinterpret_cast(&sum); + +// #pragma unroll +// for (int i = 1; i < RANKS; i++) { +// half *x = reinterpret_cast(&val[i]); +// #pragma unroll +// for (int j = 0; j < 8; j++) s[j] += x[j]; +// } + +// userptr[myrank][lineoffset + line] = sum; +// } +// __syncthreads(); +// if (threadIdx.x == 0) __threadfence(); +// __syncthreads(); + +// if (threadIdx.x < RANKS) { +// flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&myptr[targetgpu]; +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (CHECK_TIMEOUT(s, ub_timeout)) { +// UB_PRINT("[%d] Allreduce gather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, +// threadIdx.x, reduce_id, *flag); +// break; +// } +// } +// } + +// int skipmy = 0; +// #pragma unroll +// for (int i = 0; i < RANKS; i++) { +// int dst = (i + warp + myrank) & (RANKS - 1); +// if (dst == myrank) { +// skipmy++; +// continue; +// } +// dest[i - skipmy] = dst; +// } +// __syncthreads(); + +// for (int line = threadIdx.x + blockDim.x * RANKS * blockIdx.x; line < numlines; +// line += blockDim.x * gridDim.x * RANKS) { +// int4 val[RANKS - 1]; + +// #pragma unroll +// for (int i = 0; i < RANKS - 1; i++) { +// val[i] = userptr[dest[i]][lineoffset + line + blockDim.x * dest[i]]; +// } + +// #pragma unroll +// for (int i = 0; i < RANKS - 1; i++) { +// userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i]; +// } +// } +// if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Ampere) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, + const uint64_t ub_timeout) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][mylineoffset + line]; + } + + int4 sum = val[0]; + HALF_TYPE *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + HALF_TYPE *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + userptr[myrank][mylineoffset + line] = sum; + } + + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 inplace reduce-scatter kernel + +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + void **commbuff, const int handleridx, void *outbuf, const uint64_t ub_timeout) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][mylineoffset + line]; + } + + int4 sum = val[0]; + HALF_TYPE *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + HALF_TYPE *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + (reinterpret_cast(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum; + } + + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) + +#if __MUSA_ARCH__ >= 900 +// All MC kernels here +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx, + float4 *mc_ptr, const uint64_t ub_timeout) { +// int *flagptr, physgpu, targetgpu, *myptr; +// int *reduceidptr, reduce_id; + +// if (threadIdx.x < RANKS) { +// physgpu = myrank * gpustep + firstrank; +// targetgpu = threadIdx.x * gpustep + firstrank; +// const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x; +// myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; +// reduceidptr = myptr - NVTE_MAX_OPS; // +op; +// reduce_id = (*reduceidptr) + 1; +// flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; +// myptr += blockflagoffset; + +// flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&(myptr[targetgpu]); +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (clock64() - s > ub_timeout) { +// UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, +// reduce_id, *flag); +// break; +// } +// } +// reduce_id++; +// } +// __syncthreads(); +// #define UNROLL_MC 8 +// const int loop_step0 = blockDim.x * gridDim.x * RANKS; +// const int loop_step = loop_step0 * UNROLL_MC; +// const int start_elem = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); +// const int end_elem = max(start_elem, numlines); +// const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; +// const int end_aligned = start_elem + aligned_elem; + +// for (int line = start_elem; line < end_aligned; line += loop_step) { +// uint4 val[UNROLL_MC]; +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) +// #if defined(NVTE_UB_FP16) +// asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) +// : "l"(mc_ptr + (lineoffset + line + i * loop_step0)) +// : "memory"); +// #else +// asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) +// : "l"(mc_ptr + (lineoffset + line + i * loop_step0)) +// : "memory"); +// #endif +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) +// asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"( +// mc_ptr + (lineoffset + line + i * loop_step0)), +// "r"(val[i].x), "r"(val[i].y), "r"(val[i].z), "r"(val[i].w) +// : "memory"); +// } +// for (int line = end_aligned; line < end_elem; line += loop_step0) { +// uint4 val; +// #if defined(NVTE_UB_FP16) +// asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) +// : "l"(mc_ptr + (lineoffset + line)) +// : "memory"); +// #else +// asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) +// : "l"(mc_ptr + (lineoffset + line)) +// : "memory"); +// #endif +// asm volatile( +// "multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr + (lineoffset + line)), +// "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) +// : "memory"); +// } + +// __syncthreads(); +// if (threadIdx.x == 0) __threadfence_system(); +// __syncthreads(); + +// if (threadIdx.x < RANKS) { +// flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&myptr[targetgpu]; +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (clock64() - s > 2ull * ub_timeout) { +// UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, +// *flag); +// break; +// } +// } +// } +// if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Hopper) MC + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, float4 *mc_ptr, + const uint64_t ub_timeout) { +// volatile int *flagptr; +// int physgpu, targetgpu, *myptr; +// int *reduceidptr, reduce_id; +// uint4 *localptr = reinterpret_cast(commbuff[myrank * gpustep + firstrank + handleridx]); +// int lastSM = 0; + +// if (threadIdx.x < RANKS) { +// physgpu = myrank * gpustep + firstrank; +// targetgpu = threadIdx.x * gpustep + firstrank; +// myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; +// reduceidptr = myptr - NVTE_MAX_OPS; // +op; +// reduce_id = (*reduceidptr) + 1; +// flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; +// if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&(myptr[targetgpu]); +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (CHECK_TIMEOUT(s, ub_timeout)) { +// UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, +// threadIdx.x, reduce_id, *flag); +// break; +// } +// } +// } +// __syncthreads(); +// if (threadIdx.x == 0) { +// const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; +// int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); +// if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; +// } +// const int loop_step0 = blockDim.x * gridDim.x; +// const int loop_step = loop_step0 * UNROLL_MC; +// const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; +// const int end_elem = max(start_elem, totallines); +// const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; +// const int end_aligned = start_elem + aligned_elem; + +// for (int line = start_elem; line < end_aligned; line += loop_step) { +// uint4 val[UNROLL_MC]; +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) +// #if defined(NVTE_UB_FP16) +// asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) +// : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) +// : "memory"); +// #else +// asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) +// : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) +// : "memory"); +// #endif +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) localptr[mylineoffset + line + i * loop_step0] = val[i]; +// } +// for (int line = end_aligned; line < end_elem; line += loop_step0) { +// uint4 val; +// #if defined(NVTE_UB_FP16) +// asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) +// : "l"(mc_ptr + (mylineoffset + line)) +// : "memory"); +// #else +// asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) +// : "l"(mc_ptr + (mylineoffset + line)) +// : "memory"); +// #endif +// localptr[mylineoffset + line] = val; +// } + +// if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 inplace reduce-scatter kernel MC + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, + const int totallines, const int rowlines, + const int skiplines, void **commbuff, + const int handleridx, void *outbuf, float4 *mc_ptr, + const uint64_t ub_timeout) { +// volatile int *flagptr; +// int physgpu, targetgpu, *myptr; +// int *reduceidptr, reduce_id; +// int lastSM = 0; + +// if (threadIdx.x < RANKS) { +// physgpu = myrank * gpustep + firstrank; +// targetgpu = threadIdx.x * gpustep + firstrank; +// myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; +// reduceidptr = myptr - NVTE_MAX_OPS; // +op; +// reduce_id = (*reduceidptr) + 1; +// flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; +// if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&(myptr[targetgpu]); +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (CHECK_TIMEOUT(s, ub_timeout)) { +// UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, +// threadIdx.x, reduce_id, *flag); +// break; +// } +// } +// } +// __syncthreads(); +// if (threadIdx.x == 0) { +// const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; +// int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); +// if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; +// } + +// const int loop_step0 = blockDim.x * gridDim.x; +// const int loop_step = loop_step0 * UNROLL_MC; +// const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; +// const int end_elem = max(start_elem, totallines); +// const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; +// const int end_aligned = start_elem + aligned_elem; +// for (int line = start_elem; line < end_aligned; line += loop_step) { +// uint4 val[UNROLL_MC]; +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) +// #if defined(NVTE_UB_FP16) +// asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) +// : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) +// : "memory"); +// #else +// asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val[i].x), "=r"(val[i].y), "=r"(val[i].z), "=r"(val[i].w) +// : "l"(mc_ptr + (mylineoffset + line + i * loop_step0)) +// : "memory"); +// #endif +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) +// (reinterpret_cast(outbuf))[((line + i * loop_step0) / rowlines) * skiplines + +// ((line + i * loop_step0) % rowlines)] = val[i]; +// } +// for (int line = end_aligned; line < end_elem; line += loop_step0) { +// uint4 val; +// #if defined(NVTE_UB_FP16) +// asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) +// : "l"(mc_ptr + (mylineoffset + line)) +// : "memory"); +// #else +// asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];" +// : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) +// : "l"(mc_ptr + (mylineoffset + line)) +// : "memory"); +// #endif +// reinterpret_cast(outbuf)[(line / rowlines) * skiplines + (line % rowlines)] = val; +// } + +// if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) fp16 MC + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, uint4 *mc_ptr, + const uint64_t ub_timeout) { +// volatile int *flagptr; +// int physgpu, targetgpu, *myptr; +// int *reduceidptr, reduce_id; +// uint4 *localptr = reinterpret_cast(commbuff[myrank * gpustep + firstrank + handleridx]); + +// if (threadIdx.x < RANKS) { +// physgpu = myrank * gpustep + firstrank; +// targetgpu = threadIdx.x * gpustep + firstrank; +// myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; +// reduceidptr = myptr - NVTE_MAX_OPS; // +op; +// reduce_id = (*reduceidptr) + 1; +// flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; +// } +// __syncthreads(); + +// const int loop_step0 = blockDim.x * gridDim.x; +// const int loop_step = loop_step0 * UNROLL_MC; +// const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; +// const int end_elem = max(start_elem, totallines); +// const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; +// const int end_aligned = start_elem + aligned_elem; +// for (int line = start_elem; line < end_aligned; line += loop_step) { +// uint4 val[UNROLL_MC]; +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) val[i] = localptr[mylineoffset + line + i * loop_step0]; +// #pragma unroll +// for (int i = 0; i < UNROLL_MC; i++) +// asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"( +// mc_ptr + (mylineoffset + line + i * loop_step0)), +// "r"(val[i].x), "r"(val[i].y), "r"(val[i].z), "r"(val[i].w) +// : "memory"); +// } +// for (int line = end_aligned; line < end_elem; line += loop_step0) { +// uint4 val = localptr[mylineoffset + line]; +// asm volatile( +// "multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(mc_ptr + (mylineoffset + line)), +// "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) +// : "memory"); +// } + +// __syncthreads(); +// if (threadIdx.x == 0) __threadfence_system(); +// __syncthreads(); + +// __shared__ int lastSM; +// if (threadIdx.x == 0) { +// const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; +// int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); +// if (old_val + adder == NVTE_MAX_SMS * reduce_id) +// lastSM = 1; +// else +// lastSM = 0; +// } +// __syncthreads(); +// if (lastSM && threadIdx.x < RANKS) { +// if (threadIdx.x == 0) *reduceidptr = reduce_id; +// flagptr[physgpu] = reduce_id; +// volatile int *flag = (volatile int *)&myptr[targetgpu]; +// clock_t s = clock64(); +// while (CHECK_IDS(*flag, reduce_id)) { +// if (CHECK_TIMEOUT(s, ub_timeout)) { +// UB_PRINT("[%d] Allgather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, threadIdx.x, +// reduce_id, *flag); +// break; +// } +// } +// } +} // fp16 inplace allgather kernel (Hopper) MC + +#else +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, const int lineoffset, + const int numlines, void **commbuff, const int handleridx, + float4 *mc_ptr, const uint64_t ub_timeout) {} +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, + const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, + const int totallines, const int rowlines, + const int skiplines, void **commbuff, + const int handleridx, void *outbuf, float4 *mc_ptr, + const uint64_t ub_timeout) {} + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, uint4 *mc_ptr, + const uint64_t ub_timeout) {} + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, float4 *mc_ptr, + const uint64_t ub_timeout) {} +#endif + +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + void **commbuff, const int handleridx, void *outbuf, float *scale, const uint64_t ub_timeout) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + half hscale = (half)*scale; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; + } + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][mylineoffset + line]; + } + + int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + fp8type *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); + } + int hline = 2 * line; + (reinterpret_cast(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = + sum[0]; + hline++; + (reinterpret_cast(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = + sum[1]; + } + + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) (fp8->fp16) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8( + const int op, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, const int totallines, const int rowlines, + const int skiplines_out, const int skiplines_in, void **commbuff, const int handleridx, + void *outbuf, float *scale, void *counters, const int numchunks, const int atomicindex, + const uint64_t ub_timeout) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + half hscale = (half)*scale; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + // const int blockflagoffset = MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr); + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; // + blockflagoffset; + } + + for (int chunk_i = 0; chunk_i < numchunks; chunk_i++) { + ATOMIC_CONSUMER(chunk_i); + + lastSM = 0; + if (threadIdx.x < RANKS) { + reduce_id++; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), /*numchunks * */ adder); + if (old_val + adder == NVTE_MAX_SMS * (reduce_id /* + numchunks*/)) lastSM = 1; + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + const int rowlines_in = rowlines / 2; + const int index_in = skiplines_in == 0 + ? mylineoffset + myrank * totallines + line + : (numchunks <= 1 ? 1 : chunk_i) * mylineoffset + + myrank * (totallines * skiplines_in / rowlines_in) + + (line / rowlines_in) * skiplines_in + (line % rowlines_in); + const int index1_out = chunk_i * mylineoffset * 2 + ((2 * line) / rowlines) * skiplines_out + + ((2 * line) % rowlines); + const int index2_out = chunk_i * mylineoffset * 2 + + ((2 * line + 1) / rowlines) * skiplines_out + + ((2 * line + 1) % rowlines); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } + + int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + fp8type *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); + } + (reinterpret_cast(outbuf))[index1_out] = sum[0]; + (reinterpret_cast(outbuf))[index2_out] = sum[1]; + } + } + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) (fp8->fp16) + +template +__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride( + const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, + const int mylineoffset, const int totallines, const int rowlines, const int skiplines, + void **commbuff, const int handleridx, void *outbuf, const uint64_t ub_timeout) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + int index_in = mylineoffset + myrank * (totallines * skiplines / rowlines) + + (line / rowlines) * skiplines + (line % rowlines); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + int index_out = (line / rowlines) * skiplines + (line % rowlines); + (reinterpret_cast(outbuf))[index_out] = sum; + } + + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) fp16 + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic( + const int op, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, const int totallines, const int rowlines, + const int skiplines, const int numchunks, void **commbuff, const int handleridx, + void *outbuf, void *counters, const uint64_t ub_timeout) { + if (counters) { + if (threadIdx.x == 0) { + // spin-lock on counter from producer + while (0 != (atomicCAS(((unsigned int *)counters), 0, 0))) { + } + + // make sure all threadblocks have read/waited on counters. + atomicInc(((unsigned int *)counters) + numchunks, gridDim.x - 1); + while (0 != (atomicCAS(((unsigned int *)counters) + numchunks, 0, 0))) { + } + + // reset counter for next producer. + ((unsigned int *)counters)[0] = 1; + // TODO(yuzhe.wu): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); + } + } + __syncthreads(); + + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + int index_in = mylineoffset + myrank * (totallines * skiplines / rowlines) + + (line / rowlines) * skiplines + (line % rowlines); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + int index_out = (line / rowlines) * skiplines + (line % rowlines); + (reinterpret_cast(outbuf))[index_out] = sum; + } + + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; +} // fp16 reduce-scatter kernel (out of place) fp16 + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic( + const int op, const int flagoffset, const int firstrank, const int myrank, + const int gpustep, const int mylineoffset, const int totallines, const int rowlines, + const int skiplines, const int numchunks, void **commbuff, const int handleridx, + void *outbuf, void *counters, const uint64_t ub_timeout) { + for (int chunk_i = 0; chunk_i < numchunks; chunk_i++) { + if (counters) { + if (threadIdx.x == 0) { + // spin-lock on counter from producer + while (0 != (atomicCAS(((unsigned int *)counters) + chunk_i, 0, 0))) { + } + + // make sure all threadblocks have read/waited on counters. + atomicInc(((unsigned int *)counters) + numchunks + chunk_i, gridDim.x - 1); + while (0 != (atomicCAS(((unsigned int *)counters) + numchunks + chunk_i, 0, 0))) { + } + + // reset counter for next producer. + ((unsigned int *)counters)[chunk_i] = 1; + // TODO(yuzhe.wu): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); + } + } + __syncthreads(); + + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int lastSM = 0; + + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&(myptr[targetgpu]); + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; + } + } + } + __syncthreads(); + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) lastSM = 1; + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS]; + int index_in = chunk_i * mylineoffset + myrank * (totallines * skiplines / rowlines) + + (line / rowlines) * skiplines + (line % rowlines); + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + val[i] = userptr[dest[i]][index_in]; + } + + int4 sum = val[0]; + half *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + half *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } + + int index_out = chunk_i * mylineoffset + (line / rowlines) * skiplines + (line % rowlines); + (reinterpret_cast(outbuf))[index_out] = sum; + } + if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; + } +} // fp16 reduce-scatter kernel (out of place) fp16 + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, + const uint64_t ub_timeout) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + clock_t s = clock64(); + } + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; + + int skipmy = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + int dst = (i + warp + myrank) & (RANKS - 1); + if (dst == myrank) { + skipmy++; + continue; + } + dest[i - skipmy] = dst; + } + __syncthreads(); + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines; + line += blockDim.x * gridDim.x) { + int4 val[RANKS - 1]; + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + val[i] = userptr[dest[i]][mylineoffset + line + totallines * dest[i]]; + } + +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[myrank][mylineoffset + line + totallines * dest[i]] = val[i]; + } + } + __shared__ int lastSM; + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + else + lastSM = 0; + } + __syncthreads(); + if (lastSM && threadIdx.x < RANKS) { + if (threadIdx.x == 0) *reduceidptr = reduce_id; + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allgather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, threadIdx.x, + reduce_id, *flag); + break; + } + } + } +} // fp16 inplace reduce kernel (Ampere) + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank, + const int myrank, const int gpustep, + const int mylineoffset, const int totallines, + void **commbuff, const int handleridx, + const uint64_t ub_timeout) { + __shared__ int4 *userptr[RANKS]; + volatile int *flagptr; + int physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + int4 *localptr; + if (threadIdx.x < RANKS) { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - NVTE_MAX_OPS; // +op; + reduce_id = (*reduceidptr) + 1; + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; + userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); + } + __syncthreads(); + localptr = userptr[myrank]; + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS - 1]; + int skipmy = 0; +#pragma unroll + for (int i = 0; i < RANKS; i++) { + int dst = (i + warp + myrank) & (RANKS - 1); + if (dst == myrank) { + skipmy++; + continue; + } + dest[i - skipmy] = dst; + } +#define UNROLLAG 4 + __syncthreads(); + const int loop_step0 = blockDim.x * gridDim.x; + const int loop_step = loop_step0 * UNROLLAG; + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = max(start_elem, totallines); + const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + const int end_aligned = start_elem + aligned_elem; + + for (int line = start_elem; line < end_aligned; line += loop_step) { + int4 val[UNROLLAG]; +#pragma unroll + for (int j = 0; j < UNROLLAG; j++) val[j] = localptr[mylineoffset + line + loop_step0 * j]; + +#pragma unroll + for (int j = 0; j < UNROLLAG; j++) +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[dest[i]][mylineoffset + line + j * loop_step0] = val[j]; + } + } + + for (int line = end_aligned; line < end_elem; line += loop_step0) { + int4 sum = localptr[mylineoffset + line]; +#pragma unroll + for (int i = 0; i < RANKS - 1; i++) { + userptr[dest[i]][mylineoffset + line] = sum; + } + } + + __syncthreads(); + if (threadIdx.x == 0) __threadfence_system(); + __syncthreads(); + + __shared__ int lastSM; + if (threadIdx.x == 0) { + const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS - gridDim.x + 1 : 1; + int old_val = atomicAdd(myptr + (NVTE_MAX_NVLINK * 2), adder); + if (old_val + adder == NVTE_MAX_SMS * reduce_id) + lastSM = 1; + else + lastSM = 0; + } + __syncthreads(); + if (lastSM && threadIdx.x < RANKS) { + if (threadIdx.x == 0) *reduceidptr = reduce_id; + flagptr[physgpu] = reduce_id; + volatile int *flag = (volatile int *)&myptr[targetgpu]; + clock_t s = clock64(); + while (CHECK_IDS(*flag, reduce_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allgather: SM %d [%d]: expecting %d got %d", myrank, blockIdx.x, threadIdx.x, + reduce_id, *flag); + break; + } + } + } +} // fp16 inplace allgather kernel (Volta,Hopper) + +// #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ +// musaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ +// musaLaunchAttribute attribute_ub[2]; \ +// attribute_ub[1].id = musaLaunchAttributeClusterDimension; \ +// attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ +// attribute_ub[1].val.clusterDim.y = 1; \ +// attribute_ub[1].val.clusterDim.z = 1; \ +// attribute_ub[0].id = musaLaunchAttributeCooperative; \ +// cfg.attrs = attribute_ub; \ +// cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; + +// TODO(yuzhe.wu): Temporarily disable AttributeCooperative and +// enable it again when the driver supports musaLaunchAttributeCooperative. +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + musaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cfg.numAttrs = 0; + +// TODO Temporarily disable CompletionEvent +// #if (MUSART_VERSION >= 12030) +#if 0 +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[2].id = musaLaunchAttributeLaunchCompletionEvent; \ + attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3 +#else +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 +#endif + +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + musaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + musaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ + ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[1].id = musaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = musaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; + +#define callranks_ag(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + (comm->use_rr_kernel ? 0 : arg4 * arg7); \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint64_t arg10 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag \ + : userbuffers_fp16_sum_inplace_gpu_rw_ag), \ + kernelArgs)); \ + } + +#define callranks_agMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint4 *arg10 = reinterpret_cast(comm->mc_ptr[handler]); \ + uint64_t arg11 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_ag), kernelArgs)); \ + } + + #define callranks_rs(x, is_bf16) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint64_t arg10 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + if(is_bf16) { \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), \ + kernelArgs)); \ + } else { \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), \ + kernelArgs));} \ + } + +#define callranks_rsMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void *arg10 = comm->mc_ptr[handler]; \ + uint64_t arg11 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs), kernelArgs)); \ + } + + #define callranks_rs_oop(x, is_bf16) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + uint64_t arg13 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + if(is_bf16) { \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ + kernelArgs)); \ + } else { \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ + kernelArgs));} \ + } + +#define callranks_rs_oop_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + float *arg13 = scale; \ + uint64_t arg14 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8), \ + kernelArgs)); \ + } + +#define callranks_rs_oopMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + void *arg13 = comm->mc_ptr[handler]; \ + uint64_t arg14 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_atomic_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements_out / 8, \ + arg10 = strideelements_in / 16; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + float *arg14 = scale; \ + void *arg15 = counters; \ + int arg16 = numchunks, arg17 = atomicindex; \ + uint64_t arg18 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15), reinterpret_cast(&arg16), \ + reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast( \ + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_stride(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + uint64_t arg13 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_stride_atomic(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + uint64_t arg15 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ + NVTE_CHECK_CUDA(musaLaunchKernelExC( \ + &cfg, \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic), \ + kernelArgs)); \ + } + +#define callranks_rs_oop_stride_multiatomic(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + uint64_t arg15 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ + NVTE_CHECK_CUDA( \ + musaLaunchKernelExC(&cfg, \ + reinterpret_cast( \ + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic), \ + kernelArgs)); \ + } + +void reducescatter2_userbuff_strided(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, communicator *comm, + musaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) +} +void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + musaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) + callranks_rs_oop_stride_atomic(8) +} + +template +void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, + const int strideelements_out, + const int strideelements_in, const int numchunks, + const int atomicindex, void *counters, + communicator *comm, musaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + assert(comm->sm_arch >= 9); + if (elements < 128) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) +} + +template +void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements_out, + const int strideelements_in, const int numchunks, + void *counters, communicator *comm, + musaStream_t stream) { + reducescatter2_userbuff_strided_universal_fp8( + output, scale, handler, offset, rowelements, colelements, strideelements_out, + strideelements_in, 1, numchunks, counters /*nullptr*/, comm, stream); +} + +template +void reducescatter2_userbuff_strided_multiatomic_fp8( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, musaStream_t stream) { + reducescatter2_userbuff_strided_universal_fp8( + output, scale, handler, offset, rowelements, colelements, strideelements_out, + strideelements_in, numchunks, 0, counters /*nullptr*/, comm, stream); +} + +void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, const int numchunks, + void *counters, communicator *comm, + musaStream_t stream) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) + callranks_rs_oop_stride_multiatomic(8) +} + +void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, + communicator *comm, musaStream_t stream, + musaEvent_t comm_launch_event) { + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } + } +} + +void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements, + communicator *comm, const int slice_id, const int nslices, + musaStream_t stream) { + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + int peerelements = elements / ar_nvsize; + int saverrkernel = comm->use_rr_kernel; + comm->use_rr_kernel = 0; + allgather2_userbuff_inplace( + handler, offset + ar_nvrank * peerelements * (nslices - 1) + slice_id * peerelements, + elements, comm, stream); + comm->use_rr_kernel = saverrkernel; +} + +void reducescatter2_userbuff_inplace(const int handler, transformer_engine::DType dtype, + const int offset, const int elements, + communicator *comm, musaStream_t stream, + musaEvent_t comm_launch_event) { + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + assert(dtype == transformer_engine::DType::kFloat16 || dtype == transformer_engine::DType::kBFloat16); + + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + if (dtype == transformer_engine::DType::kFloat16) { + callranks_rs(2, 0) callranks_rs(4, 0) callranks_rs(8, 0) + } else { + callranks_rs(2, 1) callranks_rs(4, 1) callranks_rs(8, 1) + } + } + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + if (dtype == transformer_engine::DType::kFloat16) { + callranks_rs(2, 0) callranks_rs(4, 0) callranks_rs(8, 0) + } else { + callranks_rs(2, 1) callranks_rs(4, 1) callranks_rs(8, 1) + } + } + } +} +void reducescatter2_userbuff_stridedoutput(void *output, transformer_engine::DType dtype, + const int handler, const int offset, + const int rowelements, const int colelements, + const int strideelements, communicator *comm, + musaStream_t stream, musaEvent_t comm_launch_event) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + + if (elements < 64) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + assert(dtype == transformer_engine::DType::kFloat16 || dtype == transformer_engine::DType::kBFloat16); + + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + if (dtype == transformer_engine::DType::kFloat16) { + callranks_rs_oop(2, 0) callranks_rs_oop(4, 0) callranks_rs_oop(8, 0) + } else { + callranks_rs_oop(2, 1) callranks_rs_oop(4, 1) callranks_rs_oop(8, 1) + } + } + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + if (dtype == transformer_engine::DType::kFloat16) { + callranks_rs_oop(2, 0) callranks_rs_oop(4, 0) callranks_rs_oop(8, 0) + } else { + callranks_rs_oop(2, 1) callranks_rs_oop(4, 1) callranks_rs_oop(8, 1) + } + } + } +} +void reducescatter2_userbuff(void *output, transformer_engine::DType dtype, const int handler, const int offset, const int elements, + communicator *comm, musaStream_t stream, + musaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput(output, dtype, handler, offset, elements, 1, 0, comm, stream, + comm_launch_event); +} + +template +void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, + const int offset, const int rowelements, + const int colelements, const int strideelements, + communicator *comm, musaStream_t stream, + musaEvent_t comm_launch_event) { + const int elements = rowelements * colelements; + const int op = userbuffers_allreduceop_nonsharp2; + const int ar_firstgpu = + op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; + const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; + const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; + const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; + assert(comm->sm_arch >= 9); + if (elements < 128) return; + int sms = ar_nvsize == 1 ? 2 : comm->sms; + int warps = comm->threads / 32; + if (warps < ar_nvsize) warps = ar_nvsize; + + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } +} + +template void reducescatter2_userbuff_stridedoutput_fp8<__mt_fp8_e5m2>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements, communicator *comm, musaStream_t stream, + musaEvent_t comm_launch_event); + +template void reducescatter2_userbuff_stridedoutput_fp8<__mt_fp8_e4m3>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements, communicator *comm, musaStream_t stream, + musaEvent_t comm_launch_event); + +template +void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, + const int elements, communicator *comm, musaStream_t stream, + musaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput_fp8(output, scale, handler, offset, elements, 1, 0, + comm, stream, comm_launch_event); +} + +template void reducescatter2_userbuff_fp8<__mt_fp8_e5m2>(void *output, float *scale, + const int handler, const int offset, + const int elements, communicator *comm, + musaStream_t stream, + musaEvent_t comm_launch_event); +template void reducescatter2_userbuff_fp8<__mt_fp8_e4m3>(void *output, float *scale, + const int handler, const int offset, + const int elements, communicator *comm, + musaStream_t stream, + musaEvent_t comm_launch_event); + +template void reducescatter2_userbuff_strided_atomic_fp8<__mt_fp8_e4m3>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, musaStream_t stream); +template void reducescatter2_userbuff_strided_atomic_fp8<__mt_fp8_e5m2>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, musaStream_t stream); +template void reducescatter2_userbuff_strided_multiatomic_fp8<__mt_fp8_e4m3>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, musaStream_t stream); +template void reducescatter2_userbuff_strided_multiatomic_fp8<__mt_fp8_e5m2>( + void *output, float *scale, const int handler, const int offset, const int rowelements, + const int colelements, const int strideelements_out, const int strideelements_in, + const int numchunks, void *counters, communicator *comm, musaStream_t stream); + +__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { + // atomicAdd_system(flagptr, 1); + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(flagptr, 1); +} + +__global__ void kuserbuffers_inc(int *id) { atomicAdd(id, 1); } + +__global__ void kuserbuffers_dummy(void) {} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pullrecv(int myrank, int peer, int nvrank, int nvpeer, int *recv_id, int *flagptr, + int4 *srcptr, int4 *dstptr, const int lines, uint64_t ub_timeout) { +#define UNROLLCOPY 8 + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = (end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)); + const int end_aligned = start_elem + aligned_elem; + + if (threadIdx.x == 0) { + const int signal_id = (*recv_id) + 1; + volatile int *flag = (volatile int *)flagptr; + clock_t s = clock64(); + while (CHECK_IDS(*flag, signal_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT( + "pullrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d]: expecting %d," + " observed %d", + myrank, peer, nvrank, nvpeer, signal_id, *flag); + break; + } + } + if (lines == 0) { + *recv_id = signal_id; + return; + } // otherwise need an extra kernel + } + __syncthreads(); + + if (end_elem <= start_elem) return; + + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) + dstptr[line] = srcptr[line]; +} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pushsend(int *send_id, int *flagptr, int4 *srcptr, int4 *dstptr, const int lines) { + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) + dstptr[line] = srcptr[line]; + } + __syncthreads(); + if (threadIdx.x) return; + __threadfence_system(); + // atomicAdd_system(flagptr, + // 1); // otherwise need local SM sync before sending flag + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(flagptr, 1); + } else { // 0 bytes and 1 SM only + // atomicAdd_system(flagptr, 1); + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(flagptr, 1); + } +} + +#define CHECK_CE(ce_start, ce_end) \ + ((ce_start) != nullptr && (ce_end) != nullptr && *(ce_start) != *(ce_end)) + +__global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpeer, int *recv_id, + int *flagptr, int adder, uint64_t ub_timeout, + int *ce_start_ptr, int *ce_end_ptr) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)flagptr; + if (*flag >= signal_id) return; + clock_t s = clock64(); + while (CHECK_IDS(volatile_load((int*)flag), signal_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT( + "pushrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d]: " + "expecting %d, observed %d", + myrank, peer, nvrank, nvpeer, signal_id, *flag); + if (CHECK_CE(ce_start_ptr, ce_end_ptr)) + UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", *ce_start_ptr, + *ce_end_ptr); + return; + } + } +} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pushsendrecv(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, + const int lines, int send_peer, int recv_peer, int *recv_id, + int *recv_flagptr, int adder, uint64_t ub_timeout, int nv_send, + int nv_recv, int *ce_start_ptr, int *ce_end_ptr) { + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + } +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) { + dstptr[line] = srcptr[line]; + } + } + __syncthreads(); + if (threadIdx.x) return; + __threadfence_system(); + // atomicAdd_system(send_flagptr, + // 1); // otherwise need local SM sync before sending flag + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(send_flagptr, 1); + } else { // 0 bytes and 1 SM only + // atomicAdd_system(send_flagptr, 1); + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(send_flagptr, 1); + } + + if (blockIdx.x == 0 && threadIdx.x == 0) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)recv_flagptr; + if (*flag >= signal_id) return; + clock_t s = clock64(); + while (CHECK_IDS(*flag, signal_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT( + "pushsendrecv [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer: %d" + " receiving peer: %d]: expecting %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); + if (CHECK_CE(ce_start_ptr, ce_end_ptr)) + UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", *ce_start_ptr, + *ce_end_ptr); + return; + } + } + } +} + +__global__ void __launch_bounds__(MAX_THREADS) + kuserbuffers_pushsendrecv_atomic(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, + const int lines, int send_peer, int recv_peer, int *recv_id, + int *recv_flagptr, int adder, void *counters, + uint64_t ub_timeout, int nv_send, int nv_recv, + int *ce_start_ptr, int *ce_end_ptr) { + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + val[i] = srcptr[line + i * blockDim.x * gridDim.x]; + } +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + dstptr[line + i * blockDim.x * gridDim.x] = val[i]; + } + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) { + dstptr[line] = srcptr[line]; + } + } + __syncthreads(); + if (threadIdx.x) return; + __threadfence_system(); + // atomicAdd_system(send_flagptr, + // 1); // otherwise need local SM sync before sending flag + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(send_flagptr, 1); + } else { // 0 bytes and 1 SM only + // atomicAdd_system(send_flagptr, 1); + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(send_flagptr, 1); + } + + if (blockIdx.x == 0 && threadIdx.x == 0) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)recv_flagptr; + clock_t s = clock64(); + while (CHECK_IDS(*flag, signal_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT( + "pushsendrecv atomic [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer:" + " %d receiving peer: %d]: expecting %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ + if (CHECK_CE(ce_start_ptr, ce_end_ptr)) + UB_PRINT("pushsendrecv atomic: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", + *ce_start_ptr, *ce_end_ptr); + } + } + + // Decrement atomic val to signal current output tile finish + if (counters) { + ((unsigned int *)counters)[0] = 0; + // asm volatile("fence.sc.gpu;\n"); + // TODO(xiaoyang): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); + } + } +} + +__global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiatomic( + int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, const int lines, int send_peer, + int recv_peer, int *recv_id, int *recv_flagptr, int adder, void *counters, int nchunks, + int send_stride, int recv_stride, bool shuffle, uint64_t ub_timeout, int nv_send, int nv_recv) { + for (int chunk_i = 0; chunk_i < nchunks - 1; chunk_i++) { + int send_chunk_id = shuffle ? chunk_i : (nchunks + send_peer - chunk_i) % nchunks; + int recv_chunk_id = shuffle ? chunk_i + 1 : (nchunks + send_peer - chunk_i - 1) % nchunks; + int send_offset = (send_chunk_id * send_stride) / 16; + int recv_offset = ((shuffle ? recv_chunk_id : send_chunk_id) * recv_stride) / 16; + + if (lines) { + const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; + const int end_elem = lines; + const int aligned_elem = + ((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1))); + const int end_aligned = start_elem + aligned_elem; + if (end_elem > start_elem) { + for (int line = start_elem; line < end_aligned; + line += blockDim.x * gridDim.x * UNROLLCOPY) { + int4 val[UNROLLCOPY]; +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + val[i] = srcptr[send_offset + line + i * blockDim.x * gridDim.x]; + } +#pragma unroll + for (int i = 0; i < UNROLLCOPY; i++) { + dstptr[recv_offset + line + i * blockDim.x * gridDim.x] = val[i]; + } + } + for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) { + dstptr[recv_offset + line] = srcptr[send_offset + line]; + } + } + __syncthreads(); + if (!threadIdx.x) { + __threadfence_system(); + // atomicAdd_system(send_flagptr, + // 1); // otherwise need local SM sync before sending flag + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(send_flagptr, 1); + } + } else { // 0 bytes and 1 SM only + // atomicAdd_system(send_flagptr, 1); + // TODO(xiaoyang): replace atomicAdd_system temporarily + atomicAdd(send_flagptr, 1); + } + + // wait for message to arrive. + if (blockIdx.x == 0 && threadIdx.x == 0) { + const int signal_id = (*recv_id) + adder; + *recv_id = signal_id; + volatile int *flag = (volatile int *)recv_flagptr; + clock_t s = clock64(); + while (CHECK_IDS(*flag, signal_id)) { + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT( + "pushsendrecv multiatomic [sending peer:%d receiving peer:%d][nvrank(GPU)" + " sending peer: %d receiving peer: %d]: expecting %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ + // CE mode is not supported for multi-atomic, so there is no need to check for a deadlock + return; + } + } + } + + // Producer must update counters. + if (blockIdx.x == 0 && threadIdx.x == 0) { + // Decrement atomic val to signal current output tile finish + if (counters) { + ((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0; + // asm volatile("fence.sc.gpu;\n"); + // TODO(xiaoyang): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); + } + } + + // sync all CTAs before moving to next chunk. + if (threadIdx.x == 0) { + atomicInc(((unsigned int *)counters) + nchunks + chunk_i, gridDim.x - 1); + while (0 != (atomicCAS(((unsigned int *)counters) + nchunks + chunk_i, 0, 0))) { + } + } + __syncthreads(); + } +} + +// Return TRUE if two ranks share the same NV domain +#define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) + +// Index corresponds to the type of flag: +// 0 - Send index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsth, index) \ + ((reinterpret_cast((comm)->peer_ptr[0][(peerlocal)])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (comm)->myrank * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) + +// Index corresponds to the type of flag: +// 0 - Receive index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsth, index) \ + ((reinterpret_cast((comm)->mem_ptr[0])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) + +void comm_userbuff_over_ce(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const int elements, const int comm_bytes, + communicator *comm, const int send_peer, const int recv_peer, + transformer_engine::DType dtype, const int _tp_id, musaStream_t stream) { + +assert(dtype == transformer_engine::DType::kFloat16 || dtype == transformer_engine::DType::kBFloat16); + +MUatomicType atomicType = MUatomicType::MU_ATOMIC_TYPE_ATOMIC_ADD_BF16; +if (dtype == transformer_engine::DType::kFloat16) { +atomicType = MUatomicType::MU_ATOMIC_TYPE_ATOMIC_ADD_HF16; +} +int send_peerlocal = send_peer % comm->nvsize; +int recv_peerlocal = recv_peer % comm->nvsize; + +void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); +void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); + +void *dstptr = reinterpret_cast(comm->mem_ptr[dsthandler]) + dstoffset; +void *srcptr = reinterpret_cast(comm->peer_ptr[srchandler][recv_peerlocal]) + srcoffset; + +// pull mode +CHECK_MUSA_DRIVER(muStreamWaitValue64( + (MUstream)stream, + (MUdeviceptr)flagptr_send, + 0, + MUstreamWaitValue_flags::MU_STREAM_WAIT_VALUE_EQ)); + +CHECK_MUSA_DRIVER(muMemoryAtomicValueAsync( + (MUdeviceptr)flagptr_send, + 1, + MUatomicValueType::MU_ATOMIC_VALUE_TYPE_ATOMIC_ADD64, + (MUstream)stream)); + +CHECK_MUSA_DRIVER(muStreamWaitValue64( + (MUstream)stream, + (MUdeviceptr)flagptr_recv, + 1, + MUstreamWaitValue_flags::MU_STREAM_WAIT_VALUE_EQ)); + +// CHECK_MUSA_DRIVER(muMemoryAtomicAsync( +// (MUdeviceptr)dstptr, +// (MUdeviceptr)srcptr, +// elements, +// atomicType, +// (MUstream)stream)); +NVTE_CHECK_CUDA(musaMemcpyAsync( + dstptr, + srcptr, + comm_bytes, + musaMemcpyDeviceToDevice, + stream)); + +CHECK_MUSA_DRIVER(muStreamWriteValue64( + (MUstream)stream, + (MUdeviceptr)flagptr_recv, + 0, + MUstreamWriteValue_flags::MU_STREAM_WRITE_VALUE_DEFAULT)); +} + + +void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes, communicator *comm, + const int peer, musaStream_t stream) { + int peerlocal = peer % comm->nvsize; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); + // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); + // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + + assert(INTRANODE(peer)); + + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + if (comm->push == 0) { + kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), + reinterpret_cast(flagptr)); + } else { + void *srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + srcoffset; + void *dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; + + if (comm->use_ce) { + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); + NVTE_CHECK_CUDA(musaMemcpyAsync(dstptr, srcptr, bytes, musaMemcpyDeviceToDevice, stream)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + } + SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); + int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast(flagptr); + int4 *arg3 = reinterpret_cast(srcptr), *arg4 = reinterpret_cast(dstptr); + int arg5 = signalonly ? 0 : bytes / 16; + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5)}; + NVTE_CHECK_CUDA( + musaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsend), kernelArgs)); + } +} + +void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset, + const size_t recv_offset, const size_t bytes, communicator *comm, + const int send_peer, const int recv_peer, musaStream_t stream) { + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + int send_peerlocal = send_peer % comm->nvsize; + int recv_peerlocal = recv_peer % comm->nvsize; + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); + // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); + + void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; + void *send_dstptr = + reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; + + if (comm->use_ce) { + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); + NVTE_CHECK_CUDA( + musaMemcpyAsync(send_dstptr, send_srcptr, bytes, musaMemcpyDeviceToDevice, stream)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + } + SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); + + int *arg1 = &comm->send_id[send_peer]; + int *arg2 = reinterpret_cast(flagptr_send); + int4 *arg3 = reinterpret_cast(send_srcptr); + int4 *arg4 = reinterpret_cast(send_dstptr); + int arg5 = signalonly ? 0 : bytes / 16; + int arg6 = send_peer; + int arg7 = recv_peer; + int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; + int *arg9 = reinterpret_cast(flagptr_recv); + int arg10 = signalonly ? 1 : comm->sms; + uint64_t arg11 = comm->ub_timeout; + int arg12 = send_peerlocal; + int arg13 = recv_peerlocal; + int *arg14 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) + : nullptr); + int *arg15 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) + : nullptr); + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10), + reinterpret_cast(&arg11), reinterpret_cast(&arg12), + reinterpret_cast(&arg13), reinterpret_cast(&arg14), + reinterpret_cast(&arg15)}; + NVTE_CHECK_CUDA( + musaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv), kernelArgs)); +} + +void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, + const size_t send_offset, const size_t recv_offset, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, void *counters, musaStream_t stream) { + assert(comm->push && comm->use_ce == 0); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + + int send_peerlocal = send_peer % comm->nvsize; + int recv_peerlocal = recv_peer % comm->nvsize; + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); + // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); + + void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; + void *send_dstptr = + reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; + if (comm->use_ce) { + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); + NVTE_CHECK_CUDA( + musaMemcpyAsync(send_dstptr, send_srcptr, bytes, musaMemcpyDeviceToDevice, stream)); + // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + } + SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); + + int *arg1 = &comm->send_id[send_peer]; + int *arg2 = reinterpret_cast(flagptr_send); + int4 *arg3 = reinterpret_cast(send_srcptr); + int4 *arg4 = reinterpret_cast(send_dstptr); + int arg5 = signalonly ? 0 : bytes / 16; + int arg6 = send_peer; + int arg7 = recv_peer; + int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; + int *arg9 = reinterpret_cast(flagptr_recv); + int arg10 = signalonly ? 1 : comm->sms; + void *arg11 = counters; + int arg12 = comm->ub_timeout; + int arg13 = send_peerlocal; + int arg14 = recv_peerlocal; + int *arg15 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) + : nullptr); + int *arg16 = reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) + : nullptr); + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10), + reinterpret_cast(&arg11), reinterpret_cast(&arg12), + reinterpret_cast(&arg13), reinterpret_cast(&arg14), + reinterpret_cast(&arg15), reinterpret_cast(&arg16)}; + NVTE_CHECK_CUDA(musaLaunchKernelExC( + &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_atomic), kernelArgs)); +} + +void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, + const size_t send_stride, const size_t recv_stride, + const size_t bytes, communicator *comm, const int send_peer, + const int recv_peer, const int nchunks, void *counters, + bool shuffle, musaStream_t stream) { + assert(comm->push && comm->use_ce == 0); + // CE is not supported + + int send_peerlocal = send_peer % comm->nvsize; + int recv_peerlocal = recv_peer % comm->nvsize; + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); + + SETUP_LAUNCH_CONFIG(comm->sms, 1024, stream); + + int *arg1 = &comm->send_id[send_peer]; + int *arg2 = reinterpret_cast(flagptr_send); + int4 *arg3 = reinterpret_cast((comm->mem_ptr[srchandler])); + int4 *arg4 = reinterpret_cast((comm->peer_ptr[dsthandler][send_peerlocal])); + int arg5 = bytes / 16; + int arg6 = comm->myrank; + int arg7 = recv_peer; + int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; + int *arg9 = reinterpret_cast(flagptr_recv); + int arg10 = comm->sms; + void *arg11 = counters; + int arg12 = nchunks; + int arg13 = send_stride; + int arg14 = recv_stride; + bool arg15 = shuffle; + uint64_t arg16 = comm->ub_timeout; + int arg17 = send_peerlocal; + int arg18 = recv_peerlocal; + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), + reinterpret_cast(&arg3), reinterpret_cast(&arg4), + reinterpret_cast(&arg5), reinterpret_cast(&arg6), + reinterpret_cast(&arg7), reinterpret_cast(&arg8), + reinterpret_cast(&arg9), reinterpret_cast(&arg10), + reinterpret_cast(&arg11), reinterpret_cast(&arg12), + reinterpret_cast(&arg13), reinterpret_cast(&arg14), + reinterpret_cast(&arg15), reinterpret_cast(&arg16), + reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; + NVTE_CHECK_CUDA(musaLaunchKernelExC( + &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); +} + +void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes, communicator *comm, + const int peer, musaStream_t stream) { + int peerlocal = peer % comm->nvsize; + void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + + assert(INTRANODE(peer)); + + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; + if (comm->push == 0) { + void *dstptr = reinterpret_cast(comm->mem_ptr[dsthandler]) + dstoffset; + void *srcptr = reinterpret_cast(comm->peer_ptr[srchandler][peerlocal]) + srcoffset; + + kuserbuffers_pullrecv<<sms, signalonly ? 1 : 1024, 0, stream>>>( + comm->myrank, peer, comm->nvrank, peerlocal, + &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), + reinterpret_cast(srcptr), reinterpret_cast(dstptr), + signalonly ? 0 : bytes / 16, comm->ub_timeout); + if (!signalonly) + kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + if (comm->use_ce) { + NVTE_CHECK_CUDA(musaMemcpyAsync(dstptr, srcptr, bytes, musaMemcpyDeviceToDevice, stream)); + } + } else { + kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( + comm->myrank, peer, comm->nvrank, peerlocal, + &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], reinterpret_cast(flagptr), + signalonly || comm->sms, comm->ub_timeout, + reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) + : nullptr), + reinterpret_cast(0 ? // temporary disable + GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) + : nullptr)); + } +} + +// producer +static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { + // Decrement atomic val to signal current output tile finish + if (blockIdx.x == 0 && threadIdx.x == 0) { + ((unsigned int *)atomic_ptr)[chunk_i] = 0; + } + + // COMM kernel need to explicitely flash gmem. + // GEMM kernel already executed, and can not see gmem + // change without COMM kernel explicitely make change + // TODO(yuzhe.wu): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); +} + +// consumer +static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { + // Wait for producer to change the val to 0, which signal producer ready + if (blockIdx.x == 0 && threadIdx.x == 0) { + while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { + } + ((unsigned int *)atomic_ptr)[chunk_i] = 1; + // TODO(yuzhe.wu): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); + } +} + +// consumer_batch +static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i, int num_chunks) { + // Wait for producer to change the val to 0, which signal producer ready + if (blockIdx.x == 0 && threadIdx.x == 0) { + for (int i = first_chunk_i; i < num_chunks; i++) { + while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { + } + ((unsigned int *)atomic_ptr)[i] = 1; + // TODO(yuzhe.wu): replace asm volatile("fence.sc.gpu;\n") temporarily and the correctness needs to be verified + asm volatile("DMA.IDF.SLC.BYPASS %0" :: "R"(global_idf)); + } + } +} + +// reset counters kernel +static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) { + if (blockIdx.x == 0 && threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < num_chunks; i++) { + ((unsigned int *)atomic_ptr)[i] = 1; + ((unsigned int *)atomic_ptr)[i + num_chunks] = 0; + } + if (allgather) ((unsigned int *)atomic_ptr)[0] = 0; + } +} + +void producer(void *atomic_ptr, int chunk_i, musaStream_t stream) { + dim3 block(1); + dim3 grid(1); + producer_kernel<<>>(atomic_ptr, chunk_i); +} + +void consumer(void *atomic_ptr, int chunk_i, musaStream_t stream) { + dim3 block(1); + dim3 grid(1); + consumer_kernel<<>>(atomic_ptr, chunk_i); +} + +void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, musaStream_t stream) { + dim3 block(1); + dim3 grid(1); + consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); +} + +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, musaStream_t stream) { + dim3 block(1); + dim3 grid(1); + reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); +} + +template +__global__ void __launch_bounds__(MAX_THREADS / 4) + reduce_fp8_in_bf16_out_musa(void *inputs, void *output, const float *scale, + const int num_inputs, const int input_size) { + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + fp8type *inputs_fp8 = reinterpret_cast(inputs); + float accum_buf = static_cast(inputs_fp8[tid]) * (*scale); +#pragma unroll + for (int i = 1; i < num_inputs; i++) { + accum_buf += static_cast(inputs_fp8[tid + input_size * i]) * (*scale); + } + half *output_half = reinterpret_cast(output); + output_half[tid] = (half)accum_buf; +} + +template +void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, + int input_size, musaStream_t stream) { + size_t num_threads = MAX_THREADS / 4; + size_t num_blocks = (input_size + num_threads - 1) / num_threads; + dim3 block(num_threads); + dim3 grid(num_blocks); + reduce_fp8_in_bf16_out_musa + <<>>(inputs, output, scale, num_inputs, input_size); +} + +template void reduce_fp8_in_bf16_out<__mt_fp8_e4m3>(void *inputs, void *output, float *scale, + int num_inputs, int input_size, + musaStream_t stream); +template void reduce_fp8_in_bf16_out<__mt_fp8_e5m2>(void *inputs, void *output, float *scale, + int num_inputs, int input_size, + musaStream_t stream); + +__global__ void __launch_bounds__(MAX_THREADS / 4) + reduce_bf16_musa(void *inputs, void *output, const int num_inputs, const int input_size) { + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + __mt_bfloat16 *inputs_half = reinterpret_cast<__mt_bfloat16 *>(inputs); + float accum_buf = static_cast(inputs_half[tid]); +#pragma unroll + for (int i = 1; i < num_inputs; i++) { + accum_buf += static_cast(inputs_half[tid + input_size * i]); + } + __mt_bfloat16 *output_half = reinterpret_cast<__mt_bfloat16 *>(output); + output_half[tid] = (__mt_bfloat16)accum_buf; +} + +void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, musaStream_t stream) { + size_t num_threads = MAX_THREADS / 4; + size_t num_blocks = (input_size + num_threads - 1) / num_threads; + dim3 block(num_threads); + dim3 grid(num_blocks); + reduce_bf16_musa<<>>(inputs, output, num_inputs, input_size); +} diff --git a/transformer_engine/musa/common/common.h b/transformer_engine/musa/common/common.h new file mode 100644 index 0000000000..54603329fe --- /dev/null +++ b/transformer_engine/musa/common/common.h @@ -0,0 +1,517 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ +#define TRANSFORMER_ENGINE_COMMON_COMMON_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "./nvtx.h" +#include "./util/musa_driver.h" +#include "./util/logging.h" + +namespace transformer_engine { + +inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { + NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", + end, " in a vector with ", shape.size(), " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} + +inline size_t product(const std::vector &shape) { + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; +} + +struct SimpleTensor { + void *dptr; + std::vector shape; + DType dtype; + + SimpleTensor(void *dptr, const std::vector &shape, DType dtype) + : dptr(dptr), shape(shape), dtype(dtype) {} + + SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT + : dptr(tensor.data_ptr), + shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), + dtype(static_cast(tensor.dtype)) {} + + SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} + + operator NVTEBasicTensor() const { + const NVTEShape shape = {this->shape.data(), this->shape.size()}; + return {dptr, static_cast(dtype), shape}; + } + + int numel() const { + size_t acc = 1; + for (const auto &dim : shape) { + acc *= dim; + } + return acc; + } +}; + +struct Tensor { + SimpleTensor data; + SimpleTensor columnwise_data; + SimpleTensor amax; + SimpleTensor scale; + SimpleTensor scale_inv; + SimpleTensor columnwise_scale_inv; + + NVTEScalingMode scaling_mode; + + Tensor() + : data(), + columnwise_data(), + amax(nullptr, {1}, DType::kFloat32), + scale(nullptr, {1}, DType::kFloat32), + scale_inv(nullptr, {1}, DType::kFloat32), + columnwise_scale_inv(nullptr, {1}, DType::kFloat32), + scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + + int numel() const { + NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, + "Tensor does not hold any data!"); + size_t acc = 1; + if (data.dptr != nullptr) { + for (const auto &dim : data.shape) { + acc *= dim; + } + return acc; + } + // data is empty, use columnwise_data + for (const auto &dim : columnwise_data.shape) { + acc *= dim; + } + return acc; + } + + bool has_data() const noexcept { return data.dptr != nullptr; } + + bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; } + + DType dtype() const { + if (has_data()) return data.dtype; + if (has_columnwise_data()) return columnwise_data.dtype; + // Fallback, used e.g. in workspace + return data.dtype; + } + + std::vector shape() const { + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: + if (!has_data() && has_columnwise_data()) { + std::vector ret; + if (!columnwise_data.shape.empty()) { + for (size_t i = 1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); + } + return ret; + } else { + return data.shape; + } + break; + case NVTE_MXFP8_1D_SCALING: + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + case NVTE_MTFP8_BLOCK_SCALING: { + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + } + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", (int)scaling_mode, "\""); + return {}; + } + } + + /*! Matrix height after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_first_dim() const { + const auto &full_shape = shape(); + size_t ret = 1; + if (!full_shape.empty()) { + for (size_t i = 0; i < full_shape.size() - 1; i++) { + ret *= full_shape[i]; + } + } + return ret; + } + + /*! Matrix width after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_last_dim() const { + const auto &full_shape = shape(); + if (full_shape.empty()) { + return 1; + } else { + return full_shape.back(); + } + } +}; + +template +constexpr T DIVUP(const T &x, const T &y) { + return (((x) + ((y)-1)) / (y)); +} + +using byte = uint8_t; +using int32 = int32_t; +using int64 = int64_t; +using fp32 = float; +using fp16 = __half; +using bf16 = __mt_bfloat16; +using fp8e4m3 = __mt_fp8_e4m3; +using fp8e5m2 = __mt_fp8_e5m2; +#if CUDA_VERSION >= 12080 +using fp8e8m0 = __nv_fp8_e8m0; +#endif +using e8m0_t = uint8_t; + +namespace detail { + +template +constexpr inline const char *type_name() noexcept; +#define TRANSFORMER_ENGINE_TYPE_NAME(T) \ + template <> \ + inline constexpr const char *type_name() noexcept { \ + return #T; \ + } +TRANSFORMER_ENGINE_TYPE_NAME(uint8_t) +TRANSFORMER_ENGINE_TYPE_NAME(int32_t) +TRANSFORMER_ENGINE_TYPE_NAME(int64_t) +TRANSFORMER_ENGINE_TYPE_NAME(float) +TRANSFORMER_ENGINE_TYPE_NAME(__half) +TRANSFORMER_ENGINE_TYPE_NAME(__mt_bfloat16) +TRANSFORMER_ENGINE_TYPE_NAME(__mt_fp8_e4m3) +TRANSFORMER_ENGINE_TYPE_NAME(__mt_fp8_e5m2) +#if CUDA_VERSION >= 12080 +TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) +#endif +#undef TRANSFORMER_ENGINE_TYPE_NAME + +} // namespace detail + +template +struct TypeInfo { + using types = std::tuple; + + template + struct Helper { + constexpr static DType getType() { + constexpr int i = static_cast(current); + if (std::is_same::type>::value) { + return current; + } else { + return Helper(i + 1)>::getType(); + } + } + }; + + template + struct Helper { + constexpr static DType getType() { return DType::kNumTypes; } + }; + + template + constexpr static DType getType() { + return Helper::getType(); + } + + constexpr static DType dtype = getType(); + constexpr static size_t size = sizeof(T); + constexpr static const char *name = detail::type_name(); +}; + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kByte: { \ + using type = unsigned char; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E8M0: { \ + using type = byte; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: \ + case DType::kFloat8E4M3: { \ + NVTE_ERROR("FP8 type not instantiated for input."); \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using type = fp16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kBFloat16: { \ + using type = bf16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + NVTE_ERROR("Invalid type for 16 bit."); \ + } + +#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Invalid size of the MX scaling factor."); \ + } \ + } + +#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ + if (CONDITION) { \ + constexpr bool FLAG = true; \ + { __VA_ARGS__ } \ + } else { \ + constexpr bool FLAG = false; \ + { __VA_ARGS__ } \ + } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +inline size_t alignTo(size_t x) { + size_t r = x % B; + if (r == 0) return x; + + return x + B - r; +} + +template +struct is_fp8 : std::false_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + +size_t typeToSize(const DType type); + +void CheckNoopTensor(const Tensor &t, const std::string &name); +void CheckInputTensor(const Tensor &t, const std::string &name); +void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); + +bool is_fp8_dtype(const DType t); + +std::string to_string(const DType type); +std::string to_string(const NVTEScalingMode &type); + +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_block_scaling(const NVTEScalingMode &mode) { + return mode != NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { + return is_tensor_scaling(mode); +} + +inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + +inline bool is_mtfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MTFP8_BLOCK_SCALING; } + +/*! \brief Update a tensor's FP8 scale-inverse + * + * The FP8 scale-inverse (dequantization scaling factor) is updated + * with the reciprocal of the FP8 scale (quantization scaling factor). + */ +void update_tensor_scale_inv(Tensor *t, musaStream_t stream); + +#define NVTE_API_CALL(api_name) \ + transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); + +void checkCuDriverContext(MUstream stream); + +// CUtensorMapDataType get_CUtensorMapDataType(DType dtype); + +inline bool isPointerAligned(const void *const ptr, const int alignment); + +// Set up parameters to create TMA descriptor. +// void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, +// const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, +// const uint32_t shmemX, const uint32_t stride_elems, +// const uint32_t offset_elems, const size_t type_size); + +bool is_supported_by_CC_100(); + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ diff --git a/transformer_engine/musa/common/common.mu b/transformer_engine/musa/common/common.mu new file mode 100644 index 0000000000..e3d4b83530 --- /dev/null +++ b/transformer_engine/musa/common/common.mu @@ -0,0 +1,148 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "./common.h" +#include "./utils.muh" +#include "common/util/musa_runtime.h" +#include "common/util/logging.h" + +namespace transformer_engine { + +namespace { + +__global__ void __launch_bounds__(1) + update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr, + float *__restrict__ scale_inv_ptr) { + const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; + reciprocal(scale_inv_ptr, scale); +} + +} // namespace + +void update_tensor_scale_inv(Tensor *t, musaStream_t stream) { + if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { + NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); + update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( + reinterpret_cast(t->scale.dptr), + reinterpret_cast(t->scale_inv.dptr)); + } +} + +void checkCuDriverContext(MUstream stream) { + MUcontext ctx; + const MUresult driver_status = cuda_driver::call("muStreamGetCtx", stream, &ctx); + switch (driver_status) { + case MUSA_SUCCESS: + break; + + case MUSA_ERROR_INVALID_CONTEXT: + int current_device; + NVTE_CHECK_CUDA(musaGetDevice(¤t_device)); + NVTE_CALL_CHECK_CUDA_DRIVER(muDevicePrimaryCtxRetain, &ctx, current_device); + NVTE_CALL_CHECK_CUDA_DRIVER(muCtxSetCurrent, ctx); + break; + + default: + const char *desc_NVTE_CHECK_MUSA_DRIVER; + cuda_driver::call("muGetErrorString", driver_status, &desc_NVTE_CHECK_MUSA_DRIVER); + NVTE_ERROR("MUSA Error: ", desc_NVTE_CHECK_MUSA_DRIVER); + } +} + +/* +CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { + static const std::unordered_map dtypeMapping = { + {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, + {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, + {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, + {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; + return dtypeMapping.at(dtype); +} +*/ + +inline bool isPointerAligned(const void *const ptr, const int alignment) { + const uint64_t ptr_as_uint = reinterpret_cast(ptr); + return ptr_as_uint % alignment == 0; +} + +/* +// Set up parameters to create TMA descriptor. +void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, + const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, + const uint32_t shmemX, const uint32_t stride_elems, + const uint32_t offset_elems, const size_t type_size) { + // Get a function pointer to the cuTensorMapEncodeTiled driver API + static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() { + void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(driver_ptr); + }(); + // rank is the number of dimensions of the array + constexpr uint32_t rank = 2; + uint64_t size[rank] = {globalX, globalY}; + + // The stride is the number of bytes to traverse from the first element of one row to the next + uint64_t stride[rank - 1] = {stride_elems * type_size}; + + // The boxSize is the size of the shared memory buffer that is used as the + // source/destination of a TMA transfer + uint32_t boxSize[rank] = {shmemX, shmemY}; + + // The distance between elements in units of sizeof(element) + uint32_t elemStride[rank] = {1, 1}; + + const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); + void *dataPtr = + reinterpret_cast(reinterpret_cast(tensor.dptr) + offset_elems * type_size); + + constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address + NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment), + "Tensor data pointer must be 16B aligned"); + + const int TMA_needed_size = TMA_gmem_alignment / type_size; + NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, + "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); + + // Create the tensor descriptor. + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, // CUtensorMap *tensorMap, + tensorDataType, + rank, // cuuint32_t tensorRank, + dataPtr, // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + boxSize, // const cuuint32_t *boxDim, + elemStride, // const cuuint32_t *elementStrides, + // Interleave patterns can be used to accelerate loading of values that + // are less than 4 bytes long. + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + + // Swizzling can be used to avoid shared memory bank conflicts. + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + + // L2 Promotion can be used to widen the effect of a cache-policy to a wider + // set of L2 cache lines. + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + // CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + + // Any element that is outside of bounds will be set to zero by the TMA transfer. + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} +*/ + +bool is_supported_by_CC_100() { + // int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + // return deviceComputeCapability >= 100; + return false; +} + +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/fused_attn/fused_attn.cpp b/transformer_engine/musa/common/fused_attn/fused_attn.cpp new file mode 100644 index 0000000000..fd28467cbb --- /dev/null +++ b/transformer_engine/musa/common/fused_attn/fused_attn.cpp @@ -0,0 +1,857 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/fused_attn.h" + +#include "../common.h" +// #include "../cudnn_utils.h" +#include "../util/musa_runtime.h" +#include "../util/system.h" +// #include "fused_attn_f16_arbitrary_seqlen.h" +// #include "fused_attn_f16_max512_seqlen.h" +// #include "fused_attn_fp8.h" +// #include "utils.h" + +// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group +NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { + switch (qkv_layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + return NVTE_QKV_Layout_Group::NVTE_3HD; + case NVTE_QKV_Layout::NVTE_SBH3D: + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + return NVTE_QKV_Layout_Group::NVTE_H3D; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + return NVTE_QKV_Layout_Group::NVTE_HD_2HD; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + return NVTE_QKV_Layout_Group::NVTE_HD_H2D; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + +// map NVTE_QKV_Layout to NVTE_QKV_Format +NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { + switch (qkv_layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + case NVTE_QKV_Layout::NVTE_SBH3D: + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + return NVTE_QKV_Format::NVTE_SBHD; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + return NVTE_QKV_Format::NVTE_BSHD; + case NVTE_QKV_Layout::NVTE_T3HD: + case NVTE_QKV_Layout::NVTE_TH3D: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + return NVTE_QKV_Format::NVTE_THD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + +// select a backend for fused attention +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right) { + using namespace transformer_engine; + NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + /* + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + auto cudnn_runtime_version = cudnnGetVersion(); + + // For ragged offsets we only support 32-bit prior to cuDNN 9.5 + // Only used when THD format is requested. + const bool requires_64bit_ragged_offset = + (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( + layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); + const bool supported_ragged_offset_size = + (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); + + if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 8.9: t3hd, max_s=512, d=64, padding + ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + !requires_64bit_ragged_offset) { + if (cudnn_runtime_version >= 8900) { + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; + } else { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { + bool flag_m512 = false; + bool flag_arb = false; + if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && + (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && + (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && + ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || + (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + max_seqlen_q == max_seqlen_kv) || + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && + ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || + (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && + ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && + !requires_64bit_ragged_offset) { + flag_m512 = true; + } + if ( + // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging + // special conditions for blackwell + // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 + !(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) && + // architecture + ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || + (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && + // sequence length + ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || + (cudnn_runtime_version >= 90000)) && + // number of heads + ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || + (cudnn_runtime_version >= 8907)) && + // head dimension + ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // d=256 only supported for forward + (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && + head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && + // bias type + ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || + (cudnn_runtime_version >= 8906 && + (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || + (bias_type == NVTE_Bias_Type::NVTE_ALIBI && + attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + sm_arch_ >= 90) || + (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || + (cudnn_runtime_version >= 90000 && + (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && + // mask type + // pre-8.9.6: causal + ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} + (cudnn_runtime_version >= 8906 && + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.1: adds thd + {padding, padding_causal} + (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90300 && + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90600 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && + // bias + mask combination + (!(cudnn_runtime_version >= 8906 && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && + bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && + // qkv format + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || + (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && + ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || + cudnn_runtime_version >= 90600))) && + // sliding window + // pre-9.2: full attn, causal + ((cudnn_runtime_version < 90200 && window_size_left == -1 && + (window_size_right == -1 || window_size_right == 0)) || + // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} + (cudnn_runtime_version >= 90200 && + ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q == max_seqlen_kv)) && + max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || + // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} + (cudnn_runtime_version >= 90600 && + ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + // TODO(cyang): fix bug for BRCM + cross-attention on sm100 + (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + (sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700))))) && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)))) && + // check 64-bit ragged offset support + (supported_ragged_offset_size)) { + flag_arb = true; + } + if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } + if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + if (flag_arb == true) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } else if ((flag_arb == false) && (flag_m512 == true)) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; + } + int env_backend = static_cast(backend); + env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); + if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && + flag_m512) || + ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && + flag_arb)) { + backend = static_cast(env_backend); + } + } + if (cudnn_runtime_version < 8901 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + if (cudnn_runtime_version < 8900 && + backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." + << std::endl; + } + } else { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + */ + return backend; +} + +// NVTE fused attention FWD with packed QKV +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + const NVTETensor rng_state, size_t max_seqlen, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, musaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); + using namespace transformer_engine; +/* + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_cu_seqlens_padded = reinterpret_cast(cu_seqlens_padded); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + auto ndim = input_QKV->data.shape.size(); + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + h = input_QKV->data.shape[ndim - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + h = input_QKV->data.shape[ndim - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + } + size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, + max_seqlen, d, d, window_size_left, window_size_right); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias, + output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8900) + fused_attn_arbitrary_seqlen_fwd_qkvpacked( + b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, + Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, + stream, handle); +#else + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_QKV, input_output_S, output_O, + Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, + stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +*/ +} +// NVTE fused attention BWD with packed QKV +void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, + NVTETensor dBias, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, size_t max_seqlen, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + bool deterministic, NVTETensor workspace, musaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); + using namespace transformer_engine; +/* + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_cu_seqlens_padded = reinterpret_cast(cu_seqlens_padded); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQKV = reinterpret_cast(dQKV); + Tensor *output_dBias = reinterpret_cast(dBias); + Tensor *wkspace = reinterpret_cast(workspace); + + auto ndim = input_QKV->data.shape.size(); + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + h = input_QKV->data.shape[ndim - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + h = input_QKV->data.shape[ndim - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + } + size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, + max_seqlen, d, d, window_size_left, window_size_right); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd_qkvpacked( + b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, + input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8900) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *input_Bias, *input_rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + } else { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + } + fused_attn_arbitrary_seqlen_bwd_qkvpacked( + b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, + input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, + input_rng_state, wkspace, stream, handle); +#else + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, + input_S, input_output_dP, output_dQKV, input_cu_seqlens, + input_rng_state, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +*/ +} +// NVTE fused attention FWD with packed KV +void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, musaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); + using namespace transformer_engine; +/* + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + auto ndim = input_Q->data.shape.size(); + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + auto ndim_kv = input_KV->data.shape.size(); + size_t h_kv = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + h_kv = input_KV->data.shape[ndim_kv - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + h_kv = input_KV->data.shape[ndim_kv - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, + max_seqlen_kv, d, d, window_size_left, window_size_right); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + fused_attn_max_512_fwd_kvpacked( + b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8903) + fused_attn_arbitrary_seqlen_fwd_kvpacked( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); +#else + NVTE_ERROR( + "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + fused_attn_fp8_fwd_kvpacked( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +*/ +} +// NVTE fused attention BWD with packed KV +void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, + NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); + using namespace transformer_engine; +/* + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQ = reinterpret_cast(dQ); + Tensor *output_dKV = reinterpret_cast(dKV); + Tensor *output_dBias = reinterpret_cast(dBias); + Tensor *wkspace = reinterpret_cast(workspace); + + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + auto ndim = input_Q->data.shape.size(); + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t d = input_Q->data.shape[ndim - 1]; + auto ndim_kv = input_KV->data.shape.size(); + size_t h_kv = 0; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + h_kv = input_KV->data.shape[ndim_kv - 2]; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + h_kv = input_KV->data.shape[ndim_kv - 3]; + } else { + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_KV->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, + max_seqlen_kv, d, d, window_size_left, window_size_right); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd_kvpacked( + b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8903) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *input_Bias, *input_rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + } else { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + } + fused_attn_arbitrary_seqlen_bwd_kvpacked( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); +#else + const char *err_msg = + "cuDNN 8.9.3 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +*/ +} +// NVTE fused attention FWD with separate Q, K and V +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd); + using namespace transformer_engine; +/* + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_K = reinterpret_cast(K); + const Tensor *input_V = reinterpret_cast(V); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, + input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8900) + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR( + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +*/ +} +// NVTE fused attention BWD with separate Q, K and V +void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, + NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, + size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd); + using namespace transformer_engine; +/* + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_K = reinterpret_cast(K); + const Tensor *input_V = reinterpret_cast(V); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQ = reinterpret_cast(dQ); + Tensor *output_dK = reinterpret_cast(dK); + Tensor *output_dV = reinterpret_cast(dV); + Tensor *output_dBias = reinterpret_cast(dBias); + Tensor *wkspace = reinterpret_cast(workspace); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } + + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { +#if (CUDNN_VERSION >= 8901) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { +#if (CUDNN_VERSION >= 8900) + Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *input_Bias, *input_rng_state; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + input_Bias = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + } else { + input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + } + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, + output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); +#else + const char *err_msg = + "cuDNN 8.9.0 is required for BF16/FP16 fused attention " + "with arbitrary sequence length. \n"; + NVTE_ERROR(err_msg); +#endif + } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { +#if (CUDNN_VERSION >= 8900) + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, + output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif + } else { + NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + } +*/ +} diff --git a/transformer_engine/musa/common/fused_attn/thd_utils.h b/transformer_engine/musa/common/fused_attn/thd_utils.h new file mode 100644 index 0000000000..1340772f00 --- /dev/null +++ b/transformer_engine/musa/common/fused_attn/thd_utils.h @@ -0,0 +1,250 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ + +#include +#include + +namespace transformer_engine { +namespace fused_attn { + +/*************************************************************************************************** + * Support THD format for Context Parallel: Binary search an array for a target value + **************************************************************************************************/ + +__forceinline__ __device__ int binary_search(int target, int *array, int len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Generate partitioned indices for input tokens + **************************************************************************************************/ + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank); + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token); + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +struct LseCorrectionFunctor { + __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, + size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +struct ReadLseFunctor { + __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, + size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +template +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, + int num_heads, int lse_seqlen, int second_half_lse_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + int num_total_tokens = cu_seqlens_s[batch]; + + for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, half_idx; + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * second_half_lse_seqlen + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + + idx = row * lse_seqlen + col + seq_len; + half_idx = row * second_half_lse_seqlen + col; + } + + Functor::run(lse, half_lse, idx, half_idx); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, + float *lse_per_step, int *cu_seqlens, int batch, + int num_heads, int dim_per_head, int lse_seqlen, + int lse_per_step_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); + } + __syncthreads(); + + int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; + int lane_id = threadIdx.x % tile_size; + int num_tiles = (blockDim.x * gridDim.x) / tile_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); + + for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, idx_per_step; + + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx_per_step = head_id * lse_per_step_seqlen + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * lse_seqlen + col + seq_len * only_second_half; + idx_per_step = row * lse_per_step_seqlen + col; + } + float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); + + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx = (idx * num_heads + head_id) * dim_per_head; + idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; + dtype *cur_out = out + idx; + dtype *cur_out_per_step = out_per_step + idx_per_step; + + for (int j = lane_id; j < num_loops_per_head; j += tile_size) { + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); + } + reinterpret_cast(cur_out)[j] = data; + } + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +struct EmptyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + +#pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + +template +__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, + int batch, int hidden_size, int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + if constexpr (functor_idx < 2) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } else { + cu_seqlens_s[i] = cu_seqlens[i]; + } + } + __syncthreads(); + + int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; + int lane_id = threadIdx.x % group_size; + int num_groups = (blockDim.x * gridDim.x) / group_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size; + if constexpr (functor_idx < 2) { + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; + } else { + grad_per_step = grad_per_step + offset * blockIdx.y; + } + grad = grad + offset * blockIdx.y; + + for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + + int token_offset; + bool is_first_half; + if constexpr (functor_idx < 2) { + token_offset = cu_seqlens_s[seq_id + functor_idx]; + is_first_half = (functor_idx == 0); + } else { + token_offset = 0; + int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); + } + + dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; + dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; + for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { + if (is_first_half) { + Functor_0::run(token, token_per_step, idx); + } else { + Functor_1::run(token, token_per_step, idx); + } + } + } +} + +} // namespace fused_attn +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/musa/common/fused_attn/thd_utils.mu b/transformer_engine/musa/common/fused_attn/thd_utils.mu new file mode 100644 index 0000000000..9358eb7c48 --- /dev/null +++ b/transformer_engine/musa/common/fused_attn/thd_utils.mu @@ -0,0 +1,78 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// #include "../cudnn_utils.h" +#include "thd_utils.h" + +#include + +namespace transformer_engine { +namespace fused_attn { + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size * 2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4 *cur_half_token = + reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); + + offset_in_bytes = + (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; + float4 *cur_token = + reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } + } +} + +} // namespace fused_attn +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/fused_rope/fused_rope.mu b/transformer_engine/musa/common/fused_rope/fused_rope.mu new file mode 100644 index 0000000000..232fecc205 --- /dev/null +++ b/transformer_engine/musa/common/fused_rope/fused_rope.mu @@ -0,0 +1,366 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.muh" + +namespace transformer_engine { + +template +__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, + const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + float v_cos, v_sin; + sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + float v_src = src[offset_src]; + float v_src_rotate = (d_id + d2 / 2 < d2) + ? -static_cast(src[offset_src + (d2 / 2) * stride_d]) + : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + } + } + } +} + +template +__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, + const int s_id, const int offset_block, + const int offset_block_dst, const int h, const int d, + const int d2, const int stride_h, const int stride_d, + const int o_stride_h, const int o_stride_d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + float v_cos = cosf(freqs[s_id * d2 + d_id]); + float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) + : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + float v_src = src[offset_src]; + float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // handle the tail + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + } + } + } +} + +template +__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, + const int h, const int d, const int d2, + const int stride_s, const int stride_b, + const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, + const int h, const int d, const int d2, + const int stride_s, const int stride_b, + const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, + stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, + const float *freqs, scalar_t *dst, const int cp_size, + const int cp_rank, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + + int s_id_for_freqs; + if (cp_size > 1) { + int cur_seqlens = end - start; + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, + d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, + const float *freqs, scalar_t *dst, const int cp_size, + const int cp_rank, const int h, const int d, + const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + + int s_id_for_freqs; + if (cp_size > 1) { + int cur_seqlens = end - start; + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, + d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +template +void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, musaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + + fused_rope_forward_kernel<<>>( + input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d); + NVTE_CHECK_CUDA(musaGetLastError()); +} + +template +void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, + scalar_t *input_grads, const int s, const int b, const int h, + const int d, const int d2, const int stride_s, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, const int o_stride_d, + musaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + + fused_rope_backward_kernel<<>>( + output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, + o_stride_s, o_stride_b, o_stride_h, o_stride_d); + NVTE_CHECK_CUDA(musaGetLastError()); +} + +template +void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, + const float *freqs, scalar_t *output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + musaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(max_s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + + fused_rope_thd_forward_kernel<<>>( + input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d); + NVTE_CHECK_CUDA(musaGetLastError()); +} + +template +void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, + const float *freqs, scalar_t *input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, + musaStream_t stream) { + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(max_s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + + fused_rope_thd_backward_kernel<<>>( + output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d); + NVTE_CHECK_CUDA(musaGetLastError()); +} + +void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, const int o_stride_h, + const int o_stride_d, musaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, scalar_t, + fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(output->data.dptr), s, b, h, d, d2, + stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, stream);); +} + +void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, musaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + output_grads.data.dtype, scalar_t, + fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(input_grads->data.dptr), s, b, h, d, + d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, stream);); +} + +void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, + Tensor *output, const int cp_size, const int cp_rank, const int max_s, + const int b, const int h, const int d, const int d2, const int stride_t, + const int stride_h, const int stride_d, const int o_stride_t, + const int o_stride_h, const int o_stride_d, musaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, scalar_t, + fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(output->data.dptr), cp_size, + cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, + o_stride_t, o_stride_h, o_stride_d, stream);); +} + +void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, + const Tensor &freqs, Tensor *input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, musaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + output_grads.data.dtype, scalar_t, + fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(freqs.data.dptr), + reinterpret_cast(input_grads->data.dptr), + cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, + stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); +} + +} // end namespace transformer_engine + +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, musaStream_t stream) { + NVTE_API_CALL(nvte_fused_rope_forward); + using namespace transformer_engine; + fused_rope_forward(*reinterpret_cast(input), + *reinterpret_cast(freqs), reinterpret_cast(output), + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, stream); +} + +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, + NVTETensor input_grads, const int s, const int b, const int h, + const int d, const int d2, const int stride_s, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, const int o_stride_d, + musaStream_t stream) { + NVTE_API_CALL(nvte_fused_rope_backward); + using namespace transformer_engine; + fused_rope_backward(*reinterpret_cast(output_grads), + *reinterpret_cast(freqs), + reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); +} + +void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, musaStream_t stream) { + NVTE_API_CALL(nvte_fused_rope_thd_forward); + using namespace transformer_engine; + fused_rope_thd_forward(*reinterpret_cast(input), + *reinterpret_cast(cu_seqlens), + *reinterpret_cast(freqs), + reinterpret_cast(output), cp_size, cp_rank, max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); +} + +void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, musaStream_t stream) { + NVTE_API_CALL(nvte_fused_rope_thd_backward); + using namespace transformer_engine; + fused_rope_thd_backward( + *reinterpret_cast(output_grads), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), + reinterpret_cast(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); +} diff --git a/transformer_engine/musa/common/fused_softmax/scaled_aligned_causal_masked_softmax.mu b/transformer_engine/musa/common/fused_softmax/scaled_aligned_causal_masked_softmax.mu new file mode 100644 index 0000000000..5065de4e0c --- /dev/null +++ b/transformer_engine/musa/common/fused_softmax/scaled_aligned_causal_masked_softmax.mu @@ -0,0 +1,568 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.muh" + +namespace transformer_engine { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *((uint64_t *)dst) = *((uint64_t *)src); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { + *((uint64_t *)dst) = *((uint64_t *)src); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *((uint32_t *)dst) = *((uint32_t *)src); // NOLINT(*) +} + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(bf16 *dst) { + *dst = 0.0f; +} + +template <> +__device__ __inline__ void copy_zero_vector(bf16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_zero_vector(fp16 *dst) { + *dst = 0.0f; +} + +template <> +__device__ __inline__ void copy_zero_vector(fp16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if 1 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_ROWS; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with the following additional features + * 1) input scaling + * 2) implicit causal masking + * + * works for all cases: + * k > q + * k < q + * k = q + * + * where: + * microbatches = batches * attn_heads * query_seq_len + * rows = query_seq_len + * cols = key_seq_len + */ +template +__global__ void scaled_aligned_causal_masked_softmax_warp_forward(output_t *dst, const input_t *src, + const acc_t scale, + const int microbatches, + const int rows, const int cols) { + // 1) WARP_WIDTH must match the value of warp_size + // 2) WARP_ROWS must match the value of rows_per_warp + // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_WIDTH = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; + constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; + const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; + + const size_t thread_offset = global_row_idx * cols + col; + + src += thread_offset; + dst += thread_offset; + + // load data from global memory into registers WITH scaling + acc_t elements[WARP_ROWS][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; // local row index of attention matrix + const int masked_elements = i + cols - rows + 1; + + if (microbatch >= microbatches) { + break; + } + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { + copy_vector(temp_data, src + itr_idx); +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + elements[w][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[w][it + element] = (acc_t)(-10'000); + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[w][it + element] = (acc_t)(-10'000); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_ROWS]; +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + max_value[w] = elements[w][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[w] = (max_value[w] > elements[w][it]) ? max_value[w] : elements[w][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_ROWS]{0.0f}; +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[w][it] = expf((elements[w][it] - max_value[w])); + sum[w] += elements[w][it]; + } + } + warp_reduce(sum); + + output_t out[ELEMENTS_PER_LDG_STG]{0.0f}; +// store result +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; + const int masked_elements = i + cols - rows + 1; + + // out of Attention matrix bounds (rows) + if (microbatch >= microbatches) { + break; + } + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + out[element] = elements[w][it + element] / sum[w]; + } else { + out[element] = (output_t)(0.0f); + } + } + copy_vector(dst + itr_idx, out); + } else if (j < cols) { + copy_zero_vector(dst + itr_idx); + } else { + break; + } + } + } +} + +template +__global__ void scaled_aligned_causal_masked_softmax_warp_backward( + output_t *gradInput, const input_t *grad, const input_t *softmax_output, const acc_t scale, + const int microbatches, const int rows, const int cols) { + // 1) WARP_WIDTH must match the value of warp_size + // 2) WARP_ROWS must match the value of rows_per_warp + // of the dispatch_scaled_aligned_causal_masked_softmax_forward method. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_WIDTH = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH; + constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS; + const int col = threadIdx.x * ELEMENTS_PER_LDG_STG; + + const size_t thread_offset = global_row_idx * cols + col; + + grad += thread_offset; + softmax_output += thread_offset; + gradInput += thread_offset; + + // load data from global memory into registers + acc_t grad_reg[WARP_ROWS][WARP_ITERATIONS]{0.0f}; + acc_t softmax_output_reg[WARP_ROWS][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + const int i = microbatch % rows; // local row index of attention matrix + const int masked_elements = i + cols - rows + 1; + + if (microbatch >= microbatches) { + break; + } + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < masked_elements) { + copy_vector(temp_grad, grad + itr_idx); + copy_vector(temp_output, softmax_output + itr_idx); +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (j + element < masked_elements) { + softmax_output_reg[w][it + element] = (acc_t)temp_output[element]; + grad_reg[w][it + element] = + (acc_t)temp_grad[element] * softmax_output_reg[w][it + element]; + } + } + } + } + } + + acc_t sum[WARP_ROWS]; +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + sum[w] = grad_reg[w][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[w] += grad_reg[w][it]; + } + } + + warp_reduce(sum); + +// store result +#pragma unroll + for (int w = 0; w < WARP_ROWS; ++w) { + const int microbatch = global_row_idx + w; + if (microbatch >= microbatches) { + break; + } + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + const int j = col + it * WARP_WIDTH; // index of the first column + const int itr_idx = w * cols + it * WARP_WIDTH; + + if (j < cols) { + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[w][it + element] - + softmax_output_reg[w][it + element] * sum[w])); + } + copy_vector(gradInput + itr_idx, out); + } + } + } +} + +template +void call_kernel_scaled_aligned_causal_masked_softmax_forward( + dim3 grid_size, dim3 block_size, const int shmem_size, musaStream_t stream, output_t *dst, + const input_t *src, const acc_t scale, const int microbatches, const int query_seq_len, + const int key_seq_len) { + scaled_aligned_causal_masked_softmax_warp_forward + <<>>(dst, src, scale, microbatches, query_seq_len, + key_seq_len); +} + +template +void call_kernel_scaled_aligned_causal_masked_softmax_backward( + dim3 grid_size, dim3 block_size, const int shmem_size, musaStream_t stream, output_t *gradInput, + const input_t *grad, const input_t *output, const acc_t scale, const int microbatches, + const int query_seq_len, const int key_seq_len) { + scaled_aligned_causal_masked_softmax_warp_backward + <<>>(gradInput, grad, output, scale, microbatches, + query_seq_len, key_seq_len); +} + +template +struct FunctionWrapper { + using ForwardType = + std::function; + using BackwardType = std::function; +}; + +constexpr int MIN_SUPPORTED_POWER = 4; +constexpr int MAX_SUPPORTED_POWER = 14; +constexpr int MIN_POWER = MIN_SUPPORTED_POWER - 1; +constexpr int MAX_POWER = MAX_SUPPORTED_POWER + 1; + +// Recursively instantiate the function for the limit of "log2_elements", +// i.e. "MAX_POWER" defined above. +template +struct CompileTimeLoopForward { + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static void populate(std::array *arr) { + CompileTimeLoopForward::populate(arr); + (*arr)[log2_elements] = + &call_kernel_scaled_aligned_causal_masked_softmax_forward; + } +}; + +template +struct CompileTimeLoopForward { + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static void populate(std::array *arr) { (*arr)[MIN_POWER] = nullptr; } +}; + +template +struct CompileTimeLoopBackward { + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static void populate(std::array *arr) { + CompileTimeLoopBackward::populate(arr); + (*arr)[log2_elements] = + &call_kernel_scaled_aligned_causal_masked_softmax_backward; + } +}; + +template +struct CompileTimeLoopBackward { + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static void populate(std::array *arr) { + (*arr)[MIN_POWER] = nullptr; + } +}; + +template +void dispatch_scaled_aligned_causal_masked_softmax_forward(output_t *dst, const input_t *src, + const input_t scale, int query_seq_len, + int key_seq_len, int batches, + int attn_heads, musaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + + if (key_seq_len == 0) { + return; + } + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_WIDTH constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_ROWS constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = threads_per_block / warp_width; + int microbatches_per_block = warps_per_block * microbatches_per_warp; + int microbatches = batches * attn_heads * query_seq_len; + int blocks = DIVUP(microbatches, microbatches_per_block); + + dim3 block_size(warp_width, warps_per_block); + dim3 grid_size(blocks); + + // create an array of pointers to functions + using ForwardFuncType = typename FunctionWrapper::ForwardType; + static std::array forwardFunctionArray; + static bool is_initialized = false; + if (!is_initialized) { + CompileTimeLoopForward::populate( + &forwardFunctionArray); + is_initialized = true; + } + // Call the corresponding kernel + forwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, dst, src, scale, + microbatches, query_seq_len, key_seq_len); +} + +template +void dispatch_scaled_aligned_causal_masked_softmax_backward( + output_t *grad_input, const input_t *grad, const input_t *output, const acc_t scale, + int query_seq_len, int key_seq_len, int batches, int attn_heads, musaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + + if (key_seq_len == 0) { + return; + } + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_WIDTH constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_ROWS constexpr + // value computed inside scaled_aligned_causal_masked_softmax_warp_forward. + int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = threads_per_block / warp_width; + int microbatches_per_block = warps_per_block * microbatches_per_warp; + int microbatches = batches * attn_heads * query_seq_len; + int blocks = DIVUP(microbatches, microbatches_per_block); + + dim3 block_size(warp_width, warps_per_block); + dim3 grid_size(blocks); + + // create an array of pointers to functions + using BackwardFuncType = typename FunctionWrapper::BackwardType; + static std::array backwardFunctionArray; + static bool is_initialized = false; + if (!is_initialized) { + CompileTimeLoopBackward::populate( + &backwardFunctionArray); + is_initialized = true; + } + // Call the corresponding kernel + backwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, grad_input, grad, output, + scale, microbatches, query_seq_len, key_seq_len); +} + +void scaled_aligned_causal_masked_softmax_forward(const Tensor &input, Tensor *softmax_results, + float scale_factor, musaStream_t stream) { + const int batches = input.data.shape[0]; + const int attn_heads = input.data.shape[1]; + const int query_seq_len = input.data.shape[2]; + const int key_seq_len = input.data.shape[3]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_aligned_causal_masked_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), scale_factor, query_seq_len, + key_seq_len, batches, attn_heads, stream);); +} + +void scaled_aligned_causal_masked_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + musaStream_t stream) { + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.data.shape[0]; + const int attn_heads = output_grads.data.shape[1]; + const int query_seq_len = output_grads.data.shape[2]; + const int key_seq_len = output_grads.data.shape[3]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_aligned_causal_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, stream);); +} +} // end namespace transformer_engine + +void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, + float scale_factor, musaStream_t stream) { + NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward); + using namespace transformer_engine; + scaled_aligned_causal_masked_softmax_forward(*reinterpret_cast(input), + reinterpret_cast(softmax_results), + scale_factor, stream); +} + +void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + musaStream_t stream) { + NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward); + using namespace transformer_engine; + scaled_aligned_causal_masked_softmax_backward( + *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); +} diff --git a/transformer_engine/musa/common/fused_softmax/scaled_masked_softmax.mu b/transformer_engine/musa/common/fused_softmax/scaled_masked_softmax.mu new file mode 100644 index 0000000000..77968344ef --- /dev/null +++ b/transformer_engine/musa/common/fused_softmax/scaled_masked_softmax.mu @@ -0,0 +1,850 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.muh" + +namespace transformer_engine { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_vector(half *dst, const half *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(half *dst, const half *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); // NOLINT(*) +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if 1 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + */ +template +__global__ void scaled_softmax_warp_forward(output_t *dst, const input_t *src, const acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + size_t first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset; + dst += thread_offset; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward(output_t *dst, const input_t *src, + const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, + int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + size_t first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + size_t pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + size_t thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + size_t thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset_src_dst; + dst += thread_offset_src_dst; + mask += thread_offset_mask; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + // compute scale value to account for full mask + acc_t scale_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; + } + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] * scale_value[i] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward(output_t *gradInput, const input_t *grad, + const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + size_t first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, + grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, + output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, + out); + } + } + } +} + +template +void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const input_t scale, + int query_seq_len, int key_seq_len, int batches, + int attn_heads, musaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(query_seq_len % batches_per_block == 0, "Unsupported shape."); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 13: // 8192 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 14: // 16384 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, const uint8_t *mask, + const input_t scale, int query_seq_len, int key_seq_len, + int batches, int attn_heads, int pad_batches, + musaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(query_seq_len % batches_per_block == 0, "Unsupported shape."); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 12: // 4096 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 13: // 8192 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + case 14: // 16384 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, + pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t *grad, + const input_t *output, const acc_t scale, + int query_seq_len, int key_seq_len, int batches, + int attn_heads, musaStream_t stream) { + NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape."); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 12: // 4096 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 13: // 8192 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + case 14: // 16384 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + key_seq_len); + break; + default: + break; + } + } +} + +void scaled_softmax_forward(const Tensor &input, Tensor *softmax_results, float scale_factor, + musaStream_t stream) { + const int batches = input.data.shape[0]; + const int attn_heads = input.data.shape[1]; + const int query_seq_len = input.data.shape[2]; + const int key_seq_len = input.data.shape[3]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), scale_factor, query_seq_len, + key_seq_len, batches, attn_heads, stream);); +} + +void scaled_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + musaStream_t stream) { + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.data.shape[0]; + const int attn_heads = output_grads.data.shape[1]; + const int query_seq_len = output_grads.data.shape[2]; + const int key_seq_len = output_grads.data.shape[3]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, stream);); +} + +void scaled_masked_softmax_forward(const Tensor input, const Tensor mask, Tensor *softmax_results, + float scale_factor, musaStream_t stream) { + const int batches = input.data.shape[0]; + const int pad_batches = mask.data.shape[0]; + const int attn_heads = input.data.shape[1]; + const int query_seq_len = input.data.shape[2]; + const int key_seq_len = input.data.shape[3]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(mask.data.dptr), scale_factor, query_seq_len, + key_seq_len, batches, attn_heads, pad_batches, stream);); +} + +void scaled_masked_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + musaStream_t stream) { + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.data.shape[0]; + const int attn_heads = output_grads.data.shape[1]; + const int query_seq_len = output_grads.data.shape[2]; + const int key_seq_len = output_grads.data.shape[3]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, + query_seq_len, key_seq_len, batches, attn_heads, stream);); +} + +} // end namespace transformer_engine + +void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_results, + float scale_factor, musaStream_t stream) { + NVTE_API_CALL(nvte_scaled_softmax_forward); + using namespace transformer_engine; + scaled_softmax_forward(*reinterpret_cast(input), + reinterpret_cast(softmax_results), scale_factor, stream); +} + +void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + musaStream_t stream) { + NVTE_API_CALL(nvte_scaled_softmax_backward); + using namespace transformer_engine; + scaled_softmax_backward(*reinterpret_cast(output_grads), + *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); +} + +void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, + NVTETensor softmax_results, float scale_factor, + musaStream_t stream) { + NVTE_API_CALL(nvte_scaled_masked_softmax_forward); + using namespace transformer_engine; + scaled_masked_softmax_forward(*reinterpret_cast(input), + *reinterpret_cast(mask), + reinterpret_cast(softmax_results), scale_factor, stream); +} + +void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, NVTETensor output_grads, + float scale_factor, musaStream_t stream) { + NVTE_API_CALL(nvte_scaled_masked_softmax_backward); + using namespace transformer_engine; + scaled_masked_softmax_backward( + *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); +} diff --git a/transformer_engine/musa/common/fused_softmax/scaled_upper_triang_masked_softmax.mu b/transformer_engine/musa/common/fused_softmax/scaled_upper_triang_masked_softmax.mu new file mode 100644 index 0000000000..73ee9d17e1 --- /dev/null +++ b/transformer_engine/musa/common/fused_softmax/scaled_upper_triang_masked_softmax.mu @@ -0,0 +1,615 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.muh" + +namespace transformer_engine { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(bf16 *dst, const bf16 *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(fp16 *dst, const fp16 *src) { + *((float2 *)dst) = *((float2 *)src); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); // NOLINT(*) +} + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(bf16 *dst) { + *dst = 0.0f; +} + +template <> +__device__ __inline__ void copy_zero_vector(bf16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) +} + +template <> +__device__ __inline__ void copy_zero_vector(fp16 *dst) { + *dst = 0.0f; +} + +template <> +__device__ __inline__ void copy_zero_vector(fp16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); // NOLINT(*) +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if 1 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward(output_t *dst, const input_t *src, + const acc_t scale, + int micro_batch_size, int stride, + int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + size_t first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + src += thread_offset; + dst += thread_offset; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0.0f; + } + } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward(output_t *gradInput, + const input_t *grad, + const input_t *output, acc_t scale, + int micro_batch_size, int stride, + int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + size_t first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); + } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const input_t *src, + const input_t scale, int softmax_elements, + int softmax_elements_stride, + int attn_batches, musaStream_t stream) { + NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape."); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input, const input_t *grad, + const input_t *output, const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches, musaStream_t stream) { + NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape."); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr + // value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP; + + // This value must match the WARP_BATCH constexpr + // value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape."); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 14: // 16384 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +void scaled_upper_triang_masked_softmax_forward(const Tensor input, Tensor *softmax_results, + float scale_factor, musaStream_t stream) { + const int attn_batches = input.data.shape[0]; + const int seq_len = input.data.shape[1]; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.data.dtype, softmax_type, + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results->data.dptr), + reinterpret_cast(input.data.dptr), scale_factor, seq_len, seq_len, + attn_batches, stream);); +} + +void scaled_upper_triang_masked_softmax_backward(Tensor output_grads, const Tensor incoming_grads, + const Tensor softmax_results, float scale_factor, + musaStream_t stream) { + const int attn_batches = output_grads.data.shape[0]; + const int seq_len = output_grads.data.shape[1]; + + // Softmax Grad + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output_grads.data.dtype, softmax_type, + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(incoming_grads.data.dptr), + reinterpret_cast(softmax_results.data.dptr), scale_factor, seq_len, + seq_len, attn_batches, stream);); +} + +} // end namespace transformer_engine + +void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, float scale_factor, + musaStream_t stream) { + using namespace transformer_engine; + scaled_upper_triang_masked_softmax_forward(*reinterpret_cast(input), + reinterpret_cast(softmax_results), + scale_factor, stream); +} + +void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + musaStream_t stream) { + using namespace transformer_engine; + scaled_upper_triang_masked_softmax_backward( + *reinterpret_cast(output_grads), *reinterpret_cast(incoming_grads), + *reinterpret_cast(softmax_results), scale_factor, stream); +} diff --git a/transformer_engine/musa/common/gemm/mudnn_gemm.cpp b/transformer_engine/musa/common/gemm/mudnn_gemm.cpp new file mode 100644 index 0000000000..0aff2cf5b7 --- /dev/null +++ b/transformer_engine/musa/common/gemm/mudnn_gemm.cpp @@ -0,0 +1,383 @@ +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/mudnn.h" +#include "../util/mtfp8_utils.muh" + +namespace transformer_engine { + +namespace { + +using at::musa::InternalMemAlloc; +using at::musa::GetComputeModeFromCtx; +using transformer_engine::musa::Flat2DimShape; +using transformer_engine::musa::CreateMUTensor; +using transformer_engine::musa::ToTorchDtype; +using transformer_engine::musa::SetMUTensorDType; +using mtfp8::next_power_of_2; + +const auto empty_te_tensor = Tensor(); +const auto empty_mu_tensor = at::musa::CreateEmptyMUTensor(); + +std::once_flag init_flag; +musaStream_t compute_streams[num_streams]; +musaEvent_t cublas_event[num_streams]; +bool multistream_to_use; + +void init_streams_and_events() { + for (int i = 0; i < num_streams; i++) { + NVTE_CHECK_CUDA(musaStreamCreateWithPriority(&compute_streams[i], musaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(musaEventCreate(&cublas_event[i])); + } + + multistream_to_use = false; + if (std::getenv("MULTI_STREAM_GROUPGEMM") != nullptr + && std::string(std::getenv("MULTI_STREAM_GROUPGEMM")) == "1") { + multistream_to_use = true; + } +} + +const SimpleTensor* get_data(const Tensor* te_tensor, bool trans) { + if (trans && te_tensor->has_columnwise_data()) { + return &(te_tensor->columnwise_data); + } + return &(te_tensor->data); +} + +struct GEMM_INFO { + const SimpleTensor* data_a = nullptr; + const SimpleTensor* sinv_a = nullptr; + const SimpleTensor* data_b = nullptr; + const SimpleTensor* sinv_b = nullptr; + bool is_per_tensor = true; +}; + +GEMM_INFO get_gemm_info( + const Tensor* a, + bool trans_a, + const Tensor* b, + bool trans_b) { + NVTE_CHECK(a->scaling_mode == b->scaling_mode, + "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK(a->has_data() || a->has_columnwise_data(), "Input A does not hold any data!"); + NVTE_CHECK(b->has_data() || b->has_columnwise_data(), "Input B does not hold any data!"); + + GEMM_INFO info; + info.is_per_tensor = is_tensor_scaling(a->scaling_mode); + if (info.is_per_tensor) { + info.data_a = &(a->data); + info.sinv_a = &(a->scale_inv); + info.data_b = &(b->data); + info.sinv_b = &(b->scale_inv); + return info; + } + + const auto a_data_dim_m = product(a->data.shape, 0, a->data.shape.size() - 1); + const auto a_sinv_dim_m = product(a->scale_inv.shape, 0, a->scale_inv.shape.size() - 1); + const bool weight_is_nn_block = (a_data_dim_m != a_sinv_dim_m); + + if (weight_is_nn_block || trans_a) { + info.data_a = &(a->data); + info.sinv_a = &(a->scale_inv); + } else { + info.data_a = &(a->columnwise_data); + info.sinv_a = &(a->columnwise_scale_inv); + } + + if (trans_b) { + info.data_b = &(b->columnwise_data); + info.sinv_b = &(b->columnwise_scale_inv); + } else { + info.data_b = &(b->data); + info.sinv_b = &(b->scale_inv); + } + + return info; +} + +} // anonymous namespace + +void non_fp8_gemm( + const Tensor* inputA, + bool transa, + const Tensor* inputB, + bool transb, + Tensor* outputD, + const Tensor* biasTensor, + bool accumulate, + int math_sm_count, + musaStream_t stream) { + auto& h = at::GetMudnnHandle(); + h.SetStream(stream); + + const bool has_bias = biasTensor->has_data(); + auto mu_l = CreateMUTensor(*get_data(inputB, transb), Flat2DimShape(inputB)); + auto mu_r = CreateMUTensor(*get_data(inputA, transa), Flat2DimShape(inputA)); + auto mu_b = has_bias ? CreateMUTensor(biasTensor->data) : empty_mu_tensor; + auto mu_o = CreateMUTensor(outputD->data, Flat2DimShape(outputD)); + + ::musa::dnn::MatMul op; + CHECK_MUDNN_STATUS(op.SetTranspose(transb, transa), "SetTranspose"); + CHECK_MUDNN_STATUS( + op.SetComputeMode(GetComputeModeFromCtx(ToTorchDtype(inputB->dtype()))), + "SetComputeMode"); + CHECK_MUDNN_STATUS(op.SetAlpha(1.0), "SetAlpha"); + CHECK_MUDNN_STATUS(op.SetBeta(accumulate ? 1.0 : 0.0), "SetBeta"); + CHECK_MUDNN_STATUS(op.SetGamma(has_bias ? 1.0 : 0.0), "SetGamma"); + + CHECK_MUDNN_STATUS( + op.RunWithBiasAdd( + h, mu_o, mu_l, mu_r, mu_o, mu_b, InternalMemAlloc), + "RunWithBiasAdd"); +} + +void fp8_gemm( + const Tensor* inputA, + bool transa, + const Tensor* inputB, + bool transb, + Tensor* outputD, + const Tensor* biasTensor, + bool accumulate, + int math_sm_count, + musaStream_t stream) { + auto& h = at::GetMudnnHandle(); + h.SetStream(stream); + + const bool has_bias = biasTensor->has_data(); + const bool has_bias_scale = (biasTensor->scale_inv.dptr != nullptr); + + const bool has_output_scale = (outputD->scale.dptr != nullptr); + const bool has_output_amax = (outputD->amax.dptr != nullptr); + + const auto info = get_gemm_info(inputA, transa, inputB, transb); + const auto& data_b = *(info.data_b); + const auto& sinv_b = *(info.sinv_b); + const auto& data_a = *(info.data_a); + const auto& sinv_a = *(info.sinv_a); + + auto mu_l = CreateMUTensor(data_b, Flat2DimShape(inputB)); + auto mu_r = CreateMUTensor(data_a, Flat2DimShape(inputA)); + auto mu_b = has_bias ? CreateMUTensor(biasTensor->data) : empty_mu_tensor; + auto mu_o = CreateMUTensor(outputD->data, Flat2DimShape(outputD)); + if (!has_bias) { + SetMUTensorDType(outputD->dtype(), mu_b); + } + + auto mu_scale_l = CreateMUTensor(sinv_b); + auto mu_scale_r = CreateMUTensor(sinv_a); + auto mu_scale_b = has_bias_scale + ? CreateMUTensor(biasTensor->scale_inv) : empty_mu_tensor; + auto mu_scale_o = has_output_scale + ? CreateMUTensor(outputD->scale): empty_mu_tensor; + auto mu_amax_o = has_output_amax + ? CreateMUTensor(outputD->amax): empty_mu_tensor; + + ::musa::dnn::BatchMatMul op; + CHECK_MUDNN_STATUS(op.SetTranspose(transb, transa), "SetTranspose"); + CHECK_MUDNN_STATUS( + op.SetComputeMode(GetComputeModeFromCtx(ToTorchDtype(inputB->dtype()))), + "SetComputeMode"); + CHECK_MUDNN_STATUS(op.SetAlpha(1.0), "SetAlpha"); + CHECK_MUDNN_STATUS(op.SetBeta(accumulate ? 1.0 : 0.0), "SetBeta"); + CHECK_MUDNN_STATUS(op.SetGamma(has_bias ? 1.0 : 0.0), "SetGamma"); + if (math_sm_count != 0) { + CHECK_MUDNN_STATUS(op.SetMpCountTarget(math_sm_count), "SetMpCountTarget"); + } + + ::musa::dnn::MatMulLtParam param; + if (info.is_per_tensor) { + CHECK_MUDNN_STATUS(param.SetScale(mu_scale_l, mu_scale_r, mu_scale_b, mu_scale_o), "SetScale"); + } else { + NVTE_CHECK(inputB->scale_inv.shape.size() == 2); + const auto tile_size = static_cast(next_power_of_2(inputB->flat_last_dim() / inputB->scale_inv.shape[1])); + CHECK_MUDNN_STATUS(param.SetScale(mu_scale_l, mu_scale_r, mu_scale_b, mu_scale_o, tile_size), "SetScale"); + } + CHECK_MUDNN_STATUS(param.SetAmaxD(mu_amax_o), "SetAmax"); + + CHECK_MUDNN_STATUS(op.RunLt(h, mu_o, mu_l, mu_r, mu_o, mu_b, param, InternalMemAlloc), "RunLt"); +} + +void no_fp8_grad_bias( + const Tensor* gradO, + bool trans, + const Tensor* gradB, + musaStream_t stream) { + using REDUCE_MODE = ::musa::dnn::Reduce::Mode; + const int reduce_dim = trans ? 0 : 1; + + auto& h = at::GetMudnnHandle(); + h.SetStream(stream); + + auto mu_i = CreateMUTensor(gradO->data, Flat2DimShape(gradO)); + auto mu_o = CreateMUTensor(gradB->data, Flat2DimShape(gradB)); + + ::musa::dnn::Reduce rdc; + CHECK_MUDNN_STATUS(rdc.SetMode(REDUCE_MODE::ADD), "SetMode"); + CHECK_MUDNN_STATUS(rdc.SetDim({reduce_dim}), "SetDim"); + CHECK_MUDNN_STATUS(rdc.Run(h, mu_o, mu_i, InternalMemAlloc), "Run"); +} + +} // namespace transformer_engine + +// D = B @ A.T +void mudnn_gemm( + const NVTETensor A, + const NVTETensor B, + NVTETensor D, + const NVTETensor bias, + NVTETensor pre_gelu_out, + bool transa, + bool transb, + bool grad, + NVTETensor workspace, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + musaStream_t stream) { + using namespace transformer_engine; + + const auto* inputA = reinterpret_cast(A); + const auto* inputB = reinterpret_cast(B); + auto* outputD = reinterpret_cast(D); + const auto* biasTensor = reinterpret_cast(bias); + auto* geluOut = reinterpret_cast(pre_gelu_out); + + NVTE_CHECK(outputD->has_data()); + NVTE_CHECK(!geluOut->has_data(), "Gelu epilogue is not supported!"); + + const auto A_type = inputA->dtype(); + const auto is_fp8_A = is_fp8_dtype(A_type); + + const auto B_type = inputB->dtype(); + const auto is_fp8_B = is_fp8_dtype(B_type); + + NVTE_CHECK( + is_fp8_A == is_fp8_B, + "Inputs to muDNN GEMM must all be non-fp8 or fp8 dtypes!"); + if (!is_fp8_A) { + NVTE_CHECK( + A_type == B_type, + "Both inputs to muDNN non-FP8 GEMM must have the same dtype!"); + } + if (biasTensor->has_data() && !grad) { + NVTE_CHECK( + biasTensor->data.shape.size() == 1 && + biasTensor->data.shape[0] == outputD->flat_last_dim(), + "Mismatch bias shape, expect ", + outputD->flat_last_dim(), + ", but got ", + biasTensor->data.shape[0]); + } + + const auto* fwd_bias = grad ? &transformer_engine::empty_te_tensor : biasTensor; + if (is_fp8_A) { + fp8_gemm(inputA, transa, inputB, transb, outputD, fwd_bias, accumulate, math_sm_count, stream); + } else { + non_fp8_gemm(inputA, transa, inputB, transb, outputD, fwd_bias, accumulate, math_sm_count, stream); + } + + if (!grad || !(biasTensor->has_data())) { + return; + } + + if (!is_fp8_A) { + no_fp8_grad_bias(inputB, transb, biasTensor, stream); + } +} + +void nvte_cublas_gemm( + const NVTETensor A, + const NVTETensor B, + NVTETensor D, + const NVTETensor bias, + NVTETensor pre_gelu_out, + bool transa, + bool transb, + bool grad, + NVTETensor workspace, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + musaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm); + mudnn_gemm( + A, B, D, bias, pre_gelu_out, transa, transb, grad, workspace, + accumulate, use_split_accumulator, math_sm_count, stream); +} + +void nvte_cublas_atomic_gemm( + const NVTETensor A, + const NVTETensor B, + NVTETensor D, + const NVTETensor bias, + NVTETensor pre_gelu_out, + bool transa, + bool transb, + bool grad, + NVTETensor workspace, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + int m_split, + int n_split, + bool gemm_producer, + const NVTETensor counter, + musaStream_t stream) { + NVTE_API_CALL(nvte_cublas_atomic_gemm); + NVTE_CHECK(false, "atomic_gemm is not supported."); +} + +void nvte_multi_stream_cublas_gemm( + const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + const NVTETensor* bias, + NVTETensor* pre_gelu_out, + const int num_gemms, + bool transa, + bool transb, + bool grad, + NVTETensor* workspace, + bool accumulate, + bool use_split_accumulator, + int math_sm_count, + musaStream_t stream) { + NVTE_API_CALL(nvte_multi_stream_cublas_gemm); + using namespace transformer_engine; + + std::call_once(init_flag, init_streams_and_events); + + int num_stream_used = std::min(num_streams, num_gemms); + // wait for current stream to finish + NVTE_CHECK_CUDA(musaEventRecord(cublas_event[0], stream)); + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(compute_streams[s], cublas_event[0])); + } + + for (int i = 0; i < num_gemms; i++) { + musaStream_t stream_to_use; + if (multistream_to_use) { + stream_to_use = compute_streams[i % num_streams]; + } else { + stream_to_use = musaStreamDefault; + } + + mudnn_gemm( + A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, + workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, + stream_to_use); // compute_streams[i % num_streams] + } + if (multistream_to_use) { + // record events on compute streams + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA(musaEventRecord(cublas_event[s], compute_streams[s])); + } + // wait for all compute streams to finish + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA(musaStreamWaitEvent(stream, cublas_event[s])); + } + } +} diff --git a/transformer_engine/musa/common/include/transformer_engine/activation.h b/transformer_engine/musa/common/include/transformer_engine/activation.h new file mode 100644 index 0000000000..910bd8aace --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/activation.h @@ -0,0 +1,273 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file activation.h + * \brief Activation functions. + */ + +#ifndef TRANSFORMER_ENGINE_ACTIVATION_H_ +#define TRANSFORMER_ENGINE_ACTIVATION_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ + +/*! \brief Computes activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ + +enum class NVTE_Activation_Type { + GELU, + GEGLU, + SILU, + SWIGLU, + RELU, + REGLU, + QGELU, + QGEGLU, + SRELU, + SREGLU, +}; + +/*! \brief Computes the GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_gelu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the SiLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_silu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_relu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_qgelu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_srelu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the SiLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the gated GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_geglu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated Swish activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_swiglu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_reglu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_qgeglu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_sreglu(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the gated Swish activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the gated ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the gated Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +/*! \brief Computes the gated Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_ACTIVATION_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/cast.h b/transformer_engine/musa/common/include/transformer_engine/cast.h new file mode 100644 index 0000000000..52c21267cf --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/cast.h @@ -0,0 +1,219 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast.h + * \brief Functions to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_H_ +#define TRANSFORMER_ENGINE_CAST_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) + * The implementation is per the microscaling format MXFP8 defined by the OCP specification: + * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + * + * Supported modes of scaling (live scaling): + * 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: + * - the scaled output tensor + * - the corresponding scaling factors + * The scaling factors are computed for blocks of the shape [1,32] + * (i.e., each scaling factor spans 32 contiguous elements along rows). + * + * 2) Columwise scaling (along the dim=1) computes one set of the output data. + * The scaling factors are computed for blocks of the shape [32,1] + * (i.e., each scaling factor spans 32 contiguous elements along columns). + * + * 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) + * computes two sets of the output data: both 1) and 2). + * + * The shape of the MX block must be specified in the 'output' argument, + * and can be either [1,32] or [32,1] as no other shapes are currently supported. + * + * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter + * of the output tensor should be set to 0. + */ + +/*! \brief Casts input tensor to FP8/MXFP8. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] noop Noop tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + musaStream_t stream); + +/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workplace, musaStream_t stream); + +/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the SiLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Quick GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Squared ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Casts input tensor from reduced to higher precision. + * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, + * the block dequantization (MXFP8) of the specified shape of the block will be used. + * In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise + * data of the output tensor, regardless of whether the row- or columnwise scaling is used. + * + * \param[in] input Input FP8/MXFP8 tensor to be cast. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dequantize(const NVTETensor input, NVTETensor output, musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_CAST_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/musa/common/include/transformer_engine/cast_transpose_noop.h new file mode 100644 index 0000000000..6dc4d8f2f1 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/cast_transpose_noop.h @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file transpose_with_noop.h + * \brief Functions handling transposes with no-op. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ +#define TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, + musaStream_t stream); + +/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, + musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/musa/common/include/transformer_engine/comm_gemm_overlap.h new file mode 100644 index 0000000000..f796bf8af2 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/comm_gemm_overlap.h @@ -0,0 +1,301 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ + +#include +#include +#include + +#include + +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" + +#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 + +// Index corresponds to the type of flag: +// 0 - Receive index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsth, index) \ + ((reinterpret_cast((comm)->mem_ptr[0])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) + + +namespace transformer_engine { + +/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. + * This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option. + * + * \return True if Userbuffers is built with MPI + */ +bool ubuf_built_with_mpi(); + +enum class CommOverlapType { RS = 0, AG = 1 }; + +enum class CommOverlapAlgo { + BULK_OVERLAP_AG = 0, + BULK_OVERLAP_RS = 1, + SPLIT_PIPELINED_AG_P2P = 2, + SPLIT_PIPELINED_RS = 3, + SPLIT_PIPELINED_RS_P2P = 4, + ATOMIC_GEMM_RS = 5, + ATOMIC_GEMM_AG_P2P = 6, + ATOMIC_GEMM_RS_P2P = 7 +}; + +class CommOverlapCore { + protected: + static inline communicator *_ub_comm{nullptr}; + static inline bool _comm_created{false}; + + int _rank; + int _tp_id; + int _tp_size; + int _num_splits; + int _math_sms; + int _num_comm_sm; + int _cga_size; + int _use_ce; + int _ub_reg; + int _gemm_priority; + int _comm_priority; + bool _atomic_gemm{false}; + bool _is_p2p{false}; + + TensorWrapper _ubuf; + TensorWrapper _counter; + float *_ubuf_scale_inv; + bool _ubuf_scale_inv_initialized{false}; + + std::vector _stream_compute; + std::vector _stream_comm_ce; + musaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + + public: + CommOverlapCore() {} // dummy constructor for exposing type to Python + + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, + int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm); + + virtual ~CommOverlapCore(); + + void set_ubuf_scale_inv(float *scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + bool is_atomic_gemm() { return _atomic_gemm; } + + bool is_p2p_overlap() { return _is_p2p; } + + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, musaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, musaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, musaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, musaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + musaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } +}; // CommOverlapCore + +class CommOverlapBase : public CommOverlapCore { + protected: + int _rs_kernel_type; + bool _rs_overlap_first_gemm; + musaStream_t _stream_comm; + musaEvent_t _start_d2dcopy; + + public: + CommOverlapBase() {} // dummy constructor for exposing type to Python + + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = false, + bool rs_overlap_first_gemm = false); + + virtual ~CommOverlapBase(); + + void comm_userbuff_over_ce(void *rs_output, transformer_engine::DType dtype, const int chunk_idx, const int offset, + const int rowelements, const int colelements, const int strideelements, + bool out_of_place, bool comm_rs, bool is_pipeline, musaStream_t compute_stream); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + musaStream_t stream_main) override; + + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + musaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + musaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + musaStream_t stream_main) override; + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + musaStream_t stream_main) override; +}; // CommOverlapBase + +class CommOverlapP2PBase : public CommOverlapCore { + protected: + bool _is_reduce_scatter{false}; + bool _use_multiatomic_ag{false}; + bool _aggregate; + int _next_rank; + int _prev_rank; + int _rank_round_tp; + int _num_ubuf_chunks; + int _self_chunk_id; + std::vector _ubufs; + std::vector _stream_send; + musaStream_t _stream_recv; + musaStream_t _stream_comm_ce; + musaEvent_t _stop_send, _stop_recv, _stop_comm; + + public: + CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); + + virtual ~CommOverlapP2PBase(); + + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); + + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + musaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + musaStream_t stream_main) override; + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + musaStream_t stream_main) override; + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + musaStream_t stream_main) override; + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + musaStream_t stream_main) override; +}; // CommOverlapP2PBase + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/fused_attn.h b/transformer_engine/musa/common/include/transformer_engine/fused_attn.h new file mode 100644 index 0000000000..9af6e940e9 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/fused_attn.h @@ -0,0 +1,548 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file fused_attn.h + * \brief Enums and functions for fused attention. + */ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ + +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \enum NVTE_QKV_Layout + * \brief Memory layouts of QKV tensors. + * `S`, `B`, `H`, `D`, and `T` stand for sequence length, batch size, number of heads, + * head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. + * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length + * or padded to the same length, and `THD`-based layouts are used when sequences have + * different lengths in a batch. + */ +enum NVTE_QKV_Layout { + NVTE_SB3HD = 0, /*!< SB3HD layout */ + NVTE_SBH3D = 1, /*!< SBH3D layout */ + NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ + NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ + NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ + NVTE_BS3HD = 5, /*!< BS3HD layout */ + NVTE_BSH3D = 6, /*!< BSH3D layout */ + NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ + NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ + NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ + NVTE_T3HD = 10, /*!< T3HD layout */ + NVTE_TH3D = 11, /*!< TH3D layout */ + NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ + NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ + NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ +}; + +/*! \enum NVTE_QKV_Layout_Group + * \brief QKV layout groups + */ +enum NVTE_QKV_Layout_Group { + /*! 3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD */ + NVTE_3HD = 0, + /*! H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D */ + NVTE_H3D = 1, + /*! HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD */ + NVTE_HD_2HD = 2, + /*! HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D */ + NVTE_HD_H2D = 3, + /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ + NVTE_HD_HD_HD = 4, +}; + +/*! \enum NVTE_QKV_Format + * \brief QKV formats + */ +enum NVTE_QKV_Format { + /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */ + NVTE_SBHD = 0, + /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */ + NVTE_BSHD = 1, + /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ + NVTE_THD = 2, +}; + +/*! \enum NVTE_Bias_Type + * \brief Bias types + */ +enum NVTE_Bias_Type { + /*! No bias */ + NVTE_NO_BIAS = 0, + /*! Bias before scale */ + NVTE_PRE_SCALE_BIAS = 1, + /*! Bias after scale */ + NVTE_POST_SCALE_BIAS = 2, + /*! ALiBi */ + NVTE_ALIBI = 3, +}; + +/*! \enum NVTE_Mask_Type + * \brief Attention mask types + */ +enum NVTE_Mask_Type { + /*! No masking */ + NVTE_NO_MASK = 0, + /*! Padding attention mask */ + NVTE_PADDING_MASK = 1, + /*! Causal attention mask (aligned to the top left corner) */ + NVTE_CAUSAL_MASK = 2, + /*! Padding and causal attention mask (aligned to the top left corner) */ + NVTE_PADDING_CAUSAL_MASK = 3, + /*! Causal attention mask (aligned to the bottom right corner) */ + NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4, + /*! Padding and causal attention mask (aligned to the bottom right corner) */ + NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5, +}; + +/*! \enum NVTE_Fused_Attn_Backend + * \brief Fused attention backends + */ +enum NVTE_Fused_Attn_Backend { + /*! No supported backend */ + NVTE_No_Backend = -1, + /*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */ + NVTE_F16_max512_seqlen = 0, + /*! cuDNN-based FP16/BF16 fused attention for any sequence length */ + NVTE_F16_arbitrary_seqlen = 1, + /*! cuDNN-based FP8 fused attention for <= 512 sequence length */ + NVTE_FP8 = 2, +}; + +/*! \brief Get QKV layout group for a given QKV layout. + * + * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * + * \return qkv layout group, e.g. h3d. + */ +NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout); + +/*! \brief Get QKV format for a given QKV layout. + * + * \param[in] qkv_layout QKV layout, e.g. sbh3d. + * + * \return qkv format, e.g. sbhd. + */ +NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); + +/*! \brief Get fused attention backend based on input parameters. + * + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + */ +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); + +/*! \brief Compute dot product attention with packed QKV input. + * + * Computes: + * - P = Q * Transpose(K) + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * Transpose(V) + * + * Support Matrix: + \verbatim + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | + | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | + | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | + \endverbatim + * + * Notes: + * + * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences + * in tensors Q, K, V and O. + * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, + * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`. + * When the QKV format is `thd`, this tensor should follow the following rules. + * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`, + * When there is padding between sequences, users are responsible to adjust the offsets as needed. + * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have + * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. + * + * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, + * e.g. M, ZInv, rng_state. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(seqlen_i) for i=0,...batch_size-1. + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, + NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + const NVTETensor rng_state, size_t max_seqlen, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, musaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with packed QKV input. + * + * Support Matrix: + \verbatim + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | + | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | + | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | + \endverbatim + * + * Notes: + * + * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences + * in tensors Q, K, V and O. + * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, + * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`. + * When the QKV format is `thd`, this tensor should follow the following rules. + * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`, + * When there is padding between sequences, users are responsible to adjust the offsets as needed. + * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have + * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. + * + * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, + * e.g. M, ZInv, rng_state. + * \param[out] dQKV The gradient of the QKV tensor. + * \param[out] dBias The gradient of the Bias tensor. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(seqlen_i) for i=0,...batch_size-1. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, + NVTETensor dBias, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, size_t max_seqlen, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + bool deterministic, NVTETensor workspace, musaStream_t stream); + +/*! \brief Compute dot product attention with packed KV input. + * + * Computes: + * - P = Q * Transpose(K) + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * Transpose(V) + * + * Support Matrix: + \verbatim + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | + | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | + \endverbatim + * + * Notes: + * + * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` + * help identify the correct offsets of different sequences in tensors Q, K, V and O. + * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, + * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. + * When the QKV format is `thd`, these tensors should follow the following rules. + * When there is no padding between sequences, the offset tensors should be equal to + * `cu_seqlens_q` and `cu_seqlens_kv` respectively. + * When there is padding between sequences, users are responsible to adjust the offsets as needed. + * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have + * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. + * + * \param[in] Q The Q tensor, in HD layouts. + * \param[in] KV The KV tensor, in 2HD or H2D layouts. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, + * e.g. M, ZInv, rng_state. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, musaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with packed KV input. + * + * Support Matrix: + \verbatim + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | + | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | + \endverbatim + * + * Notes: + * + * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` + * help identify the correct offsets of different sequences in tensors Q, K, V and O. + * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, + * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. + * When the QKV format is `thd`, these tensors should follow the following rules. + * When there is no padding between sequences, the offset tensors should be equal to + * `cu_seqlens_q` and `cu_seqlens_kv` respectively. + * When there is padding between sequences, users are responsible to adjust the offsets as needed. + * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have + * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. + * + * \param[in] Q The Q tensor, in HD layouts. + * \param[in] KV The KV tensor, in H2D or 2HD layouts. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, + * e.g. M, ZInv, rng_state. + * \param[out] dQ The gradient of the Q tensor. + * \param[out] dKV The gradient of the KV tensor. + * \param[out] dBias The gradient of the Bias tensor. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, + const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, + NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Compute dot product attention with separate Q, K and V. + * + * Computes: + * - P = Q * Transpose(K) + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * Transpose(V) + * + * Support Matrix: + \verbatim + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | + | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | + | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | + | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | + | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | + \endverbatim + * + * Notes: + * + * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` + * help identify the correct offsets of different sequences in tensors Q, K, V and O. + * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, + * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. + * When the QKV format is `thd`, these tensors should follow the following rules. + * When there is no padding between sequences, the offset tensors should be equal to + * `cu_seqlens_q` and `cu_seqlens_kv` respectively. + * When there is padding between sequences, users are responsible to adjust the offsets as needed. + * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have + * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. + * + * \param[in] Q The Q tensor. + * \param[in] K The K tensor. + * \param[in] V The V tensor. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, + * e.g. M, ZInv, rng_state. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. + * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensors' layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with separate Q, K and V. + * + * Support Matrix: + \verbatim + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | + | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | + | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | + | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | + | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | + \endverbatim + * + * Notes: + * + * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` + * help identify the correct offsets of different sequences in tensors Q, K, V and O. + * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, + * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. + * When the QKV format is `thd`, these tensors should follow the following rules. + * When there is no padding between sequences, the offset tensors should be equal to + * `cu_seqlens_q` and `cu_seqlens_kv` respectively. + * When there is padding between sequences, users are responsible to adjust the offsets as needed. + * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have + * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. + * + * \param[in] Q The Q tensor. + * \param[in] K The K tensor. + * \param[in] V The V tensor. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, + * e.g. M, ZInv, rng_state. + * \param[out] dQ The gradient of the Q tensor. + * \param[out] dK The gradient of the K tensor. + * \param[out] dV The gradient of the V tensor. + * \param[out] dBias The gradient of the Bias tensor. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. + * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. + * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. + * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensors' layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, + NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, + size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, + musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/musa/common/include/transformer_engine/fused_rope.h b/transformer_engine/musa/common/include/transformer_engine/fused_rope.h new file mode 100644 index 0000000000..5c66d8c062 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/fused_rope.h @@ -0,0 +1,129 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_ +#define TRANSFORMER_ENGINE_FUSED_ROPE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Apply rotary positional embedding to the input tensor. + * + * \param[in] input Input tensor for fused rope. + * \param[in] freqs The freqs tensor. + * \param[out] output Output tensor. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] stride_s Stride of the s dimension of input. + * \param[in] stride_b Stride of the b dimension of input. + * \param[in] stride_h Stride of the h dimension of input. + * \param[in] stride_d Stride of the d dimension of input. + * \param[in] o_stride_s Stride of the s dimension of output. + * \param[in] o_stride_b Stride of the b dimension of output. + * \param[in] o_stride_h Stride of the h dimension of output. + * \param[in] o_stride_d Stride of the d dimension of output. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, musaStream_t stream); + +/*! \brief Compute the backward of the fused rope. + * + * \param[in] output_grads Incoming gradient tensor for backward. + * \param[in] freqs The freqs tensor. + * \param[out] input_grads Input gradient tensor to calculate. + * \param[in] s Length of the s dimension of output_grads. + * \param[in] b Length of the b dimension of output_grads. + * \param[in] h Length of the h dimension of output_grads. + * \param[in] d Length of the d dimension of output_grads. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] stride_s Stride of the s dimension of output_grads. + * \param[in] stride_b Stride of the b dimension of output_grads. + * \param[in] stride_h Stride of the h dimension of output_grads. + * \param[in] stride_d Stride of the d dimension of output_grads. + * \param[in] o_stride_s Stride of the s dimension of input_grads. + * \param[in] o_stride_b Stride of the b dimension of input_grads. + * \param[in] o_stride_h Stride of the h dimension of input_grads. + * \param[in] o_stride_d Stride of the d dimension of input_grads. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, + NVTETensor input_grads, const int s, const int b, const int h, + const int d, const int d2, const int stride_s, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, const int o_stride_d, + musaStream_t stream); + +/*! \brief Apply rotary positional embedding to the input tensor in thd format. + * + * \param[in] input Input tensor for fused rope. + * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. + * \param[in] freqs The freqs tensor. + * \param[out] output Output tensor. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] max_s Max sequence length. + * \param[in] b Batch size. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] stride_t Stride of the t dimension of input. + * \param[in] stride_h Stride of the h dimension of input. + * \param[in] stride_d Stride of the d dimension of input. + * \param[in] o_stride_t Stride of the t dimension of output. + * \param[in] o_stride_h Stride of the h dimension of output. + * \param[in] o_stride_d Stride of the d dimension of output. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, musaStream_t stream); + +/*! \brief Compute the backward of the fused rope in thd format. + * + * \param[in] output_grads Incoming gradient tensor for backward. + * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. + * \param[in] freqs The freqs tensor. + * \param[out] input_grads Input gradient to calculate. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] max_s Max sequence length. + * \param[in] b Batch size. + * \param[in] h Length of the h dimension of output_grads. + * \param[in] d Length of the d dimension of output_grads. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] stride_t Stride of the t dimension of output_grads. + * \param[in] stride_h Stride of the h dimension of output_grads. + * \param[in] stride_d Stride of the d dimension of output_grads. + * \param[in] o_stride_t Stride of the t dimension of input_grads. + * \param[in] o_stride_h Stride of the h dimension of input_grads. + * \param[in] o_stride_d Stride of the d dimension of input_grads. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, + const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, + const int o_stride_d, musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_FUSED_ROPE_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/gemm.h b/transformer_engine/musa/common/include/transformer_engine/gemm.h new file mode 100644 index 0000000000..21e3df6bee --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/gemm.h @@ -0,0 +1,124 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gemm.h + * \brief Functions for matrix multiplication. + */ + +#ifndef TRANSFORMER_ENGINE_GEMM_H_ +#define TRANSFORMER_ENGINE_GEMM_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. + * + * Computes: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * \param[in] A The A matrix. + * \param[in] B The B matrix. + * \param[in,out] D Output matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_gelu_out Output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of the + * gradient computation. + * \param[out] workspace Workspace tensor. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. + * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, + NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, + NVTETensor workspace, bool accumulate, bool use_split_accumulator, + int math_sm_count, musaStream_t stream); + +/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. + * + * \warning Cublas atomic gemm uses a beta API and is not tested for all use cases. + * + * Computes: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * \param[in] A The A matrix. + * \param[in] B The B matrix. + * \param[in,out] D Output matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_gelu_out Output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of the + * gradient computation. + * \param[out] workspace Workspace tensor. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. + * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) + * \param[in] m_split Number of chunks/splits along m-dimension for Atomic GEMM. + * \param[in] n_split Number of chunks/splits along n-dimension for Atomic GEMM. + * \param[in] gemm_producer Whether Atomic GEMM is the producer or consumer. + * \param[in,out] counter counter[chunk_i]=0 indicates chunk_i has been produced. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, + const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, + bool transb, bool grad, NVTETensor workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const NVTETensor counter, + musaStream_t stream); + +/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations, + * on multiple streams. + * + * Computes: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * \param[in] A The list of A matrices. + * \param[in] B The list of B matrices. + * \param[in,out] D List of output matrices. + * \param[in] bias List of bias tensors. + * \param[in,out] pre_gelu_out List of output matrix before GELU activation. + * \param[in] num_gemms Number of GEMMs to compute. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of the + * gradient computation. + * \param[out] workspace List of workspace tensors. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. + * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) + * \param[in] stream CUDA stream to wait on. + */ +void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + const NVTETensor* bias, NVTETensor* pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor* workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, + musaStream_t stream); +#ifdef __cplusplus +} // extern "C" +#endif + +/*! \namespace transformer_engine + */ +namespace transformer_engine { + +constexpr int num_streams = 4; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GEMM_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/normalization.h b/transformer_engine/musa/common/include/transformer_engine/normalization.h new file mode 100644 index 0000000000..e75b941744 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/normalization.h @@ -0,0 +1,156 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file normalization.h + * \brief LayerNorm and RMSNorm functions. + */ + +#ifndef TRANSFORMER_ENGINE_NORMALIZATION_H_ +#define TRANSFORMER_ENGINE_NORMALIZATION_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Compute LayerNorm on the input. + * + * The formula used: + * @f[ + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta + * @f] + * + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] x Input tensor of shape [N, H]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[in] beta Beta tensor of shape [H]. + * \param[in] epsilon Value added to denominator for numerical stability. + * \param[in,out] z Output tensor of shape [N, H]. + * \param[out] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[out] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[out] workspace Workspace tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, musaStream_t stream); + +/*! \brief Compute backward of LayerNorm. + * + * This function computes the gradient of function: + * @f[ + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta + * @f] + * else + * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. + * + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of these tensors to the required values. + * + * \param[in] dz Incoming gradient tensor of shape [N, H]. + * \param[in] x Forward input tensor of shape [N, H]. + * \param[in] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[in] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[out] dx Output gradient of shape [N, H]. + * \param[out] dgamma Gradient for gamma tensor of shape [H]. + * \param[out] dbeta Gradient for beta tensor of shape [H]. + * \param[out] workspace Workspace tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor mu, + const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, + NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + musaStream_t stream); + +/*! \brief Compute RMSNorm. + * + * The formula used: + * @f[ + * y = \frac{x}{RMS_\varepsilon(x)}\gamma + * @f] + * where + * @f[ + * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} + * @f] + * + * Calling this function with workspace and barrier set to empty tensor will not + * perform the operation, but instead set the shape and type of the workspace + * and barrier tensors to the required values. + * + * \param[in] x Input tensor of shape [N, H]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[in] epsilon Value added to denominator for numerical stability. + * \param[in,out] z Output tensor of shape [N, H]. + * \param[out] rsigma Reciprocal of the root mean square of the input + * calculated over the last dimension. Shape: [N]. + * \param[out] workspace Workspace tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, + NVTETensor rsigma, NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, musaStream_t stream); + +/*! \brief Compute backward of RMSNorm. + * + * This function computes the gradient of function: + * @f[ + * y = \frac{x}{RMS_\varepsilon(x)}\gamma + * @f] + * where + * @f[ + * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} + * @f] + * with respect to \f$x\f$ and \f$gamma\f$. + * + * Calling this function with workspace, barrier, dgamma_part set + * to empty tensor will not perform the operation, but instead set the shape and type + * of these tensors to the required values. + * + * \param[in] dz Incoming gradient tensor of shape [N, H]. + * \param[in] x Forward input tensor of shape [N, H]. + * \param[in] rsigma Reciprocal of the root mean square of the input + * calculated over the last dimension. Shape: [N]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[out] dx Output gradient of shape [N, H]. + * \param[out] dgamma Gradient for gamma tensor of shape [H]. + * \param[out] workspace Workspace tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, + const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, musaStream_t stream); + +/*! \brief Helper to enable cuDNN backend for normalization + * + * \param[in] bool Enable if True + */ +void nvte_enable_cudnn_norm_fwd(bool enable); +void nvte_enable_cudnn_norm_bwd(bool enable); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_NORMALIZATION_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/padding.h b/transformer_engine/musa/common/include/transformer_engine/padding.h new file mode 100644 index 0000000000..1be7619c87 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/padding.h @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file padding.h + * \brief Functions handling padding. + */ + +#ifndef TRANSFORMER_ENGINE_PADDING_H_ +#define TRANSFORMER_ENGINE_PADDING_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Padding multiple tensors. + * + * NOTE: Padding mode only support bottom. + * + * For example, 3x3 matrix pad to 4x3 matrix. + * + * source + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * + * destination + * | 1 | 2 | 3 | + * | 4 | 5 | 6 | + * | 7 | 8 | 9 | + * | 0 | 0 | 0 | + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of padded tensors. Dimensions + * match tensors in input_list. + * \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_PADDING_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/permutation.h b/transformer_engine/musa/common/include/transformer_engine/permutation.h new file mode 100644 index 0000000000..ffd09e19d4 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/permutation.h @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_ +#define TRANSFORMER_ENGINE_PERMUTATION_H_ + +#include "transformer_engine.h" + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, musaStream_t stream = nullptr); + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + musaStream_t stream = nullptr); + +void nvte_permute_mask(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor probs, NVTETensor permuted_probs, const int num_tokens, + const int num_experts, const int num_out_tokens, const int hidden_size, + musaStream_t stream = nullptr); +void nvte_unpermute_mask(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor merging_probs, const NVTETensor permuted_probs, + NVTETensor unpermuted_probs, const int num_tokens, const int num_experts, + const int hidden_size, musaStream_t stream = nullptr); +// HACK(sherry): +void nvte_permute_mask_high_precision_probs(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor probs, NVTETensor permuted_probs, const int num_tokens, + const int num_experts, const int num_out_tokens, const int hidden_size, + musaStream_t stream = nullptr); +void nvte_unpermute_mask_high_precision_probs(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor merging_probs, const NVTETensor permuted_probs, + NVTETensor unpermuted_probs, const int num_tokens, const int num_experts, + const int hidden_size, musaStream_t stream = nullptr); +// HACK(sherry) + +void nvte_unpermute_mask_bwd_with_merging_probs( + const NVTETensor fwd_output_grad, NVTETensor fwd_input_grad, const NVTETensor fwd_input, + const NVTETensor merging_probs, NVTETensor merging_probs_grad, NVTETensor row_id_map, + const int num_tokens, const int num_experts, const int hidden_size, + musaStream_t stream = nullptr); + +#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/recipe.h b/transformer_engine/musa/common/include/transformer_engine/recipe.h new file mode 100644 index 0000000000..ffd3818e1c --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/recipe.h @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file recipe.h + * \brief Functions handling FP8 recipes. + */ + +#ifndef TRANSFORMER_ENGINE_RECIPE_H_ +#define TRANSFORMER_ENGINE_RECIPE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Update FP8 scaling factors with delayed scaling recipe. + * + * The amax history is rotated by -1 (e.g. the first entry shifts to + * the last, the last entry shifts to the second to last) and the + * first entry is set to zero. The scaling factor is estimated so the + * FP8 tensor's maximum absolute value is + * @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. + * + * \param[in] amax_history History of maximum absolute values. + * Shape: [history_length, num_scales] + * \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] + * \param[out] updated_amax_history Updated history of maximum absolute values. + * Shape: [history_length, num_scales] + * \param[out] updated_scale Updated scaling factor for casting to FP8. + * Shape: [num_scales] + * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and + * "most_recent". + * \param[in] fp8_dtype FP8 datatype. + * \param[in] margin Scaling factor margin. + * \param[in] stream CUDA stream. + */ +void nvte_delayed_scaling_recipe_amax_and_scale_update( + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + musaStream_t stream); + +/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. + * + * Operations performed include, updating the most recent amax history + * with the relevant segment of global reduction buffer if it's not 0, + * rotating the amax history based on the rule below, and updating the + * scales. + * + * The amax history is rotated by -1 (e.g. the first entry shifts to + * the last, the last entry shifts to the second to last) and the + * first entry is set to zero. The scaling factor is estimated so the + * FP8 tensor's maximum absolute value is + * @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. + * + * \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction. + * Shape: [num_scales * num_tensors] + * \param[in,out] amax_histories List of amax histories of maximum absolute values. + * Shape: num_tensors x [history_length, num_scales] + * \param[in,out] scales List of scaling factors for casting to FP8. + * Shape: num_tensors x [num_scales] + * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and + * "most_recent". + * \param[in] fp8_dtype FP8 datatype. + * \param[in] margin Scaling factor margin. + * \param[in] stream CUDA stream. + */ +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, std::vector amax_histories, + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_RECIPE_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/softmax.h b/transformer_engine/musa/common/include/transformer_engine/softmax.h new file mode 100644 index 0000000000..485f445969 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/softmax.h @@ -0,0 +1,132 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_SOFTMAX_H_ +#define TRANSFORMER_ENGINE_SOFTMAX_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Compute scaled softmax activation on the input. + * + * \param[in] input Input tensor for softmax. + * \param[out] softmax_results Output tensor. + * \param[in] scale_factor Scalar for the input tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_results, + float scale_factor, musaStream_t stream); + +/*! \brief Compute the backward of the scaled softmax activation. + * + * - `incoming_grads` is the input tensor containing the gradients received from the following layer. + * - `softmax_results` is the output tensor of the corresponding forward softmax operation. + * - `output_grads` is the output tensor containing the computed gradients. + * + * \param[in] incoming_grads Input gradient tensor for backward. + * \param[in] softmax_results Output tensor of softmax forward. + * \param[out] output_grads Output tensor. + * \param[in] scale_factor Scalar for the output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, musaStream_t stream); + +/*! \brief Compute scaled masked softmax activation on the input. + * + * \param[in] input Input tensor for softmax. + * \param[in] mask Mask for the input tensor. + * \param[out] softmax_results Output tensor. + * \param[in] scale_factor Scalar for the input tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, + NVTETensor softmax_results, float scale_factor, + musaStream_t stream); + +/*! \brief Compute the backward of the scaled masked softmax activation. + * + * - `incoming_grads` is the input tensor containing the gradients received from the following layer. + * - `softmax_results` is the output tensor of the corresponding forward softmax operation. + * - `output_grads` is the output tensor containing the computed gradients. + * + * \param[in] incoming_grads Input gradient tensor for backward. + * \param[in] softmax_results Output tensor of softmax forward. + * \param[out] output_grads Output tensor. + * \param[in] scale_factor Scalar for the output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, NVTETensor output_grads, + float scale_factor, musaStream_t stream); + +/*! \brief Compute scaled softmax activation using a 2D upper triangular mask on the input. + * + * \param[in] input Input tensor for softmax. + * \param[out] softmax_results Output tensor. + * \param[in] scale_factor Scalar for the input tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, float scale_factor, + musaStream_t stream); + +/*! \brief Compute the backward of the scaled softmax activation using a 2D upper triangular mask. + * + * - `incoming_grads` is the input tensor containing the gradients received from the following layer. + * - `softmax_results` is the output tensor of the corresponding forward softmax operation. + * - `output_grads` is the output tensor containing the computed gradients. + * + * \param[in] incoming_grads Input gradient tensor for backward. + * \param[in] softmax_results Output tensor of softmax forward. + * \param[out] output_grads Output tensor. + * \param[in] scale_factor Scalar for the output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + musaStream_t stream); + +/*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. + * + * \param[in] input Input tensor for softmax. + * \param[out] softmax_results Output tensor. + * \param[in] scale_factor Scalar for the input tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input, + NVTETensor softmax_results, + float scale_factor, musaStream_t stream); + +/*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix. + * + * - `incoming_grads` is the input tensor containing the gradients received from the following layer. + * - `softmax_results` is the output tensor of the corresponding forward softmax operation. + * - `output_grads` is the output tensor containing the computed gradients. + * + * \param[in] incoming_grads Input gradient tensor for backward. + * \param[in] softmax_results Output tensor of softmax forward. + * \param[out] output_grads Output tensor. + * \param[in] scale_factor Scalar for the output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incoming_grads, + const NVTETensor softmax_results, + NVTETensor output_grads, float scale_factor, + musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_SOFTMAX_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/swizzle.h b/transformer_engine/musa/common/include/transformer_engine/swizzle.h new file mode 100644 index 0000000000..45e6bff8e9 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/swizzle.h @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast.h + * \brief Functions to cast to/from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_ +#define TRANSFORMER_ENGINE_SWIZZLE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] input Input tensor with non-swizzled scale_inv. + * \param[in,out] output Output tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_SWIZZLE_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/transformer_engine.h b/transformer_engine/musa/common/include/transformer_engine/transformer_engine.h new file mode 100644 index 0000000000..ddbd42e7c4 --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/transformer_engine.h @@ -0,0 +1,618 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file transformer_engine.h + * \brief Base classes and functions of Transformer Engine API. + */ + +#ifndef TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ +#define TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \enum NVTEDType + * \brief TE datatype. + */ +enum NVTEDType { + kNVTEByte = 0, /*!< Byte */ + kNVTEInt32 = 1, /*!< 32-bit integer */ + kNVTEInt64 = 2, /*!< 64-bit integer */ + kNVTEFloat32 = 3, /*!< 32-bit float */ + kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ + kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ + kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ + kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ + kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */ + kNVTENumTypes /*!< Number of supported types */ +}; + +/*! \struct NVTEShape + * \brief Shape of the tensor. + */ +struct NVTEShape { + /*! \brief Shape data, of size ndim. */ + const size_t *data; + /*! \brief Number of dimensions. */ + size_t ndim; +}; + +/*! \struct NVTEBasicTensor + * \brief A basic tensor type used to populate parameters of NVTETensor. + * It does not own the memory it points to. + */ +struct NVTEBasicTensor { + void *data_ptr; + NVTEDType dtype; + NVTEShape shape; +}; + +/*! \enum NVTETensorParam + * \brief Indicates the kind of the tensor parameter to set/get. + */ +enum NVTETensorParam { + kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */ + kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */ + kNVTEScale = 2, /*!< Scale tensor */ + kNVTEAmax = 3, /*!< Amax tensor */ + kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ + kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTENumTensorParams +}; + +/*! \enum NVTEScalingMode + * \brief Granularity of scaling: + */ +enum NVTEScalingMode { + /*! Single scale per tensor, computed in delayed manner. + Used also for high precision data, without scaling */ + NVTE_DELAYED_TENSOR_SCALING = 0, + /*! Single scale per block of 32 elements consecutive in either + rowwise or columnwise direction */ + NVTE_MXFP8_1D_SCALING = 1, + NVTE_MTFP8_BLOCK_SCALING = 2, + NVTE_INVALID_SCALING +}; + +/*! \brief TE Tensor type + * + * NVTETensor is a contiguous tensor type storing a pointer + * to data of a given shape and type. It does not own the + * memory it points to. + */ +typedef void *NVTETensor; + +/*! \brief Create a new TE tensor. + * + * Create a new TE tensor. Before use its parameters need to be set. + * TE tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] scaling_mode Scaling mode of the tensor. + * + * \return A new TE tensor. + */ +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); + +/*! \brief Destroy a TE tensor. + * + * Since the TE tensor does not own memory, the underlying + * data is not freed during this operation. + * + * \param[in] tensor Tensor to be destroyed. + */ +void nvte_destroy_tensor(NVTETensor tensor); + +/*! \brief Get a raw pointer to the tensor's rowwise data. + * + * \param[in] tensor Tensor. + * + * \return A raw pointer to tensor's rowwise data. + */ +void *nvte_tensor_data(const NVTETensor tensor); + +/*! \brief Get a raw pointer to the tensor's columnwise data. + * + * \param[in] tensor Tensor. + * + * \return A raw pointer to tensor's columnwise data. + */ +void *nvte_tensor_columnwise_data(const NVTETensor tensor); + +/*! \brief Get a tensor's data shape. + * + * \param[in] tensor Tensor. + * + * \return A shape of the input tensor. + */ +NVTEShape nvte_tensor_shape(const NVTETensor tensor); + +/*! \brief Get a tensor's data shape. + * + * \param[in] tensor Tensor. + * + * \return A shape of the input tensor. + */ +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor); + +/*! \brief Get a tensor's number of dimensions. + * + * \param[in] tensor Tensor. + * + * \return Number of tensor dimensions. + */ +size_t nvte_tensor_ndims(const NVTETensor tensor); + +/*! \brief Get the size of a specific tensor dimension. + * + * \param[in] tensor Tensor. + * \param[in] size_t Dimension index. + * + * \return Size of the tensor at the specified dimension. + */ +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); + +/*! \brief Get a tensor's total number of elements. + * + * \param[in] tensor Tensor. + * + * \return Number of elements in the tensor. + */ +size_t nvte_tensor_numel(const NVTETensor tensor); + +/*! \brief Get the byte size for the tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return Byte size of the tensor's data type. + */ +size_t nvte_tensor_element_size(const NVTETensor tensor); + +/*! \brief Get a tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return A data type of the input tensor. + */ +NVTEDType nvte_tensor_type(const NVTETensor tensor); + +/*! \brief Get a pointer to the tensor's amax data. + * + * \param[in] tensor Tensor. + * + * \return A pointer to tensor's amax data. + */ +float *nvte_tensor_amax(const NVTETensor tensor); + +/*! \brief Get a pointer to the tensor's scale data. + * + * \param[in] tensor Tensor. + * + * \return A pointer to tensor's scale data. + */ +float *nvte_tensor_scale(const NVTETensor tensor); + +/*! \brief Get a pointer to the tensor's inverse of scale data. + * + * \param[in] tensor Tensor. + * + * \return A pointer to tensor's inverse of scale data. + */ +float *nvte_tensor_scale_inv(const NVTETensor tensor); + +/*! \brief Get a tensor's scale_inv shape. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor); + +/*! \brief Reset tensor value to zero. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +void nvte_zero_tensor(const NVTETensor tensor, musaStream_t stream); + +/*! \brief Set a parameter of the tensor. + * + * \param[in/out] tensor Tensor. + * \param[in] param_name The parameter to be set. + * \param[in] param The value to be set. + */ +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param); + +/*! \brief Get a value of the parameter of the tensor. + * + * \param[in] tensor Tensor. + * \param[in] param_name The parameter to be set. + */ +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name); + +/*! \brief Get the granularity of scaling of this tensor. + * + * \param[in] tensor Tensor. + * + * \return A struct containing the granularity of tensor's scaling. + */ +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor); + +/*! \struct NVTETensorPack + \brief Pack of tensors, generally used for auxiliary outputs. + */ +struct NVTETensorPack { + /*! Max number of tensors in the pack. Assumed <= 10. */ + static const int MAX_SIZE = 10; + /*! Wrappers of tensors. They do not hold the associated memory. */ + NVTETensor tensors[MAX_SIZE]; + /*! Actual number of tensors in the pack, 0 <= size <= MAX_SIZE. */ + size_t size = 0; +}; + +/*! \brief Create `tensors` in NVTETensorPack. + */ +void nvte_tensor_pack_create(NVTETensorPack *pack); + +/*! \brief Destroy `tensors` in NVTETensorPack. + */ +void nvte_tensor_pack_destroy(NVTETensorPack *pack); + +#ifdef __cplusplus +} // extern "C" + +#include + +/*! \namespace transformer_engine + * \brief Namespace containing C++ API of Transformer Engine. + */ +namespace transformer_engine { + +/*! \enum DType + * \brief TE datatype. + */ +enum class DType { + kByte = 0, + kInt32 = 1, + kInt64 = 2, + kFloat32 = 3, + kFloat16 = 4, + kBFloat16 = 5, + kFloat8E4M3 = 6, + kFloat8E5M2 = 7, + kFloat8E8M0 = 8, + kNumTypes +}; + +/*! \struct TensorWrapper + * \brief C++ wrapper for the NVTETensor class. + */ +class TensorWrapper { + public: + /*! \brief Constructs new TensorWrapper. + * + * Create a new TE tensor with a given shape, datatype and data. + * TE tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] dptr Pointer to the tensor data. + * \param[in] shape Shape of the tensor. + * \param[in] dtype Data type of the tensor. + * \param[in] amax_dptr Pointer to the AMAX value. + * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv + * \param[in] scale_inv_dptr Pointer to the inverse of scale value. + */ + TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, + float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, + const NVTEShape scale_inv_shape = defaultShape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { + tensor_ = nvte_create_tensor(scaling_mode); + NVTEBasicTensor data = {dptr, static_cast(dtype), shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data); + NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax); + NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEScale, &scale); + NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv); + } + + /*! \brief Constructs new TensorWrapper. + * + * Create a new TE tensor with a given shape, datatype and data. + * TE tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] dptr Pointer to the tensor data. + * \param[in] shape Shape of the tensor. + * \param[in] dtype Data type of the tensor. + * \param[in] amax_dptr Pointer to the AMAX value. + * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv + * \param[in] scale_inv_dptr Pointer to the inverse of scale value. + */ + TensorWrapper(void *dptr, const std::vector &shape, const DType dtype, + float *amax_dptr = nullptr, float *scale_dptr = nullptr, + float *scale_inv_dptr = nullptr, const std::vector &scale_inv_shape = {1}, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, + scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, + scaling_mode) {} + + /*! \brief Constructs new empty TensorWrapper. + * + * Create a new empty TE tensor which holds nothing. + */ + explicit TensorWrapper(const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_tensor(scaling_mode)) {} + + /*! \brief TensorWrapper destructor. */ + ~TensorWrapper() { nvte_destroy_tensor(tensor_); } + + TensorWrapper &operator=(const TensorWrapper &other) = delete; + TensorWrapper(const TensorWrapper &other) = delete; + + /*! \brief Constructs new TensorWrapper from existing TensorWrapper. + * + * Pass an existing TE tensor to a new TensorWrapper. + * + * \param[in,out] other The source of the data. + */ + TensorWrapper(TensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing TensorWrapper. + * + * Change ownership of an existing TE tensor. + * + * \param[in,out] other The source of the data. + */ + TensorWrapper &operator=(TensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + TensorWrapper &set_parameter(const NVTETensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_tensor_param(&tensor_, param, &data); + return *this; + } + + template + TensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEScale, dptr, type, shape); + } + + template + TensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEAmax, dptr, type, shape); + } + + template + TensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseScaleInv, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); + } + + // Parameter getters + + NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { + return nvte_get_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTERowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEColumnwiseScaleInv); + } + + /*! \brief Get an underlying NVTETensor. + * + * \return NVTETensor held by this TensorWrapper. + */ + NVTETensor data() const noexcept { return tensor_; } + + /*! \brief Get the shape of this TensorWrapper. + * + * \return Shape of this TensorWrapper. + */ + const NVTEShape shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_shape(tensor_); + } + + /*! \brief Get the shape of this TensorWrapper. + * + * \return Shape of this TensorWrapper. + */ + const NVTEShape columnwise_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_columnwise_shape(tensor_); + } + + /*! \brief Get the size of this TensorWrapper in the given dimension. + * + * \param[in] size_t Dimension index. + * + * \return Size of this TensorWrapper in given dimension. + */ + size_t size(const size_t dim) const { + if (tensor_ == nullptr) return 0; + return nvte_tensor_size(tensor_, dim); + } + + /*! \brief Get the number of dimensions for this TensorWrapper. + * + * \return Number of dimensions for this TensorWrapper. + */ + size_t ndim() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_ndims(tensor_); + } + + /*! \brief Get the number of allocated elements in the tensor. This will return 0 for tensors + * with nullptr data even if the TensorWrapper has a non-zero shape. + * + * + * \return Number of elements in the tensor. + */ + size_t numel() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_); + } + + /*! \brief Get the tensor's element size in bytes. + * + * \return Element size in bytes. + */ + size_t element_size() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_element_size(tensor_); + } + + /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr + * data even if the TensorWrapper has a non-zero shape and valid dtype. + * + * \return Total tensor size in bytes. + */ + size_t bytes() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); + } + + /*! \brief Get the data type of this TensorWrapper. + * + * \return Data type of this TensorWrapper. + */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_tensor_type(tensor_)); + } + + /*! \brief Get a raw pointer to the tensor's data. + * + * \return A raw pointer to tensor's data. + */ + void *dptr() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_data(tensor_); + } + + /*! \brief Get a raw pointer to the tensor's data. + * + * \return A raw pointer to tensor's data. + */ + void *columnwise_dptr() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_columnwise_data(tensor_); + } + + /*! \brief Get a pointer to the tensor's amax data. + * + * \return A pointer to tensor's amax data. + */ + float *amax() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_amax(tensor_); + } + + /*! \brief Get a pointer to the tensor's scale data. + * + * \return A pointer to tensor's scale data. + */ + float *scale() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_scale(tensor_); + } + + /*! \brief Get a pointer to the tensor's inverse of scale data. + * + * \return A pointer to tensor's inverse of scale data. + */ + float *scale_inv() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_scale_inv(tensor_); + } + + /*! \brief Get the scale_inv_shape of this TensorWrapper. + * + * \return scale_inv_shape of this TensorWrapper. + */ + const NVTEShape scale_inv_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_scale_inv_shape(tensor_); + } + + /*! \brief Get a scaling mode of the tensor. + * + * \return Scaling mode of the tensor. + */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_tensor_scaling_mode(tensor_); + } + + void zero_(musaStream_t stream) { nvte_zero_tensor(tensor_, stream); } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = {&defaultData, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { return {s.data(), s.size()}; } + + /*! \brief Wrapped NVTETensor. */ + NVTETensor tensor_ = nullptr; +}; + +} // namespace transformer_engine + +#endif // __cplusplus + +#endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ diff --git a/transformer_engine/musa/common/include/transformer_engine/transpose.h b/transformer_engine/musa/common/include/transformer_engine/transpose.h new file mode 100644 index 0000000000..850cc4571e --- /dev/null +++ b/transformer_engine/musa/common/include/transformer_engine/transpose.h @@ -0,0 +1,325 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file transpose.h + * \brief Functions handling transposes. + */ + +#ifndef TRANSFORMER_ENGINE_TRANSPOSE_H_ +#define TRANSFORMER_ENGINE_TRANSPOSE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Cast and transpose the input. + * + * This function casts the input and produces 2 results: + * - rowwise data in `output` is the result of the cast + * - columnwise data in `output` is the transposed result of the cast. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, musaStream_t stream); + +/*! \brief Transpose the input. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[out] transposed_output Result of the transpose. Shape: [H, N]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, musaStream_t stream); + +/*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension. + * + * This function casts the input and produces 2 results: + * - `output` is the result of the cast (rowwise data) and transposed cast (columnwise data) + * - `dbias` is the result of the reduction of the input along the first dimension. + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[out] dbias Result of the reduction of the input along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, musaStream_t stream); + +/*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension. + * + * This function takes FP8 input and produces 2 results: + * - `transposed_output` is the transposed result of the input. + * - `dbias` is the result of the reduction of the input along the first dimension. + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] transposed_output Result of the transpose. Shape: [H, N]. + * \param[out] dbias Result of the reduction of the input along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, musaStream_t stream); + +/*! \brief Cast and transpose multiple tensors. + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of casted tensors. Dimensions + * of their rowwise data members match + * tensors in input_list. Dimensions of + * their columnwise data members are + * transposed. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, + NVTETensor* output_list, musaStream_t stream); + +/*! \brief Compute backward of GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the GeLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the SiLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Compute backward of ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Compute backward of the Quick GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Quick GeLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Compute backward of the Squared ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Squared ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream); + +/*! \brief Computes the gated GeLU activation of the input, additionally casts and transposes + * the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ +void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated Swish activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ +void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ +void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated Quick GeLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ +void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, musaStream_t stream); + +/*! \brief Computes the gated Squared ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ +void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, musaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_TRANSPOSE_H_ diff --git a/transformer_engine/musa/common/nvtx.h b/transformer_engine/musa/common/nvtx.h new file mode 100644 index 0000000000..05f0003a12 --- /dev/null +++ b/transformer_engine/musa/common/nvtx.h @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_ +#define TRANSFORMER_ENGINE_COMMON_NVTX_H_ + +// #include + +#include + +namespace transformer_engine::nvtx { + +struct NVTXWrapper { + explicit NVTXWrapper(const std::string &name) { /* nvtxRangePush(name.c_str()); */ } + + ~NVTXWrapper() { /* nvtxRangePop(); */ } +}; + +} // namespace transformer_engine::nvtx + +#endif // TRANSFORMER_ENGINE_COMMON_NVTX_H_ diff --git a/transformer_engine/musa/common/permutation/permutation.mu b/transformer_engine/musa/common/permutation/permutation.mu new file mode 100644 index 0000000000..b8994396b7 --- /dev/null +++ b/transformer_engine/musa/common/permutation/permutation.mu @@ -0,0 +1,373 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" + +static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, + const int num_rows, const int topK, + const int num_out_tokens) { + // Each block corresponds to one source token + // row_id_map[topK][num_rows] + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int idx = bid * blockDim.x + tid; + + if (idx >= num_rows * topK) return; + + int source_row = sorted_row_id[idx]; + int source_token_id = source_row / topK; + int source_topK_id = source_row % topK; + + if (idx >= num_out_tokens) { + // Set the indices of dropped tokens to -1 + row_id_map[source_topK_id * num_rows + source_token_id] = -1; + } else { + // Create a row id map for subsequent unpermute operation + row_id_map[source_topK_id * num_rows + source_token_id] = idx; + } +} + +template +__global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, + const float *prob, const int num_rows, const int topK, + const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one dest token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x * blockDim.y) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_elem[kElementsPerAccess]; + TCompute frag_sum[kElementsPerAccess]; + + int source_row = row_id_map[source_token]; + + // source_row == -1 represents a dropped token + if (source_row != -1) { + const T *source_row_ptr = input + source_row * num_cols; + + // frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + frag_load_store = *(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] * s_prob[0]; + } + } + } else { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(0.0f); + } + } + + for (int k = 1; k < topK; k++) { + source_row = row_id_map[k * num_rows + source_token]; + + if (source_row == -1) continue; + + const T *source_row_ptr = input + source_row * num_cols; + + // frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + frag_load_store = *(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = frag_elem[e] * s_prob[k]; + } + } + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] + frag_elem[e]; + } + } + + T *dest_row_ptr = unpermuted_output + source_token * num_cols; + + for (int e = 0; e < kElementsPerAccess; e++) { + if constexpr ((std::is_same_v || std::is_same_v) && + (!hasProb)) { + frag_sum[e] = frag_sum[e] / TCompute(topK); + } + frag_load_store_ptr[e] = T(frag_sum[e]); + } + + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + } +} + +template +__global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad, + const float *prob, float *prob_grad, const int *row_id_map, + const int num_rows, const int topK, const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one source token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Accumulators for the calculation of prob_grad + float accum[topKTile] = {0.0f}; + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // The starting address of each source row + const T *source_row_ptr = input_bwd + source_token * num_cols; + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_src[kElementsPerAccess]; + + // frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + frag_load_store = *(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]); + + int index = source_token; + + // Process each row in the corresponding topK rows + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + + int dest_row = row_id_map[index]; + index += num_rows; + + if (dest_row != -1) { + if (hasProb) { + // Calculate act_grad in unpermute bwd + for (int e = 0; e < kElementsPerAccess; e++) + frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); + } else { + // permute fwd + for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]); + } + + T *dest_row_ptr = act_grad + dest_row * num_cols; + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + + if (hasProb) { + // Inner product calculation for prob_grad in unpermute bwd + const T *input_fwd_ptr = input_fwd + dest_row * num_cols; + + // frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); + frag_load_store = *(reinterpret_cast(input_fwd_ptr + i)); + + TCompute frag_input_fwd[kElementsPerAccess]; + for (int e = 0; e < kElementsPerAccess; e++) + frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); + + for (int e = 0; e < kElementsPerAccess; e++) { + accum[k] += static_cast(frag_src[e] * frag_input_fwd[e]); + } + } + } + } + } + + if (hasProb) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + // Warp-level reduction + for (int mask = 16; mask > 0; mask /= 2) { + accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); + } + } + + if (tid == 0) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + prob_grad[source_token * topK + k] = accum[k]; + } + } + } +} + +template +void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, + const float *prob, float *prob_grad, const T *input_fwd, + const int num_rows, const int topK, const int num_cols, + const int num_out_tokens, musaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + if (input_fwd == nullptr) { + // moe_permute_fwd + + int threads = 64; + int blocks = (num_rows * topK + threads - 1) / threads; + + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, + num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_kernel<<>>( + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd + + int threads = 32; + int blocks = num_rows; + + if (prob == nullptr) { + // moe_unpermute_bwd without probs + + moe_permute_kernel<<>>( + input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd with probs + + size_t smem_bytes = topK * sizeof(TCompute); + + if (topK <= 8) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 16) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 32) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 64) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 128) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else { + NVTE_ERROR("topK cannot exceed 128."); + } + } + } +} + +template +void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob, + const int num_rows, const int topK, const int num_cols, + musaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + int blocks = num_rows; + int threads = std::min(num_cols / kElementsPerAccess, 1024); + size_t smem_bytes = topK * sizeof(TCompute); + + if (prob == nullptr) { + // moe_permute_bwd + // moe_unpermute_fwd without probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, nullptr, num_rows, topK, num_cols); + } else { + // moe_unpermute_fwd with probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, topK, num_cols); + } +} + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, musaStream_t stream) { + NVTE_API_CALL(nvte_permute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *sorted_row_id_cu = + reinterpret_cast(sorted_row_id); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + const transformer_engine::Tensor *prob_grad_cu = + reinterpret_cast(prob_grad); + const transformer_engine::Tensor *input_fwd_cu = + reinterpret_cast(input_fwd); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(sorted_row_id_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), + reinterpret_cast(prob_grad_cu->data.dptr), + reinterpret_cast(input_fwd_cu->data.dptr), num_rows, topK, + num_cols, num_out_tokens, stream);); +} + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + musaStream_t stream) { + NVTE_API_CALL(nvte_unpermute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), num_rows, topK, + num_cols, stream);); +} diff --git a/transformer_engine/musa/common/permutation/permutation_mask.mu b/transformer_engine/musa/common/permutation/permutation_mask.mu new file mode 100644 index 0000000000..5096258c41 --- /dev/null +++ b/transformer_engine/musa/common/permutation/permutation_mask.mu @@ -0,0 +1,699 @@ +#include + +#include "../common.h" +#include "../util/mtfp8_utils.muh" +#include "../utils.muh" + +// HACK(sherry): support fp32/fp64 router +// input: [num_tokens, hidden_size] @ [stride_input_token, stride_input_hidden] +// row_id_map: [num_experts, num_tokens] +// output: [num_out_tokens, hidden_size] @ [stride_output_token, stride_output_hidden] +// probs: [num_tokens, num_experts] +// permuted_probs: [num_out_tokens] +template +__global__ void permute_with_mask_map_trans( + MUtensorDescriptor out_dev_tensorDesc, MUtensorDescriptor in_dev_tensorDesc, + IdxDtype *row_id_map_ptr, const P_Dtype *probs_ptr, P_Dtype *permuted_probs_ptr, + const int num_tokens, const int num_experts, const int hidden_size, + const int stride_input_token, const int stride_input_hidden, const int stride_output_token, + const int stride_output_hidden, const int stride_probs_token, const int stride_probs_expert, + const int stride_permuted_probs_token) { + using IdxVec = transformer_engine::Vec; + int tidx = threadIdx.x; + int token_id = blockIdx.x; + + int trans_count = hidden_size * sizeof(Dtype); + extern __shared__ __align__(128) char shared_array[]; + Dtype *smem = reinterpret_cast(shared_array); + __musa::async_barrier bar(1); + bar.init_arrival(1); + __syncthreads(); + + int ld_dim = hidden_size; + int ld_pos = token_id * hidden_size; + __musa::memcpy_async(bar, smem, &in_dev_tensorDesc, ld_dim, ld_pos, trans_count, 0, 3, 1); + unsigned phase_id = bar.arrive(); + bar.wait(phase_id); + + IdxDtype dst_row; + for (int expert_id = 0; expert_id < num_experts; expert_id += 1) { + dst_row = row_id_map_ptr[expert_id * num_tokens + token_id]; + if (dst_row != -1) { + int st_dim = hidden_size; + int st_pos = dst_row * hidden_size; + __musa::memcpy(smem, &out_dev_tensorDesc, st_dim, st_pos); + if constexpr (with_permuted_probs) { + if (tidx == 0) { + int prob_offset = token_id * stride_probs_token + expert_id * stride_probs_expert; + P_Dtype prob_val = probs_ptr[prob_offset]; + int permuted_prob_offset = dst_row * stride_permuted_probs_token; + permuted_probs_ptr[permuted_prob_offset] = prob_val; + } + } + } + } +} + +template +__global__ void permute_with_mask_map(MUtensorDescriptor out_dev_tensorDesc, + MUtensorDescriptor in_dev_tensorDesc, + MUtensorDescriptor map_dev_tensorDesc, const P_Dtype *probs_ptr, + P_Dtype *permuted_probs_ptr, const int num_tokens, + const int num_experts, const int hidden_size, + const int stride_input_token, const int stride_input_hidden, + const int stride_output_token, const int stride_output_hidden, + const int stride_probs_token, const int stride_probs_expert, + const int stride_permuted_probs_token) { + using IdxVec = transformer_engine::Vec; + int tidx = threadIdx.x; + int token_id = blockIdx.x; + + int trans_count = hidden_size * sizeof(Dtype); + int trans_count_map = num_experts * sizeof(IdxDtype); + extern __shared__ __align__(128) char shared_array[]; + Dtype *smem = reinterpret_cast(shared_array); + const size_t hidden_size_aligned = + (hidden_size * sizeof(Dtype) + 127) / 128 * 128 / sizeof(Dtype); + IdxDtype *smem_map = reinterpret_cast(smem + hidden_size_aligned); + __musa::async_barrier bar(1); + __musa::async_barrier bar_map(2); + bar.init_arrival(1); + bar_map.init_arrival(1); + __syncthreads(); + + int ld_dim = hidden_size; + int ld_pos = token_id * hidden_size; + int ld_dim_map = num_experts; + int ld_pos_map = token_id * num_experts; + __musa::memcpy_async(bar_map, smem_map, &map_dev_tensorDesc, ld_dim_map, ld_pos_map, + trans_count_map, 0, 3, 1); + unsigned phase_id_map = bar_map.arrive(); + __musa::memcpy_async(bar, smem, &in_dev_tensorDesc, ld_dim, ld_pos, trans_count, 0, 3, 1); + unsigned phase_id = bar.arrive(); + + bar_map.wait(phase_id_map); + bar.wait(phase_id); + IdxVec dst_row_vec; + for (int expert_id = 0; expert_id < num_experts; expert_id += 4) { + dst_row_vec.load_from(smem_map + expert_id); +#pragma unroll + for (int i = 0; i < 4; i++) { + if (dst_row_vec.data.elt[i] != -1) { + int st_dim = hidden_size; + int st_pos = dst_row_vec.data.elt[i] * hidden_size; + __musa::memcpy(smem, &out_dev_tensorDesc, st_dim, st_pos); + if constexpr (with_permuted_probs) { + if (tidx == 0) { + int prob_offset = token_id * stride_probs_token + (expert_id + i) * stride_probs_expert; + P_Dtype prob_val = probs_ptr[prob_offset]; + int permuted_prob_offset = dst_row_vec.data.elt[i] * stride_permuted_probs_token; + permuted_probs_ptr[permuted_prob_offset] = prob_val; + } + } + } + } + } +} + + +// input: [num_out_tokens, hidden_size] +// row_id_map: [num_experts, num_tokens] +// output: [num_tokens, hidden_size] +// merging_probs: [num_tokens, num_experts] +// permuted_probs: [num_out_tokens] +// unpermuted_probs: [num_tokens, num_experts] +template +__global__ void moe_unpermute_mask( + const Dtype *in_ptr, Dtype *out_ptr, IdxDtype *row_id_map_ptr, const P_Dtype *merging_probs_ptr, + const P_Dtype *permuted_probs_ptr, P_Dtype *unpermuted_probs_ptr, const int num_tokens, + const int num_experts, const int hidden_size, const int stride_input_token, + const int stride_input_hidden, const int stride_output_token, const int stride_output_hidden, + const int stride_merging_probs_token, const int stride_merging_probs_expert, + const int stride_permuted_probs_token, const int stride_unpermuted_probs_token, + const int stride_unpermuted_probs_expert) { + using DtypeVec = transformer_engine::Vec; + using ComputeVec = transformer_engine::Vec; + constexpr int idx_vlen = 4; + using IdxVec = transformer_engine::Vec; + int token_id = blockIdx.y; + int tidx = blockIdx.x * blockDim.x + threadIdx.x; + int tidx_vlen = (blockIdx.x * blockDim.x + threadIdx.x) * vlen; + extern __shared__ IdxDtype smem[]; + + if constexpr (trans_row_id_map) { + for (int expert_id = threadIdx.x; expert_id < num_experts; expert_id += blockDim.x) { + memcpy_global2shared(smem + expert_id, row_id_map_ptr + expert_id * num_tokens + token_id, 1); + } + } else { + for (int expert_id = threadIdx.x * idx_vlen; expert_id < num_experts; + expert_id += blockDim.x * idx_vlen) { + memcpy_global2shared(smem + expert_id, row_id_map_ptr + token_id * num_experts + expert_id, + idx_vlen); + } + } + __syncthreads_lm(); + + ComputeVec acc_vec = 0.0f; + if (tidx_vlen < hidden_size) { + IdxVec src_row_vec; + for (int expert_id = 0; expert_id < num_experts; expert_id += idx_vlen) { + src_row_vec.load_from(smem + expert_id); +#pragma unroll + for (int i = 0; i < idx_vlen; i++) { + int unpermuted_offset = token_id * stride_unpermuted_probs_token + + (expert_id + i) * stride_unpermuted_probs_expert; + if (src_row_vec.data.elt[i] != -1) { + int src_offset = + src_row_vec.data.elt[i] * stride_input_token + tidx_vlen * stride_input_hidden; + DtypeVec src_val_vec = (Dtype)(0.0f); + src_val_vec.load_from(in_ptr + src_offset); + + P_Dtype merging_probs_val = (P_Dtype)(1.0f); + if constexpr (with_merging_probs) { + int merging_probs_offset = token_id * stride_merging_probs_token + + (expert_id + i) * stride_merging_probs_expert; + merging_probs_val = merging_probs_ptr[merging_probs_offset]; + } + +#pragma unroll + for (int j = 0; j < vlen; j++) { + acc_vec.data.elt[j] += (float)src_val_vec.data.elt[j] * (float)merging_probs_val; + } + + if constexpr (with_permuted_probs) { + if (tidx == 0) { + unpermuted_probs_ptr[unpermuted_offset] = + permuted_probs_ptr[src_row_vec.data.elt[i] * stride_permuted_probs_token]; + } + } + } else { + if constexpr (with_permuted_probs) { + if (tidx == 0) { + unpermuted_probs_ptr[unpermuted_offset] = (P_Dtype)(0.0f); + } + } + continue; + } + } + } + int dst_offset = token_id * stride_output_token + tidx_vlen * stride_output_hidden; +#pragma unroll + for (int i = 0; i < vlen; i++) { + out_ptr[dst_offset + i] = (Dtype)acc_vec.data.elt[i]; + } + } +} +// HACK(sherry) + +// fwd_input_grad, [num_out_tokens, hidden_size] +// merging_probs_grad, [num_tokens, num_experts] +// fwd_output_grad, [num_tokens, hidden_size] +// fwd_input, [num_out_tokens, hidden_size] +// merging_probs, [num_tokens, num_experts] +// row_id_map, [num_experts, num_tokens] +template +__global__ void moe_unpermute_mask_bwd_with_merging_probs( + const Dtype *fwd_output_grad_ptr, Dtype *fwd_input_grad_ptr, const Dtype *fwd_input_ptr, + const Dtype *merging_probs_ptr, Dtype *merging_probs_grad_ptr, IdxDtype *row_id_map_ptr, + const int num_tokens, const int num_experts, const int hidden_size, + const int stride_fwd_output_grad_token, const int stride_fwd_output_grad_hidden, + const int stride_fwd_input_grad_token, const int stride_fwd_input_grad_hidden, + const int stride_fwd_input_token, const int stride_fwd_input_hidden, + const int stride_merging_probs_token, const int stride_merging_probs_expert, + const int stride_merging_probs_grad_token, const int stride_merging_probs_grad_expert) { + constexpr int idx_vlen = 4; + using ComputeDtype = float; + using DtypeVec = transformer_engine::Vec; + using IdxVec = transformer_engine::Vec; + int token_id = blockIdx.x; + int tidx = threadIdx.x; + int tidx_vlen = (threadIdx.x) * vlen; + extern __shared__ IdxDtype smem[]; + int warp_id = tidx >> 5; + int lane_id = tidx & 31; + ComputeDtype *warpLevelVal = reinterpret_cast(smem + num_experts); + + if constexpr (trans_row_id_map) { + for (int expert_id = threadIdx.x; expert_id < num_experts; expert_id += blockDim.x) { + memcpy_global2shared(smem + expert_id, row_id_map_ptr + expert_id * num_tokens + token_id, 1); + } + } else { + for (int expert_id = threadIdx.x * idx_vlen; expert_id < num_experts; + expert_id += blockDim.x * idx_vlen) { + memcpy_global2shared(smem + expert_id, row_id_map_ptr + token_id * num_experts + expert_id, + idx_vlen); + } + } + __syncthreads_lm(); + + IdxVec dst_row_vec; + for (int expert_id = 0; expert_id < num_experts; expert_id += idx_vlen) { + dst_row_vec.load_from(smem + expert_id); +#pragma unroll + for (int i = 0; i < idx_vlen; i++) { + int probs_grad_offset = token_id * stride_merging_probs_grad_token + + (expert_id + i) * stride_merging_probs_grad_expert; + if (dst_row_vec.data.elt[i] != -1) { + ComputeDtype prob_grad_acc = 0.0f; + for (int hidden_offset = tidx_vlen; hidden_offset < hidden_size; + hidden_offset += blockDim.x * vlen) { + int input_offset = token_id * stride_fwd_output_grad_token + + hidden_offset * stride_fwd_output_grad_hidden; + DtypeVec src_val_vec = (Dtype)(0.0f); + src_val_vec.load_from(fwd_output_grad_ptr + input_offset); + + int merging_prob_offset = + token_id * stride_merging_probs_token + (expert_id + i) * stride_merging_probs_expert; + float merging_prob = (float)(merging_probs_ptr[merging_prob_offset]); + + DtypeVec dst_val_vec = (Dtype)(0.0f); + int output_offset = dst_row_vec.data.elt[i] * stride_fwd_input_grad_token + + hidden_offset * stride_fwd_input_grad_hidden; +#pragma unroll + for (int j = 0; j < vlen; j++) { + dst_val_vec.data.elt[j] = (float)src_val_vec.data.elt[j] * merging_prob; + } + dst_val_vec.store_to(fwd_input_grad_ptr + output_offset); + + int fwd_input_offset = dst_row_vec.data.elt[i] * stride_fwd_input_token + + hidden_offset * stride_fwd_input_hidden; + DtypeVec fwd_input_vec = (Dtype)(0.0f); + fwd_input_vec.load_from(fwd_input_ptr + fwd_input_offset); +#pragma unroll + for (int j = 0; j < vlen; j++) { + prob_grad_acc = + prob_grad_acc + (float)fwd_input_vec.data.elt[j] * (float)src_val_vec.data.elt[j]; + } + } + ComputeDtype sum = prob_grad_acc; + for (int delta = 16; delta > 0; delta >>= 1) { + sum += __shfl_down_sync(0xffffffff, sum, delta); + } + if (lane_id == 0) { + warpLevelVal[warp_id] = sum; + } + __syncthreads_lm(); + if (warp_id == 0) { + sum = (lane_id < 4) ? warpLevelVal[lane_id] : 0; + for (int delta = 2; delta > 0; delta >>= 1) { + sum += __shfl_down_sync(0xffffffff, sum, delta); + } + if (tidx == 0) { + merging_probs_grad_ptr[probs_grad_offset] = (Dtype)sum; + } + } + } else { + merging_probs_grad_ptr[probs_grad_offset] = (Dtype)(0.0f); + } + } + } +} + +// HACK(sherry): support fp32/fp64 router +template +void nvte_permute_mask_launcher(const Dtype *input, Dtype *output, IdxDtype *row_id_map, + const P_Dtype *probs, P_Dtype *permuted_probs, const int num_tokens, + const int num_experts, const int num_out_tokens, + const int hidden_size, musaStream_t stream) { + NVTE_CHECK((hidden_size * sizeof(Dtype)) % 4 == 0, "bytes of hidden_size must be divisible by 4"); + if constexpr (!trans_row_id_map) { + NVTE_CHECK((num_experts * sizeof(IdxDtype)) % 4 == 0, + "bytes of num_experts must be divisible by 4"); + } + + MUtensorDescriptor intensorDesc; + MUtensorDescriptor outtensorDesc; + MUtensorDescriptor maptensorDesc; + MUtensorDescriptorDataType tensorDataType = MU_TENSOR_DESCRIPTOR_DATA_TYPE_BFLOAT16; + MUtensorDescriptorDataType mapDataType = MU_TENSOR_DESCRIPTOR_DATA_TYPE_INT64; + uint32_t tensorRank = 1; + const uint64_t in_globalDim[5] = {static_cast(hidden_size) * num_tokens, 1, 1, 1, 1}; + const uint64_t in_globalStrides[4] = {0, 0, 0, 0}; + const uint64_t out_globalDim[5] = {static_cast(hidden_size) * num_out_tokens, 1, 1, 1, + 1}; + const uint64_t out_globalStrides[4] = {0, 0, 0, 0}; + const uint64_t map_globalDim[5] = {static_cast(num_experts) * num_tokens, 1, 1, 1, 1}; + const uint64_t map_globalStrides[4] = {0, 0, 0, 0}; + MUtensorDescriptorInterleave interleave = MU_TENSOR_DESCRIPTOR_INTERLEAVE_NONE; + uint64_t oobConstantFill = 0; + + transformer_engine::checkCuDriverContext(stream); + NVTE_CHECK_MU(muTensorDescriptorEncode(&intensorDesc, tensorDataType, tensorRank, (void *)input, + in_globalDim, in_globalStrides, interleave, + oobConstantFill)); + NVTE_CHECK_MU(muTensorDescriptorEncode(&outtensorDesc, tensorDataType, tensorRank, (void *)output, + out_globalDim, out_globalStrides, interleave, + oobConstantFill)); + NVTE_CHECK_MU(muTensorDescriptorEncode(&maptensorDesc, mapDataType, tensorRank, + (void *)row_id_map, map_globalDim, map_globalStrides, + interleave, oobConstantFill)); + + const int block_x = 32; + const int grid_x = num_tokens; + dim3 block(block_x, 1); + dim3 grid(grid_x, 1); + + if constexpr (trans_row_id_map) { + int smem_size = hidden_size * sizeof(Dtype); + permute_with_mask_map_trans<<>>( + outtensorDesc, intensorDesc, row_id_map, probs, permuted_probs, num_tokens, num_experts, + hidden_size, hidden_size, 1, hidden_size, 1, num_experts, 1, 1); + } else { + int smem_size = hidden_size * sizeof(Dtype) + num_experts * sizeof(IdxDtype); + permute_with_mask_map<<>>( + outtensorDesc, intensorDesc, maptensorDesc, probs, permuted_probs, num_tokens, num_experts, + hidden_size, hidden_size, 1, hidden_size, 1, num_experts, 1, 1); + } +} + +template +void nvte_unpermute_mask_launcher(const Dtype *input, Dtype *output, IdxDtype *row_id_map, + const P_Dtype *merging_probs, const P_Dtype *permuted_probs, + P_Dtype *unpermuted_probs, const int num_tokens, + const int num_experts, const int hidden_size, + musaStream_t stream) { + constexpr int vlen = 16 / sizeof(Dtype); + int block_x = 128; + int grid_x = transformer_engine::mtfp8::ceil_div(hidden_size, block_x * vlen); + int grid_y = num_tokens; + dim3 block(block_x, 1); + dim3 grid(grid_x, grid_y); + int smem_size = num_experts * sizeof(IdxDtype); + + moe_unpermute_mask<<>>( + input, output, row_id_map, merging_probs, permuted_probs, unpermuted_probs, num_tokens, + num_experts, hidden_size, hidden_size, 1, hidden_size, 1, num_experts, 1, 1, num_experts, 1); +} +// HACK(sherry) + +template +void nvte_unpermute_mask_bwd_with_merging_probs_launcher( + const Dtype *fwd_output_grad, Dtype *fwd_input_grad, const Dtype *fwd_input, + const Dtype *merging_probs, Dtype *merging_probs_grad, IdxDtype *row_id_map, + const int num_tokens, const int num_experts, const int hidden_size, musaStream_t stream) { + NVTE_CHECK(num_experts % 4 == 0, "num_experts must be divisible by 4"); + + constexpr int vlen = 16 / sizeof(Dtype); + int block_x = 128; + int grid_x = num_tokens; + dim3 block(block_x, 1, 1); + dim3 grid(grid_x, 1, 1); + + int smem_size = num_experts * sizeof(IdxDtype) + block_x * sizeof(float); + moe_unpermute_mask_bwd_with_merging_probs + <<>>( + fwd_output_grad, fwd_input_grad, fwd_input, merging_probs, merging_probs_grad, row_id_map, + num_tokens, num_experts, hidden_size, hidden_size, 1, hidden_size, 1, hidden_size, 1, + num_experts, 1, num_experts, 1); +} + +#define CALL_PERMUTE_MASK_LAUNCHER(_PERMUTED_PROBS, _TRANS_ROW_ID_MAP) \ + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( \ + input_cu->data.dtype, T, \ + nvte_permute_mask_launcher( \ + reinterpret_cast(input_cu->data.dptr), \ + reinterpret_cast(output_cu->data.dptr), \ + reinterpret_cast(row_id_map_cu->data.dptr), \ + reinterpret_cast(probs_cu->data.dptr), \ + reinterpret_cast(permuted_probs_cu->data.dptr), num_tokens, num_experts, \ + num_out_tokens, hidden_size, stream);); + +void nvte_permute_mask(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor probs, NVTETensor permuted_probs, const int num_tokens, + const int num_experts, const int num_out_tokens, const int hidden_size, + musaStream_t stream) { + NVTE_API_CALL(nvte_permute_mask); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *probs_cu = + reinterpret_cast(probs); + const transformer_engine::Tensor *permuted_probs_cu = + reinterpret_cast(permuted_probs); + + if (probs_cu->data.dptr != nullptr) { + if (row_id_map_cu->data.shape[0] == num_experts) { + CALL_PERMUTE_MASK_LAUNCHER(true, true); + } else { + CALL_PERMUTE_MASK_LAUNCHER(true, false); + } + } else { + if (row_id_map_cu->data.shape[0] == num_experts) { + CALL_PERMUTE_MASK_LAUNCHER(false, true); + } else { + CALL_PERMUTE_MASK_LAUNCHER(false, false); + } + } +} +#undef CALL_PERMUTE_MASK_LAUNCHER + +#define CALL_UNPERMUTE_MASK_LAUNCHER(_MERGING_PROBS, _PERMUTED_PROBS, _TRANS_ROW_ID_MAP) \ + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( \ + input_cu->data.dtype, T, \ + nvte_unpermute_mask_launcher( \ + reinterpret_cast(input_cu->data.dptr), \ + reinterpret_cast(output_cu->data.dptr), \ + reinterpret_cast(row_id_map_cu->data.dptr), \ + reinterpret_cast(merging_probs_cu->data.dptr), \ + reinterpret_cast(permuted_probs_cu->data.dptr), \ + reinterpret_cast(unpermuted_probs_cu->data.dptr), num_tokens, num_experts, \ + hidden_size, stream);); + +#define CALL_UNPERMUTE_MASK_TRANS_LAUNCHER(_TRANS_ROW_ID_MAP) \ + if (merging_probs_cu->data.dptr != nullptr) { \ + if (permuted_probs_cu->data.dptr != nullptr) { \ + CALL_UNPERMUTE_MASK_LAUNCHER(true, true, _TRANS_ROW_ID_MAP); \ + } else { \ + CALL_UNPERMUTE_MASK_LAUNCHER(true, false, _TRANS_ROW_ID_MAP); \ + } \ + } else { \ + if (permuted_probs_cu->data.dptr != nullptr) { \ + CALL_UNPERMUTE_MASK_LAUNCHER(false, true, _TRANS_ROW_ID_MAP); \ + } else { \ + CALL_UNPERMUTE_MASK_LAUNCHER(false, false, _TRANS_ROW_ID_MAP); \ + } \ + } + +void nvte_unpermute_mask(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor merging_probs, const NVTETensor permuted_probs, + NVTETensor unpermuted_probs, const int num_tokens, const int num_experts, + const int hidden_size, musaStream_t stream) { + NVTE_API_CALL(nvte_unpermute_mask); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *merging_probs_cu = + reinterpret_cast(merging_probs); + const transformer_engine::Tensor *permuted_probs_cu = + reinterpret_cast(permuted_probs); + const transformer_engine::Tensor *unpermuted_probs_cu = + reinterpret_cast(unpermuted_probs); + + if (row_id_map_cu->data.shape[0] == num_experts) { + CALL_UNPERMUTE_MASK_TRANS_LAUNCHER(true); + } else { + CALL_UNPERMUTE_MASK_TRANS_LAUNCHER(false); + } +} + +#undef CALL_UNPERMUTE_MASK_LAUNCHER +#undef CALL_UNPERMUTE_MASK_TRANS_LAUNCHER + + +//HACK(sherry): support fp32/fp64 router +#define PROBS_TYPE_SWITCH(probs_dtype, probs_type, ...) \ + switch (probs_dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using probs_type = fp16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kBFloat16: { \ + using probs_type = bf16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kFloat32: { \ + using probs_type = fp32; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + NVTE_ERROR("Invalid probs type."); \ + } +#define TRANSFORMER_ENGINE_PROBS_PERMUTE_TYPE_SWITCH(dtype, type, probs_dtype, probs_type,...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using type = fp16; \ + PROBS_TYPE_SWITCH(probs_dtype, probs_type, __VA_ARGS__); \ + break; \ + } \ + case DType::kBFloat16: { \ + using type = bf16; \ + PROBS_TYPE_SWITCH(probs_dtype, probs_type, __VA_ARGS__); \ + break; \ + } \ + default: \ + NVTE_ERROR("Invalid type for 16 bit."); \ + } + +#define CALL_HIGH_PRECISION_PROBS_PERMUTE_MASK_LAUNCHER(_PERMUTED_PROBS, _TRANS_ROW_ID_MAP) \ + TRANSFORMER_ENGINE_PROBS_PERMUTE_TYPE_SWITCH( \ + input_cu->data.dtype, T, probs_cu->data.dtype, T_P, \ + nvte_permute_mask_launcher( \ + reinterpret_cast(input_cu->data.dptr), \ + reinterpret_cast(output_cu->data.dptr), \ + reinterpret_cast(row_id_map_cu->data.dptr), \ + reinterpret_cast(probs_cu->data.dptr), \ + reinterpret_cast(permuted_probs_cu->data.dptr), num_tokens, num_experts, \ + num_out_tokens, hidden_size, stream);); + +void nvte_permute_mask_high_precision_probs(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor probs, NVTETensor permuted_probs, const int num_tokens, + const int num_experts, const int num_out_tokens, const int hidden_size, + musaStream_t stream) { + NVTE_API_CALL(nvte_permute_mask_high_precision_probs); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *probs_cu = + reinterpret_cast(probs); + const transformer_engine::Tensor *permuted_probs_cu = + reinterpret_cast(permuted_probs); + + if (probs_cu->data.dptr != nullptr) { + if (row_id_map_cu->data.shape[0] == num_experts) { + CALL_HIGH_PRECISION_PROBS_PERMUTE_MASK_LAUNCHER(true, true); + } else { + CALL_HIGH_PRECISION_PROBS_PERMUTE_MASK_LAUNCHER(true, false); + } + } else { + if (row_id_map_cu->data.shape[0] == num_experts) { + CALL_HIGH_PRECISION_PROBS_PERMUTE_MASK_LAUNCHER(false, true); + } else { + CALL_HIGH_PRECISION_PROBS_PERMUTE_MASK_LAUNCHER(false, false); + } + } +} +#undef CALL_HIGH_PRECISION_PROBS_PERMUTE_MASK_LAUNCHER + +#define CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_LAUNCHER(_MERGING_PROBS, _PERMUTED_PROBS, _TRANS_ROW_ID_MAP) \ + TRANSFORMER_ENGINE_PROBS_PERMUTE_TYPE_SWITCH( \ + input_cu->data.dtype, T,permuted_probs_cu->data.dtype, T_P, \ + nvte_unpermute_mask_launcher( \ + reinterpret_cast(input_cu->data.dptr), \ + reinterpret_cast(output_cu->data.dptr), \ + reinterpret_cast(row_id_map_cu->data.dptr), \ + reinterpret_cast(merging_probs_cu->data.dptr), \ + reinterpret_cast(permuted_probs_cu->data.dptr), \ + reinterpret_cast(unpermuted_probs_cu->data.dptr), num_tokens, num_experts, \ + hidden_size, stream);); + +#define CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_TRANS_LAUNCHER(_TRANS_ROW_ID_MAP) \ + if (merging_probs_cu->data.dptr != nullptr) { \ + if (permuted_probs_cu->data.dptr != nullptr) { \ + CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_LAUNCHER(true, true, _TRANS_ROW_ID_MAP); \ + } else { \ + CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_LAUNCHER(true, false, _TRANS_ROW_ID_MAP); \ + } \ + } else { \ + if (permuted_probs_cu->data.dptr != nullptr) { \ + CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_LAUNCHER(false, true, _TRANS_ROW_ID_MAP); \ + } else { \ + CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_LAUNCHER(false, false, _TRANS_ROW_ID_MAP); \ + } \ + } + +void nvte_unpermute_mask_high_precision_probs(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor merging_probs, const NVTETensor permuted_probs, + NVTETensor unpermuted_probs, const int num_tokens, const int num_experts, + const int hidden_size, musaStream_t stream) { + NVTE_API_CALL(nvte_unpermute_mask_high_precision_probs); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *merging_probs_cu = + reinterpret_cast(merging_probs); + const transformer_engine::Tensor *permuted_probs_cu = + reinterpret_cast(permuted_probs); + const transformer_engine::Tensor *unpermuted_probs_cu = + reinterpret_cast(unpermuted_probs); + + if (row_id_map_cu->data.shape[0] == num_experts) { + CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_TRANS_LAUNCHER(true); + } else { + CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_TRANS_LAUNCHER(false); + } +} + +#undef CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_LAUNCHER +#undef CALL_HIGH_PRECISION_PROBS_UNPERMUTE_MASK_TRANS_LAUNCHER + +#undef PROBS_TYPE_SWITCH +#undef TRANSFORMER_ENGINE_PROBS_PERMUTE_TYPE_SWITCH + +//HACK(sherry) + +#define CALL_UNPERMUTE_MASK_BWD_WITH_MERGING_PROBS_LAUNCHER(_TRANS_ROW_ID_MAP) \ + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( \ + fwd_output_grad_cu->data.dtype, T, \ + nvte_unpermute_mask_bwd_with_merging_probs_launcher( \ + reinterpret_cast(fwd_output_grad_cu->data.dptr), \ + reinterpret_cast(fwd_input_grad_cu->data.dptr), \ + reinterpret_cast(fwd_input_cu->data.dptr), \ + reinterpret_cast(merging_probs_cu->data.dptr), \ + reinterpret_cast(merging_probs_grad_cu->data.dptr), \ + reinterpret_cast(row_id_map_cu->data.dptr), num_tokens, num_experts, \ + hidden_size, stream);); + +void nvte_unpermute_mask_bwd_with_merging_probs( + const NVTETensor fwd_output_grad, NVTETensor fwd_input_grad, const NVTETensor fwd_input, + const NVTETensor merging_probs, NVTETensor merging_probs_grad, NVTETensor row_id_map, + const int num_tokens, const int num_experts, const int hidden_size, musaStream_t stream) { + NVTE_API_CALL(nvte_unpermute_mask_bwd_with_merging_probs); + + const transformer_engine::Tensor *fwd_output_grad_cu = + reinterpret_cast(fwd_output_grad); + const transformer_engine::Tensor *fwd_input_grad_cu = + reinterpret_cast(fwd_input_grad); + const transformer_engine::Tensor *fwd_input_cu = + reinterpret_cast(fwd_input); + const transformer_engine::Tensor *merging_probs_cu = + reinterpret_cast(merging_probs); + const transformer_engine::Tensor *merging_probs_grad_cu = + reinterpret_cast(merging_probs_grad); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + + if (row_id_map_cu->data.shape[0] == num_experts) { + CALL_UNPERMUTE_MASK_BWD_WITH_MERGING_PROBS_LAUNCHER(true); + } else { + CALL_UNPERMUTE_MASK_BWD_WITH_MERGING_PROBS_LAUNCHER(false); + } +} + +#undef CALL_UNPERMUTE_MASK_BWD_WITH_MERGING_PROBS_LAUNCHER diff --git a/transformer_engine/musa/common/recipe/__init__.py b/transformer_engine/musa/common/recipe/__init__.py new file mode 100644 index 0000000000..0bce83d98f --- /dev/null +++ b/transformer_engine/musa/common/recipe/__init__.py @@ -0,0 +1,198 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""This module provides predefined FP8 recipes.""" +from __future__ import annotations +import warnings +from enum import Enum +from typing import Literal, Optional, Union, Callable, NamedTuple +from pydantic.dataclasses import dataclass + + +class _FormatHelper(NamedTuple): + """ + Stores max FP8 values for fprop and bprop a `Format`. + """ + + max_fwd: float + max_bwd: float + + +class Format(Enum): + """ + Supported FP8 formats. + + Values + ------ + E4M3 : + All FP8 tensors are in e4m3 format + E5M2 : + All FP8 tensors are in e5m2 format + HYBRID : + FP8 tensors in the forward pass are in e4m3 format, + FP8 tensors in the backward pass are in e5m2 format + """ + + E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) + E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) + HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) + + +class Recipe: + """ + Base recipe class. + """ + + def mxfp8(self): + """Whether the given recipe is MXFP8 block scaling.""" + return isinstance(self, MXFP8BlockScaling) + + def delayed(self): + """Whether the given recipe is delayed scaling.""" + return isinstance(self, DelayedScaling) + + +@dataclass() +class DelayedScaling(Recipe): + """ + Use the delayed scaling factor strategy. Use scale factor from previous + iteration and record amax history of `amax_history_len` steps. + + Parameters + ---------- + margin : int, default = 0 + Margin for the scaling factor computation. + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID + Controls the FP8 data format used during forward and backward + pass. + amax_history_len : int, default = 1024 + The length of the amax history window used for + scaling factor computation. + amax_compute_algo : {'max', 'most_recent', Callable}, default = 'max' + Algorithm used for choosing the `amax` value for the + scaling factor computation. There are 2 predefined + choices: `max` chooses the largest `amax` in the history + window, while `most_recent` always chooses the most recently + seen value. Alternatively, one may pass a function of the + signature: + + .. code-block:: python + + def amax_compute(amax_history: Tensor) -> Tensor + + where `Tensor` is a framework tensor type. + scaling_factor_compute_algo : Callable, default = None + Algorithm used for computing the new scaling + factor based on the value of `amax`. It should + be a function of the signature: + + .. code-block:: python + + def scaling_factor_compute(amax: Tensor, + old_scaling_factor: Tensor, + fp8_max: Tensor, + recipe: DelayedScaling) -> Tensor + + where `Tensor` is a framework tensor type. + reduce_amax: bool, default = `True` + By default, if `torch.distributed` is initialized, the `amax` value for FP8 + tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` + call). This keeps the amaxes and scaling factors synced across the given + distributed group. If set to `False`, this reduction is skipped and every + GPU maintains local amaxes and scaling factors. To ensure results are + numerically identical across checkpointing boundaries in this case, all + ranks must checkpoint in order to store the local tensors. + fp8_dpa: bool, default = `False` + Whether to enable FP8 dot product attention (DPA). When the model is placed in an + `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + inputs from higher precision to FP8, performs attention in FP8, and casts tensors + back to higher precision as outputs. FP8 DPA currently is only supported in the + `FusedAttention` backend. + fp8_mha: bool, default = `False` + Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting + operations mentioned above at the DPA boundaries. Currently only standard MHA modules + i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When + `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as + `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. + When `fp8_mha = True, fp8_dpa = True`, it becomes + `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + + Notes + ----- + * By default (when `scaling_factor_compute_algo` is left as `None`) the scaling + factor is computed from the final `amax` value using the formula: + + .. code-block:: python + + FP8_MAX = maximum_representable_value(fp8_format) + new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin) + + * `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are + subject to change in future Transformer Engine releases. + """ + + margin: int = 0 + interval: int = -1 + fp8_format: Format = Format.HYBRID + amax_history_len: int = 1024 + amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max" + scaling_factor_compute_algo: Optional[Callable] = None + reduce_amax: bool = True + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + if self.interval >= 0: + warnings.warn( + "`interval` argument is deprecated and unused. " + "It will be removed in an upcoming release.", + DeprecationWarning, + ) + + def __repr__(self) -> str: + return ( + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}, " + f"amax_history_len={self.amax_history_len}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) + + +@dataclass() +class MXFP8BlockScaling(Recipe): + """ + Use the MXFP8 scaling factor strategy. + + In this strategy, tensors are scaled in blockwise fashion. Each group + of 32 consecutive values is scaled together using their own scaling + factor. The type of the scaling factor is E8M0 (8 bits of exponent, + 0 bits of mantissa), equivalent to scaling by a power of 2. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), in this recipe the quantized tensor and its transpose + are not numerically equivalent. Due to this, when Transformer Engine + needs both the MXFP8 tensor and its transpose (e.g. to calculate both + forward and backward pass), during the quantization both versions are + computed from the high precision input to avoid double quantization + errors. + + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 + Controls the FP8 data format used during forward and backward + pass. + """ + + margin: int = 0 + fp8_format: Format = Format.E4M3 + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + + def __repr__(self) -> str: + return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," diff --git a/transformer_engine/musa/common/recipe/delayed_scaling.mu b/transformer_engine/musa/common/recipe/delayed_scaling.mu new file mode 100644 index 0000000000..76e0051881 --- /dev/null +++ b/transformer_engine/musa/common/recipe/delayed_scaling.mu @@ -0,0 +1,420 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/musa_runtime.h" +#include "../util/logging.h" + +namespace transformer_engine { +namespace delayed_scaling_recipe { + +namespace { + +// amax value to use for updating scaling factor +enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX }; + +const char* dtype_name(DType dtype) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type, + return TypeInfo::name;); // NOLINT(*) + return ""; +} + +// Maximum representable value of an FP8 dtype +inline float fp8_dtype_max(DType dtype) { + switch (dtype) { + case DType::kFloat8E4M3: + return 448; + case DType::kFloat8E5M2: + return 57344; + default: + NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype)); + } + return 0; +} + +// struct for amax parameters +struct AmaxParam { + int num_scale = 0; + float* amax_history = nullptr; + float* scale = nullptr; +}; + +// dummy struct for kernel_bulk's other params +struct OtherParams { + float* a; + size_t b; + AmaxComputeAlgo c; + float d; +}; + +#if CUDART_VERSION >= 12010 +constexpr size_t max_constant_memory_per_kernel = 32768; +constexpr size_t AMAX_PARAMS_LIMIT = + (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#else +constexpr size_t max_constant_memory_per_kernel = 4096; +constexpr size_t AMAX_PARAMS_LIMIT = + (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#endif + +struct AmaxParams { + AmaxParam param[AMAX_PARAMS_LIMIT]; +}; + +namespace amax_and_scale_update_impl { + +// CUDA block size +constexpr size_t bsize = 256; + +/* CUDA kernel to update amax history and FP8 scaling factors + * + * Block dims: bsize x 1 x 1 + * + * Grid dims: num_scales x 1 x 1 + */ +__global__ void __launch_bounds__(bsize) + kernel(const float* amax_history_ptr, const float* scale_ptr, float* updated_amax_history_ptr, + float* updated_scale_ptr, size_t amax_history_length, size_t amax_history_stride, + AmaxComputeAlgo amax_compute_algo, float scaled_max) { + const size_t tid = threadIdx.x; + const size_t bid = blockIdx.x; + + // Update amax + float amax = 0; + { + // Roll amax history + const auto* amax_history = amax_history_ptr + bid; + auto* updated_amax_history = updated_amax_history_ptr + bid; + const auto last_amax = amax_history[0]; + const auto& length = amax_history_length; + const auto& stride = amax_history_stride; + for (size_t off = 0; off < length; off += bsize) { + const size_t i = off + tid; + float a = 0; + if (i < length) { + a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax; + amax = fmaxf(amax, a); + } + __syncthreads(); // In case roll is in-place + if (i < length) { + updated_amax_history[i * stride] = (i > 0) ? a : 0; + } + } + + // Compute amax to use for scaling factor + switch (amax_compute_algo) { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: { + __shared__ float shared_amax[bsize]; + shared_amax[tid] = amax; + __syncthreads(); +#pragma unroll + for (size_t off = bsize / 2; off > 0; off /= 2) { + if (tid < off) { + shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]); + } + __syncthreads(); + } + amax = shared_amax[tid]; + } break; + default: + amax = 0; + } + } + + // Update scale + if (tid == 0) { + // Update scale + float scale; + if (isfinite(amax) && amax > 0) { + scale = scaled_max / amax; + } else { + scale = scale_ptr[bid]; + } + // When the amax is too tiny that the scale becoming infinite in FP32, + // we set the scale to the max value of FP32. In this case, the tensor’s + // amax won't get mapped to the FP8 max representable, but rather + // something below that, but this is the best thing we can do. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + updated_scale_ptr[bid] = scale; + } +} + +/* CUDA kernel to bulk-update amax history and FP8 scaling factors + * + * Block dims: bsize x 1 x 1 + * + * Grid dims: num_tensors x 1 x 1 + */ +__global__ void __launch_bounds__(bsize) + kernel_bulk(float* amax_reduction_buffer, AmaxParams p, size_t amax_history_length, + AmaxComputeAlgo amax_compute_algo, float scaled_max) { + const size_t bid = blockIdx.x; + const size_t tid = threadIdx.x; + const int num_scale = p.param[bid].num_scale; + + int offset_in_buffer = 0; + for (int j = 0; j < bid; j++) { + offset_in_buffer += p.param[j].num_scale; + } + + for (int count = 0; count < num_scale; count++) { + // Update amax + float amax = 0; + { + // Roll amax history + const auto& length = amax_history_length; + const auto& stride = p.param[bid].num_scale; + auto* amax_history = p.param[bid].amax_history + count; + const auto last_amax = ((amax_reduction_buffer != nullptr) && + (amax_reduction_buffer[offset_in_buffer + count] != 0.0f)) + ? amax_reduction_buffer[offset_in_buffer + count] + : amax_history[0]; + if (last_amax != 0.0f) { + for (size_t off = 0; off < length; off += bsize) { + const size_t i = off + tid; + float a = 0; + if (i < length) { + a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax; + amax = fmaxf(amax, a); + } + __syncthreads(); // Inplace roll + if (i < length) { + amax_history[i * stride] = (i > 0) ? a : 0; + } + } + } + + // Compute amax to use for scaling factor + switch (amax_compute_algo) { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: { + __shared__ float shared_amax[bsize]; + shared_amax[tid] = amax; + __syncthreads(); +#pragma unroll + for (size_t off = bsize / 2; off > 0; off /= 2) { + if (tid < off) { + shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]); + } + __syncthreads(); + } + amax = shared_amax[tid]; + } break; + default: + amax = 0; + } + } + + // Update scale + if (tid == 0) { + // Computing the scaling factor requires consideration of the following scenarios: + // 1. amax == 0: + // No action is possible, set scale to the previous scale (or 1). + // 2. 0 < amax < tiny_amax + // The amax is too tiny that the scale becomes infinite in FP32. + // Set scale = FP32_max + // 3. tiny_amax <= amax < FP32_max: + // Set scale = FP8_max (or scaled_max) / amax + // 4. When amax == inf or amax == nan: + // No action is possible, set scale to the previous scale (or 1). + + float scale; + if (isfinite(amax) && amax > 0) { + scale = scaled_max / amax; + } else { + scale = p.param[bid].scale[count]; + } + // When the amax is too tiny that the scale becoming infinite in FP32, + // we set the scale to the max value of FP32. In this case, the tensor’s + // amax won't get mapped to the FP8 max representable, but rather + // something below that, but this is the best thing we can do. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + p.param[bid].scale[count] = scale; + } + } +} + +} // namespace amax_and_scale_update_impl + +} // namespace + +void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, + Tensor* updated_amax_history_, Tensor* updated_scale_, + const std::string& amax_compute_algo, DType fp8_dtype, float margin, + musaStream_t stream) { + auto& updated_amax_history = *updated_amax_history_; + auto& updated_scale = *updated_scale_; + + // Check tensors + NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(), + " dims"); + const size_t amax_history_length = amax_history.data.shape[0]; + const size_t num_scales = amax_history.data.shape[1]; + NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ", + dtype_name(amax_history.data.dtype), "."); + NVTE_CHECK(scale.numel() == num_scales, "Expected ", num_scales, " elements, ", "but found ", + scale.numel(), "."); + NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), "."); + NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ", + updated_amax_history.data.shape.size(), " dims."); + NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ", + amax_history_length, ", ", "but found ", updated_amax_history.data.shape[0]); + NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, "Expected ", num_scales, ", ", + "but found ", updated_amax_history.data.shape[1]); + NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ", + dtype_name(updated_amax_history.data.dtype), "."); + NVTE_CHECK(updated_scale.numel() == num_scales, "Expected ", num_scales, " elements, ", + "but found ", updated_scale.numel(), "."); + NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ", + dtype_name(updated_scale.data.dtype), "."); + + // amax value to use for updating scaling factor + AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; + if (amax_compute_algo == "max") { + amax_compute_algo_ = AmaxComputeAlgo::MAX; + } else if (amax_compute_algo == "most_recent") { + amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT; + } else { + NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")"); + } + + // Expected maximum value after scale is applied + const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); + + // Launch CUDA kernel + constexpr size_t block_size = amax_and_scale_update_impl::bsize; + const size_t grid_size = num_scales; + amax_and_scale_update_impl::kernel<<>>( + static_cast(amax_history.data.dptr), static_cast(scale.data.dptr), + static_cast(updated_amax_history.data.dptr), + static_cast(updated_scale.data.dptr), amax_history_length, num_scales, + amax_compute_algo_, scaled_max); + NVTE_CHECK_CUDA(musaGetLastError()); +} + +void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + const std::string& amax_compute_algo, DType fp8_dtype, + float margin, musaStream_t stream) { + using namespace transformer_engine; + + // amax value to use for updating scaling factor + AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; + if (amax_compute_algo == "max") { + amax_compute_algo_ = AmaxComputeAlgo::MAX; + } else if (amax_compute_algo == "most_recent") { + amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT; + } else { + NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")"); + } + + // Expected maximum value after scale is applied + const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); + + // Number of tensors in the bulk + const size_t num_tensors = amax_histories.size(); + size_t num_remaining_tensors = num_tensors; + const int num_kernels = (num_tensors + AMAX_PARAMS_LIMIT - 1) / AMAX_PARAMS_LIMIT; + size_t amax_history_length = 0; + if (num_tensors > 0) { + amax_history_length = amax_histories[0]->data.shape[0]; + } + + // amax parameters + float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); + AmaxParams p; + for (int iter = 0; iter < num_kernels; iter++) { + size_t kernel_num_scales = 0; + size_t kernel_num_tensors = + (iter == (num_kernels - 1)) ? num_remaining_tensors : AMAX_PARAMS_LIMIT; + for (size_t pi = 0; pi < kernel_num_tensors; pi++) { + size_t i = iter * AMAX_PARAMS_LIMIT + pi; + + // Check tensors + int num_scale = amax_histories[i]->data.shape[1]; + NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ", + dtype_name(amax_histories[i]->data.dtype), "."); + NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ", + amax_histories[i]->data.shape.size(), " dims"); + NVTE_CHECK(amax_histories[i]->numel() == amax_history_length * num_scale, "Expected ", + amax_history_length * num_scale, " elements, ", "but found ", + amax_histories[i]->numel(), "."); + NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ", + dtype_name(scales[i]->data.dtype), "."); + NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(), + " dims"); + NVTE_CHECK(scales[i]->numel() == num_scale, "Expected ", num_scale, " elements, ", "Found ", + scales[i]->numel(), "."); + + // amax parameters + kernel_num_scales += num_scale; + p.param[pi].num_scale = num_scale; + p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); + p.param[pi].scale = static_cast(scales[i]->data.dptr); + } + + // Launch CUDA kernel + size_t grid_size = kernel_num_tensors; + const size_t block_size = amax_and_scale_update_impl::bsize; + amax_and_scale_update_impl::kernel_bulk<<>>( + amax_buffer, p, amax_history_length, amax_compute_algo_, scaled_max); + NVTE_CHECK_CUDA(musaGetLastError()); + + // shift amax buffer pointer + if (amax_buffer != nullptr) { + amax_buffer += kernel_num_scales; + } + num_remaining_tensors -= AMAX_PARAMS_LIMIT; + } +} + +} // namespace delayed_scaling_recipe +} // namespace transformer_engine + +void nvte_delayed_scaling_recipe_amax_and_scale_update( + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + musaStream_t stream) { + NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); + using namespace transformer_engine; + delayed_scaling_recipe::amax_and_scale_update( + *reinterpret_cast(amax_history), *reinterpret_cast(scale), + reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), + amax_compute_algo, static_cast(fp8_dtype), margin, stream); +} + +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, std::vector amax_histories, + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, musaStream_t stream) { + NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); + using namespace transformer_engine; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories, t_scales; + for (size_t i = 0; i < num_tensors; i++) { + t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); + t_scales.push_back(reinterpret_cast(scales[i])); + } + delayed_scaling_recipe::amax_and_scale_update_after_reduction( + *reinterpret_cast(amax_reduction_buffer), t_amax_histories, t_scales, + amax_compute_algo, static_cast(fp8_dtype), margin, stream); +} diff --git a/transformer_engine/musa/common/recipe/recipe_common.muh b/transformer_engine/musa/common/recipe/recipe_common.muh new file mode 100644 index 0000000000..ef7d246758 --- /dev/null +++ b/transformer_engine/musa/common/recipe/recipe_common.muh @@ -0,0 +1,74 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ +#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ + +#include "common/common.h" + +namespace transformer_engine { + +__device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, + bool force_pow_2_scales, float epsilon, + float value_for_inf) { + // NOTE: NAN amax evaluates false for <, handled further down. + if (amax < epsilon) { + amax = epsilon; + } + + float scale = 1.f; + + if (isinf(amax) || amax == 0.f || isnan(amax)) { + return scale; + } + + // Here we don't use "scale = max_fp8 / amax" because it has different results with/without + // "--use_fast_math". + // "__fdiv_rn" has the same behavior with "max_fp8 / amax" when not using fast math. + scale = __fdiv_rn(max_fp8, amax); + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + // use fp32 max to represent the scale + scale = value_for_inf; + } + if (force_pow_2_scales) { + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0 || scale == 0.0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + + return scale; +} + +// Calculate the quantization scale for an individual data element +// given the amax(abs(tile)) value for a given quantization tile. +// +// +// Arguments: +// IType: data type of the tensor being quantized (float or bf16) +// OType: quantized data type (e4m3 or e5m2) +// amax: The evaluation of amax(abs(tile)) for the quantization tile. +// eps: An epsilon used as a floor for amax. +// pow_2_scaling: Whether to force the scale to be a power of 2. +template +__device__ __forceinline__ float compute_scale_from_types(const float amax, const float eps, + const float pow_2_scaling) { + constexpr float fp8_max = TypeInfo::max_finite_value; + // NOTE: We're relying on compute_scale_from_amax to have behavior where it + // clips the mantissa of the max_finite_value if power of 2 scaling applies. + constexpr float value_for_inf = TypeInfo::max_finite_value; + return compute_scale_from_amax(amax, fp8_max, pow_2_scaling, eps, value_for_inf); +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ + \ No newline at end of file diff --git a/transformer_engine/musa/common/swizzle/swizzle.mu b/transformer_engine/musa/common/swizzle/swizzle.mu new file mode 100644 index 0000000000..4f7df5a411 --- /dev/null +++ b/transformer_engine/musa/common/swizzle/swizzle.mu @@ -0,0 +1,338 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace { + +constexpr int TB_DIM = 32; +constexpr int NEW_SF_TILE_DIM_K = 16; +constexpr int N_SF_PER_TD_PER_TILE = 4; + +// output is in ~K-major interleaved blocks +constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr int NEW_SF_TILE_DIM_M_I32 = 32; + +template +__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { + // inp, 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15] + // out, swapping byte to form new 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15] + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t new_regs[kVectorSize]; + int32_t* regs = reinterpret_cast(regs_vec); + +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { +#pragma unroll + for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) { + new_regs[i * N_SF_PER_TD_PER_TILE + j] = + (((regs[i + 0 * N_TILE_PER_TD] >> 8 * j) & 0xFF)) | + (((regs[i + 1 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 8) | + (((regs[i + 2 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 16) | + (((regs[i + 3 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 24); + } + } +#pragma unroll + for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; +} + +template +__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // input is in M-major + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4; + constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K; + + const int M_i32 = M / 4; + const int K_i32 = K; + + int m_tiles_in_tb = N_TILE_PER_TD; + int k_tiles_in_tb = TB_DIM; + if (blockIdx.x == gridDim.x - 1) { + k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; + } + if (blockIdx.y == gridDim.y - 1) { + m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; + } + + const int32_t* input_i32 = reinterpret_cast(input) + + blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + + blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + int32_t* output_i32[N_TILE_PER_TD]; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + + (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + } + extern __shared__ int slm[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && + threadIdx.y < k_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // local shuffle + regs_shuffle_with_bit_shifts(regs_vec); + + // store, regs -> shared + int tM = threadIdx.x * N_SF_PER_TD; + int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 + + tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int i = 0; i < N_SF_PER_TD; i++) { + /* TODO rotate_i */ + slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 + + ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + __align__(16) int4* output_v4i = reinterpret_cast(output_i32[i]); + __align__(16) int4* slm_v4i = + reinterpret_cast(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; + j += blockDim.x * blockDim.y) { + output_v4i[j] = slm_v4i[j]; + } + } +} + +template +__device__ inline void regs_shuffle(LType* regs_vec) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + if constexpr (N_TILE_PER_TD == 1) return; + + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t tmp[kVectorSize]; + int32_t* ptr = reinterpret_cast(regs_vec); +#pragma unroll + for (int i = 0; i < kVectorSize; i++) + tmp[i % N_TILE_PER_TD * N_SF_PER_TD_PER_TILE + i / N_TILE_PER_TD] = ptr[i]; + +#pragma unroll + for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; +} + +template +__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // input is in K-major + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M; + + int n_tiles_in_tb = N_TILES_IN_TB; + const int K_i32 = K / 4; + if (blockIdx.x == gridDim.x - 1) { + n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; + } + + const int* input_i32 = reinterpret_cast(input) + + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; + int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + + blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + + extern __shared__ int4 slm_v4i[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // shuffle regs + regs_shuffle(regs_vec); + +// store, regs -> shared +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { + /* TODO rotate i */ + slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; + __align__(16) int4* output_v4i = reinterpret_cast(output_i32); +#pragma unroll + for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) { + output_v4i[i] = slm_v4i[i]; + } +} + +} // namespace + +namespace transformer_engine { + +void swizzle_scaling_factors(const Tensor* input, Tensor* output, musaStream_t stream) { + if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); + } + + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + + auto& scaling_mode = input->scaling_mode; + + // 1D block scaling, row-wise or colum-wise + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + const int m = + input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; + const int k = + input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + dim3 block_size(TB_DIM, TB_DIM); + if (input->has_data()) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + musaFuncSetAttribute(swizzle_row_scaling_kernel, + musaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 2: + musaFuncSetAttribute(swizzle_row_scaling_kernel, + musaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 1: + musaFuncSetAttribute(swizzle_row_scaling_kernel, + musaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + if (input->has_columnwise_data()) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + musaFuncSetAttribute(swizzle_col_scaling_kernel, + musaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 2: + musaFuncSetAttribute(swizzle_col_scaling_kernel, + musaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 1: + musaFuncSetAttribute(swizzle_col_scaling_kernel, + musaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + + // 2D block scaling + } else { + NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); + } + musaError_t err = musaGetLastError(); + if (err != musaSuccess) { + printf("CUDA Error: %s\n", musaGetErrorString(err)); + exit(-1); + } +} +} // namespace transformer_engine + +/* + * WIP (Phuong): + * - Opt for bank conflicts + * - Adding swizzle for 2d-block scaling. +*/ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_scaling_factors); + using namespace transformer_engine; + swizzle_scaling_factors(reinterpret_cast(input), reinterpret_cast(output), + stream); +} diff --git a/transformer_engine/musa/common/transformer_engine.cpp b/transformer_engine/musa/common/transformer_engine.cpp new file mode 100644 index 0000000000..594c0ef323 --- /dev/null +++ b/transformer_engine/musa/common/transformer_engine.cpp @@ -0,0 +1,416 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "common.h" + +namespace transformer_engine { + +size_t typeToSize(const DType type) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, + return TypeInfo::size;); // NOLINT(*) +} + +bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } + +std::string to_string(const DType type) { + switch (type) { + case DType::kByte: + return "Byte"; + case DType::kBFloat16: + return "BFloat16"; + case DType::kFloat16: + return "Float16"; + case DType::kFloat32: + return "Float32"; + case DType::kFloat8E4M3: + return "Float8E4M3"; + case DType::kFloat8E5M2: + return "Float8E5M2"; + case DType::kFloat8E8M0: + return "Float8E8M0"; + case DType::kInt32: + return "Int32"; + case DType::kInt64: + return "Int64"; + default: + return concat_strings("Invalid type ", static_cast(type)); + } +} + +std::string to_string(const NVTEScalingMode &mode) { + switch (mode) { + case NVTE_DELAYED_TENSOR_SCALING: + return "Delayed Tensor Scaling"; + case NVTE_MXFP8_1D_SCALING: + return "MXFP8 1D Scaling"; + case NVTE_INVALID_SCALING: + return "Invalid Scaling"; + } + return "Invalid Scaling"; +} + +void CheckNoopTensor(const Tensor &t, const std::string &name) { + if (t.data.dptr != nullptr) { + NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(), + "."); + NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name, + " noop. Expected kFloat32."); + } +} + +void CheckScaleTensorShape(const Tensor &t, const std::string &name) { + NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!"); + if (is_tensor_scaling(t.scaling_mode)) { + // per-tensor scaling + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected (1), got ", + t.columnwise_scale_inv.shape, ")"); + } + } else { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { + // Need (4, 128) alignment even for e8 scaling factor + auto block_alignment = std::vector{128ul, 4ul}; + size_t expected_x, expected_y, alignment; + + if (t.has_data()) { + alignment = block_alignment[0]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(1)), alignment) * alignment; + alignment = block_alignment[1]; + expected_y = + DIVUP(DIVUP(t.flat_last_dim(), static_cast(32)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + alignment = block_alignment[1]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(32)), alignment) * alignment; + alignment = block_alignment[0]; + expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } + } else if (is_mtfp_scaling(t.scaling_mode)) { + NVTE_CHECK(t.amax.dptr == nullptr && t.scale.dptr == nullptr); + } + } +} + +void CheckInputTensor(const Tensor &t, const std::string &name) { + const DType type = t.dtype(); + if (is_fp8_dtype(type)) { + // FP8 input needs to have scale_inv + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } + } else { + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); + } + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!"); + + CheckScaleTensorShape(t, name); +} + +void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { + const DType type = t.dtype(); + if (is_fp8_dtype(type)) { + // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor"); + NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", + to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); + NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, + " (expected 1 entry, got shape=", t.amax.shape, ")"); + } + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } + } else { + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); + } + + if (!allow_empty) { + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!"); + } + + CheckScaleTensorShape(t, name); +} + +} // namespace transformer_engine + +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { + transformer_engine::Tensor *ret = new transformer_engine::Tensor; + ret->scaling_mode = scaling_mode; + return ret; +} + +void nvte_destroy_tensor(NVTETensor tensor) { + if (tensor == nullptr) return; + auto *t = reinterpret_cast(tensor); + delete t; +} + +NVTEDType nvte_tensor_type(const NVTETensor tensor) { + if (tensor == nullptr) return kNVTEFloat32; + return static_cast( + reinterpret_cast(tensor)->dtype()); +} + +NVTEShape nvte_tensor_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + + // FP8 tensor keeps shape in rowwise data + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + return ret; + } + + // Get shape based on what data is available + if (t.has_data()) { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + return ret; + } + if (t.has_columnwise_data()) { + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + return ret; + } + + // Tensor has no data + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + return ret; +} + +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + return ret; +} + +size_t nvte_tensor_ndim(const NVTETensor tensor) { + if (tensor == nullptr) return 0; + const auto &t = *reinterpret_cast(tensor); + return t.data.shape.size(); +} + +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { + if (tensor == nullptr) return 0; + const auto &t = *reinterpret_cast(tensor); + NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); + return t.data.shape[dim]; +} + +size_t nvte_tensor_numel(const NVTETensor tensor) { + if (tensor == nullptr) return 0; + const auto &t = *reinterpret_cast(tensor); + size_t numel = 1; + for (auto size : t.data.shape) { + numel *= size; + } + return numel; +} + +size_t nvte_tensor_element_size(const NVTETensor tensor) { + if (tensor == nullptr) return sizeof(float); + const auto &t = *reinterpret_cast(tensor); + return transformer_engine::typeToSize(t.data.dtype); +} + +void *nvte_tensor_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.data.dptr; +} + +void *nvte_tensor_columnwise_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_data.dptr; +} + +float *nvte_tensor_amax(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, + "Tensor's amax must have Float32 type!"); + return reinterpret_cast(t.amax.dptr); +} + +float *nvte_tensor_scale(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, + "Tensor's scale must have Float32 type!"); + return reinterpret_cast(t.scale.dptr); +} + +float *nvte_tensor_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return reinterpret_cast(t.scale_inv.dptr); +} + +void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_scale_inv.dptr; +} + +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + ret.data = t.scale_inv.shape.data(); + ret.ndim = t.scale_inv.shape.size(); + return ret; +} + +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param) { + NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL."); + NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated."); + auto &t = *reinterpret_cast(*tensor); + switch (param_name) { + case kNVTERowwiseData: + t.data = *param; + break; + case kNVTEColumnwiseData: + t.columnwise_data = *param; + break; + case kNVTEScale: + t.scale = *param; + break; + case kNVTEAmax: + t.amax = *param; + break; + case kNVTERowwiseScaleInv: + t.scale_inv = *param; + break; + case kNVTEColumnwiseScaleInv: + t.columnwise_scale_inv = *param; + break; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { + if (tensor == nullptr) { + return {nullptr, kNVTEFloat32, {nullptr, 0}}; + } + const auto &t = *reinterpret_cast(tensor); + switch (param_name) { + case kNVTERowwiseData: + return t.data; + case kNVTEColumnwiseData: + return t.columnwise_data; + case kNVTEScale: + return t.scale; + case kNVTEAmax: + return t.amax; + case kNVTERowwiseScaleInv: + return t.scale_inv; + case kNVTEColumnwiseScaleInv: + return t.columnwise_scale_inv; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.scaling_mode; +} + +void nvte_tensor_pack_create(NVTETensorPack *pack) { + for (int i = 0; i < pack->MAX_SIZE; i++) { + pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); + } +} + +void nvte_tensor_pack_destroy(NVTETensorPack *pack) { + for (int i = 0; i < pack->MAX_SIZE; i++) { + auto *t = reinterpret_cast(pack->tensors[i]); + delete t; + } +} + +void nvte_zero_tensor(const NVTETensor tensor, musaStream_t stream) { + const auto &t = *reinterpret_cast(tensor); + // Zero out tensor data if allocated + if (t.data.dptr != nullptr) { + size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); + musaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); + } + // Set amax to 0 if allocated + if (t.amax.dptr != nullptr) { + float zero = 0.0f; + musaMemcpyAsync(t.amax.dptr, &zero, sizeof(float), musaMemcpyHostToDevice, stream); + } + musaStreamSynchronize(stream); +} diff --git a/transformer_engine/musa/common/transpose/cast_transpose.h b/transformer_engine/musa/common/transpose/cast_transpose.h new file mode 100644 index 0000000000..78e61ed592 --- /dev/null +++ b/transformer_engine/musa/common/transpose/cast_transpose.h @@ -0,0 +1,28 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ +#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ + +#include "../common.h" + +namespace transformer_engine::detail { + +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, musaStream_t stream); + +template +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, musaStream_t stream); + +template +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, + musaStream_t stream); + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/musa/common/transpose/cast_transpose.mu b/transformer_engine/musa/common/transpose/cast_transpose.mu new file mode 100644 index 0000000000..e72add0bba --- /dev/null +++ b/transformer_engine/musa/common/transpose/cast_transpose.mu @@ -0,0 +1,359 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include + +#include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.muh" +#include "cast_transpose.h" + +namespace transformer_engine::detail { + +namespace { + +// String with RTC kernel implementation +#include "string_code_transpose_rtc_cast_transpose_mu.h" + +// Hard-coded kernel parameters +using CType = float; +constexpr size_t warps_per_tile = 4; +constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile; + +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /** Vector load size */ + size_t load_size = 0; + /** Vector store size to transposed output */ + size_t store_size = 0; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Elements per L1 cache load */ + size_t elements_per_load = 0; + /* Elements per L1 cache store to cast output*/ + size_t elements_per_store_c = 0; + /* Elements per L1 cache store to transposed output */ + size_t elements_per_store_t = 0; + + KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t otype_size, + size_t load_size_, size_t store_size_, size_t sm_count) + : load_size{load_size_}, store_size{store_size_} { + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % itype_size != 0 || store_size % otype_size != 0 || + cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) { + return; + } + const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size; + const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size; + valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0); + if (!valid) { + return; + } + + // Number of CUDA blocks + num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements); + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), sm_count); + elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) / itype_size); + elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) / otype_size); + elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) / otype_size); + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + // + 1/elements_per_store_c + // + 1/elements_per_store_t) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &sc1 = this->elements_per_store_c; + const auto &st1 = this->elements_per_store_t; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &sc2 = other.elements_per_store_c; + const auto &st2 = other.elements_per_store_t; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2; + const auto cost1 = (scale / l1 + scale / sc1 + scale / st1) / p1; + const auto cost2 = (scale / l2 + scale / sc2 + scale / st2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } + } +}; + +template +__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel( + const IType *__restrict__ const input, const CType *__restrict__ const noop, + OType *__restrict__ const output_c, OType *__restrict__ const output_t, + const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr, + CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(IType); + constexpr size_t nvec_out = store_size / sizeof(OType); + using IVec = Vec; + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // FP8 factors + const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; + CType amax = 0; + + // Load input and store to registers + // Note: Each thread loads num_iterations subtiles, computes amax, + // casts type, and transposes in registers. + OVecT local_output_t[nvec_in][num_iterations]; +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; +#pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + if (row < num_rows) { +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + const CType in = input[row * row_length + col + j2]; + const OType out = OType(in * scale); + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(in), amax); + output_c[row * row_length + col + j2] = out; + local_output_t[j2][iter].data.elt[i2] = out; + } + } + } + } + } + + // Copy transposed output from registers to global memory + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output_t[j1][i1] = local_output_t[j2][iter]; + } + __syncthreads(); +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + if (col < row_length) { +#pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + if (row + i2 < num_rows) { + output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; + } + } + } + } + __syncthreads(); + } + + // Reduce amax over block + if (amax_ptr != nullptr) { + amax = reduce_max(amax, tidy); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } +} + +} // namespace + +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, musaStream_t stream) { + Tensor &output = *output_; + + CheckNoopTensor(noop, "cast_transpose_noop"); + CheckInputTensor(input, "cast_transpose_input"); + CheckOutputTensor(output, "cast_transpose_output"); + + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output.has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output.has_columnwise_data(), "Output columnwise is not allocated"); + + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output.data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output.data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); + NVTE_CHECK(output.flat_first_dim() == num_rows && output.flat_last_dim() == row_length, + "Invalid output dimensions (expected ", std::vector{num_rows, row_length}, + ", got ", std::vector{output.flat_first_dim(), output.flat_last_dim()}, ")"); + + // Check that cast and transposed output data matches + NVTE_CHECK(output.data.dtype == output.columnwise_data.dtype, + "Cast and transposed output types must match."); + NVTE_CHECK(output.scale_inv.dptr == output.columnwise_scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output.dtype(), OutputType, + if (is_delayed_tensor_scaling(output.scaling_mode)) { + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + constexpr size_t itype_size = sizeof(InputType); + constexpr size_t otype_size = sizeof(OutputType); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = + (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + const size_t sm_count = static_cast(cuda::sm_count()); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, + store_size, sm_count); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = + *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "cast_transpose" + ",itype=", + itype_name, ",otype=", otype_name, ",load_size=", load_size, + ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_mu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, + "transformer_engine/common/transpose/rtc/cast_transpose.mu"); + } + rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; + const int num_blocks = + (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel + <<>>( + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); + } + } else { + NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); + }); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace transformer_engine::detail + +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose); + using namespace transformer_engine; + auto noop = Tensor(); + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), noop, + reinterpret_cast(output), stream); +} + +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_with_noop); + using namespace transformer_engine; + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), + reinterpret_cast(output), stream); +} diff --git a/transformer_engine/musa/common/transpose/cast_transpose_fusion.mu b/transformer_engine/musa/common/transpose/cast_transpose_fusion.mu new file mode 100644 index 0000000000..92151bf3c9 --- /dev/null +++ b/transformer_engine/musa/common/transpose/cast_transpose_fusion.mu @@ -0,0 +1,1414 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include +#include + +#include "../util/cast_kernels.muh" +#include "../util/math.h" +#include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.muh" +#include "cast_transpose.h" + +namespace transformer_engine { + +namespace detail { + +// String with RTC kernel implementation +#include "string_code_transpose_rtc_cast_transpose_fusion_mu.h" + +// STUFF TO TUNE +constexpr size_t n_warps_per_tile = 8; +constexpr size_t desired_load_size = 8; +constexpr size_t desired_store_size = 8; +constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers +constexpr size_t desired_store_size_dact = 4; + +constexpr size_t threads_per_warp = static_cast(THREADS_PER_WARP); +constexpr size_t max_threads_per_block = 256; +constexpr size_t reduce_dbias_num_threads = 256; +constexpr size_t cast_transpose_num_threads = n_warps_per_tile * threads_per_warp; +constexpr size_t n_warps_per_block = cast_transpose_num_threads / threads_per_warp; +static_assert(cast_transpose_num_threads <= max_threads_per_block); + +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + size_t load_size = 0; // Vector load size + size_t store_size = 0; // Vector store size to transposed output + + bool valid = false; // Whether config is valid + bool is_dact = false; // Whether dact is used + size_t num_blocks = 0; // Number of CUDA blocks + + size_t active_sm_count = 0; // Number of active SMs + size_t elements_per_load = 0; // Elements per L1 cache load + size_t elements_per_load_dact = 0; // Elements per L1 cache load dact + size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output + size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output + + KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t itype2_size, + size_t otype_size, size_t load_size_, size_t store_size_, size_t sm_count, + bool is_dact_) + : load_size{load_size_}, store_size{store_size_}, is_dact{is_dact_} { + if (is_dact) { + if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) { + return; + } + } + + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % itype_size != 0 || store_size % otype_size != 0 || + cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) { + return; + } + /* row_tile_elements */ + const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size; + /* col_tile_elements */ + const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size; + const size_t num_tiles_x = row_length / tile_size_x; + const size_t num_tiles_y = num_rows / tile_size_y; + + valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0); + if (!valid) { + return; + } + + // Number of CUDA blocks + num_blocks = num_tiles_x * num_tiles_y; + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm), sm_count); + elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) / itype_size); + elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) / itype2_size); + elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size) / otype_size); + elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size) / otype_size); + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + // + 1/elements_per_load_dact + // + 1/elements_per_store_c + // + 1/elements_per_store_t) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &la1 = this->elements_per_load_dact; + const auto &sc1 = this->elements_per_store_c; + const auto &st1 = this->elements_per_store_t; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &la2 = other.elements_per_load_dact; + const auto &sc2 = other.elements_per_store_c; + const auto &st2 = other.elements_per_store_t; + const auto &p2 = other.active_sm_count; + const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1); + const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1); + const auto scale = scale1 * scale2; + const auto cost1 = + (scale / l1 + scale / sc1 + scale / st1 + (is_dact ? (scale / la1) : 0)) / p1; + const auto cost2 = + (scale / l2 + scale / sc2 + scale / st2 + (is_dact ? (scale / la2) : 0)) / p2; + + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } + } +}; + +template +inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], + OVec (&out_trans)[nvec_in], + CVec &out_dbias, // NOLINT(*) + typename OVec::type *output_cast_tile, + const size_t current_place, const size_t stride, + const CType scale, + CType &amax, // NOLINT(*) + const int dbias_shfl_src_lane, + const bool valid_store) { + using OType = typename OVec::type; + using OVecC = Vec; + + CVec step_dbias; + if constexpr (IS_DBIAS) { + step_dbias.clear(); + } + +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + OVecC out_cast; +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + const CType tmp = in[i].data.elt[j]; + if constexpr (IS_DBIAS) { + step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation + } + out_cast.data.elt[j] = static_cast(tmp * scale); + out_trans[j].data.elt[i] = static_cast(tmp * scale); // thread tile transpose + + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(tmp), amax); + } + if (IS_FULL_TILE || valid_store) { + out_cast.store_to(output_cast_tile, current_place + stride * i); + } + } + + if constexpr (IS_DBIAS) { +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + CType elt = step_dbias.data.elt[j]; + elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp + out_dbias.data.elt[j] += elt; + } + } +} + +void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ + Tensor *workspace, const int nvec_out) { + const size_t row_length = cast_output.flat_last_dim(); + const size_t num_rows = cast_output.flat_first_dim(); + + const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + + const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } +} + +template +__global__ void __launch_bounds__(reduce_dbias_num_threads) + reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial, + const int row_length, const int num_rows) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= row_length) { + return; + } + + const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec; + OutputType *const thread_out_base = dbias_output + thread_id * nvec; + + const int stride_in_vec = row_length / nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < num_rows; ++i) { + ldg_vec.load_from(thread_in_base, i * stride_in_vec); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base, 0); +} + +template +void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length, + const size_t num_rows, const int nvec_out, musaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType); + + NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t reduce_dbias_row_length = row_length; + const size_t reduce_dbias_num_rows = + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t reduce_dbias_num_blocks = + DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec); + + using DbiasOutputType = fp32; + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), + reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, + reduce_dbias_num_rows); +} + +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { + static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); + using IType = typename Param::InputType; + using IType2 = typename Param::InputType2; + using OType = typename Param::OutputType; + using CType = typename Param::ComputeType; + using IVec = Vec; + using IVec2 = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = + (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) { + return; + } + + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const size_t tile_offset = + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + const size_t tile_offset_transp = + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + + const IType *const my_input_tile = param.input + tile_offset; + const IType2 *const my_act_input_tile = param.act_input + tile_offset; + OType *const my_output_c_tile = param.output_c + tile_offset; + OType *const my_output_t_tile = param.output_t + tile_offset_transp; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length); + + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; + const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; + const unsigned int tile_length = + row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest; + const unsigned int tile_height = + row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest; + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + CVec *const my_dbias_scratch = reinterpret_cast(scratch); + + IVec in[2][nvec_out]; + IVec2 act_in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space[n_iterations][nvec_in]; + + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + CType amax = 0; + const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; + + CVec partial_dbias; + if constexpr (IS_DBIAS) { + partial_dbias.clear(); + } + + { + const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height; +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + if (valid_load) { + const size_t ld_offset = current_stride + my_place + stride * i; + in[0][i].load_from(my_input_tile, ld_offset); + if constexpr (IS_DACT) { + act_in[0][i].load_from(my_act_input_tile, ld_offset); + } + } else { + in[0][i].clear(); + if constexpr (IS_DACT) { + act_in[0][i].clear(); + } + } + } + } + +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { + const bool valid_load = + my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height; +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + if (valid_load) { + const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j); + in[current_in][j].load_from(my_input_tile, ld_offset); + if constexpr (IS_DACT) { + act_in[current_in][j].load_from(my_act_input_tile, ld_offset); + } + } else { + in[current_in][j].clear(); + if constexpr (IS_DACT) { + act_in[current_in][j].clear(); + } + } + } + } + CVec after_dact[nvec_out]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { +#pragma unroll + for (unsigned int k = 0; k < nvec_in; ++k) { + if constexpr (IS_DACT) { + after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * + OP(act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + after_dact[j].data.elt[k] = OP(in[current_in ^ 1][j].data.elt[k], {}); + } else { + after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); + } + } + } + const int dbias_shfl_src_lane = + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + constexpr bool IS_FULL_TILE = false; + const bool valid_store = + (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height); + + cast_and_transpose_regs(after_dact, out_space[i], partial_dbias, + my_output_c_tile, current_place, stride, scale, + amax, dbias_shfl_src_lane, valid_store); + + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + current_row += nvec_out; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + if constexpr (IS_DBIAS) { + my_dbias_scratch[threadIdx.x] = partial_dbias; + __syncthreads(); + if (warp_id_in_tile == 0) { +#pragma unroll + for (unsigned int i = 1; i < n_warps_per_tile; ++i) { + CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + partial_dbias.data.elt[j] += tmp.data.elt[j]; + } + } + if (my_id_in_warp < tile_length) { + partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); + } + } + } + + // Reduce amax over block + if (param.amax != nullptr) { + amax = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(param.amax, amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } +} + +static const char *ActTypeToString[] = { + "none", // 0 + "sigmoid", // 1 + "dsigmoid", // 2 + "gelu", // 3 + "dgelu", // 4 + "qgelu", // 5 + "dqgelu", // 6 + "silu", // 7 + "dsilu", // 8 + "relu", // 9 + "drelu", // 10 + "srelu", // 11 + "dsrelu" // 12 +}; + +template +constexpr int get_activation_type() { + constexpr decltype(OP) ActivationList[] = { + nullptr, // 0 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 + }; +#pragma unroll + for (int i = 0; i < sizeof(ActivationList) / sizeof(ActivationList[0]); ++i) { + if (OP == ActivationList[i]) { + return i; + } + } + return 0; +} + +template +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, musaStream_t stream) { + // Check tensors, unless querying dbias workspace + if (!IS_DBIAS || workspace->data.dptr != nullptr) { + CheckInputTensor(input, "cast_transpose_fused_input"); + CheckOutputTensor(*output, "output"); + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr && dbias->has_data()); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr && act_input->has_data()); + CheckInputTensor(*act_input, "act_input"); + } + } + + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output->has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output->has_columnwise_data(), "Output columnwise data is not allocated"); + + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output->data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); + + // Check that cast and transposed output data matches + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, + "Cast and transposed output types must match."); + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{row_length}, "Wrong shape of DBias."); + } + if constexpr (IS_DACT) { + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OutputType, using InputType2 = InputType; + using Param = CTDBiasDActParam; + + constexpr int itype_size = sizeof(InputType); + constexpr int itype2_size = sizeof(InputType2); + constexpr int otype_size = sizeof(OutputType); + + const bool aligned = + (row_length % THREADS_PER_WARP == 0) && (num_rows % THREADS_PER_WARP == 0); + const bool jit_compiled = aligned && rtc::is_enabled(); + + size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size); + size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size); + size_t num_blocks; + + if (jit_compiled) { + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + const size_t sm_count = static_cast(cuda::sm_count()); + auto add_config = [&](size_t load_size_config, size_t store_size_config) { + kernel_configs.emplace_back(row_length, num_rows, itype_size, itype2_size, otype_size, + load_size_config, store_size_config, sm_count, IS_DACT); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + + // Select the kernel configuration with the lowest cost + const auto &kernel_config = + *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + load_size = kernel_config.load_size; + store_size = kernel_config.store_size; + num_blocks = kernel_config.num_blocks; + } + + const size_t nvec_in = load_size / itype_size; + const size_t nvec_out = store_size / otype_size; + const size_t tile_size_x = nvec_in * threads_per_warp; + const size_t tile_size_y = nvec_out * threads_per_warp; + const size_t num_tiles_x = DIVUP(row_length, tile_size_x); + const size_t num_tiles_y = DIVUP(num_rows, tile_size_y); + const size_t num_tiles = num_tiles_x * num_tiles_y; + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + + if (!jit_compiled) { + num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); + } if constexpr (IS_DBIAS) { + // Check workspace size + populate_cast_transpose_dbias_workspace_config(*output, workspace, nvec_out); + if (workspace->data.dptr == nullptr) { + return; + } + } + + size_t VecOutputTypeSize; + switch (nvec_out) { + case 1: + VecOutputTypeSize = sizeof(Vec); + break; + case 2: + VecOutputTypeSize = sizeof(Vec); + break; + case 4: + VecOutputTypeSize = sizeof(Vec); + break; + case 8: + VecOutputTypeSize = sizeof(Vec); + break; + } size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile * + (threads_per_warp + 1) * VecOutputTypeSize; + + if constexpr (IS_DBIAS) { + size_t VecComputeTypeSize; + switch (nvec_in) { + case 1: + VecComputeTypeSize = sizeof(Vec); + break; + case 2: + VecComputeTypeSize = sizeof(Vec); + break; + case 4: + VecComputeTypeSize = sizeof(Vec); + break; + case 8: + VecComputeTypeSize = sizeof(Vec); + break; + } + const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize; + if (shared_size_transpose < shared_size_dbias) { + shared_size_transpose = shared_size_dbias; + } + } + + Param param; + param.input = reinterpret_cast(input.data.dptr); + param.output_c = reinterpret_cast(output->data.dptr); + param.output_t = reinterpret_cast(output->columnwise_data.dptr); + param.scale_ptr = reinterpret_cast(output->scale.dptr); + param.amax = reinterpret_cast(output->amax.dptr); + param.scale_inv = reinterpret_cast(output->scale_inv.dptr); + if constexpr (IS_DBIAS) { + param.workspace = reinterpret_cast(workspace->data.dptr); + } if constexpr (IS_DACT) { + param.act_input = reinterpret_cast(act_input->data.dptr); + } + + // Runtime-compiled tuned kernel + if (jit_compiled) { + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *itype2_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + + int actType = 0; + if constexpr (IS_DACT || IS_ACT) { + // actType = get_activation_type(); + } + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "cast_transpose_fusion" + ",itype=", + itype_name, ",itype2=", itype2_name, ",otype=", otype_name, + ",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS, + ",IS_DACT=", IS_DACT, ",IS_ACT=", IS_ACT, + ",activationType=", ActTypeToString[actType]); + + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_fusion_mu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__ITYPE2__", itype2_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads); + code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS); + code = regex_replace(code, "__IS_DACT__", IS_DACT); + code = regex_replace(code, "__IS_ACT__", IS_ACT); + code = regex_replace(code, "__ACTIVATION_TYPE__", actType); + + rtc_manager.compile( + kernel_label, "cast_transpose_fusion_kernel_optimized", code, + "transformer_engine/common/transpose/rtc/cast_transpose_fusion.mu"); + } + + rtc_manager.set_cache_config(kernel_label, MU_FUNC_CACHE_PREFER_SHARED); + + rtc_manager.launch(kernel_label, num_blocks, cast_transpose_num_threads, + shared_size_transpose, stream, param, row_length, num_rows, + num_tiles); + } else { // Statically-compiled general kernel + constexpr size_t load_size = IS_DACT ? desired_load_size_dact : desired_load_size; + constexpr size_t store_size = IS_DACT ? desired_store_size_dact : desired_store_size; + constexpr size_t nvec_in = load_size / itype_size; + constexpr size_t nvec_out = store_size / otype_size; + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + + musaFuncSetAttribute( + cast_transpose_fused_kernel_notaligned, + musaFuncAttributePreferredSharedMemoryCarveout, 100); + cast_transpose_fused_kernel_notaligned + <<>>( + param, row_length, num_rows, num_tiles); + } + + if constexpr (IS_DBIAS) { + reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + dgated_act_cast_transpose_kernel(const IType *const input, const IType *const act_input, + OType *const output_c, OType *const output_t, + const CType *const scale_ptr, CType *const amax, + CType *const scale_inv, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { + using IVec = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) { + return; + } + + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType *const my_input_tile = + input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + const IType *const my_act_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + const IType *const my_gate_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_c_tile_0 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + OType *const my_output_c_tile_1 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_t_tile_0 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + OType *const my_output_t_tile_1 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP + + row_length * num_rows; + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + IVec in[2][nvec_out]; + IVec act_in[2][nvec_out]; + IVec gate_in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space_0[n_iterations][nvec_in]; + OVec out_space_1[n_iterations][nvec_in]; + + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + const size_t stride2 = 2 * row_length / nvec_in; + size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; + CType max = 0; + const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; + + CVec partial_dbias; + +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); + gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride2 + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + act_in[current_in][j].load_from(my_act_input_tile, + current_stride2 + my_place_in + stride2 * (nvec_out + j)); + gate_in[current_in][j].load_from(my_gate_input_tile, + current_stride2 + my_place_in + stride2 * (nvec_out + j)); + } + } + CVec after_dact[nvec_out]; // NOLINT(*) + CVec after_dgate[nvec_out]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { +#pragma unroll + for (unsigned int k = 0; k < nvec_in; ++k) { + after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * + CType(in[current_in ^ 1][j].data.elt[k]) * + CType(gate_in[current_in ^ 1][j].data.elt[k]); + after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * + OP2(act_in[current_in ^ 1][j].data.elt[k], {}); + } + } + OVec out_trans_0[nvec_in]; // NOLINT(*) + OVec out_trans_1[nvec_in]; // NOLINT(*) + + constexpr bool IS_DBIAS = false; + constexpr bool IS_FULL_TILE = true; + constexpr bool valid_store = true; + constexpr int dbias_shfl_src_lane = 0; + + cast_and_transpose_regs(after_dact, out_trans_0, partial_dbias, + my_output_c_tile_0, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + + cast_and_transpose_regs(after_dgate, out_trans_1, partial_dbias, + my_output_c_tile_1, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space_0[i][j].data.vec = out_trans_0[j].data.vec; + out_space_1[i][j].data.vec = out_trans_1[j].data.vec; + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + current_stride2 += nvec_out * stride2; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_0[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_1[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); + } +} + +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + dgated_act_cast_transpose_kernel_notaligned(const IType *const input, + const IType *const act_input, OType *const output_c, + OType *const output_t, const CType *const scale_ptr, + CType *const amax, CType *const scale_inv, + const size_t row_length, const size_t num_rows, + const size_t num_tiles) { + using IVec = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = + (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) return; + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType *const my_input_tile = + input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + const IType *const my_act_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + const IType *const my_gate_input_tile = + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_c_tile_0 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP; + OType *const my_output_c_tile_1 = + output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP + + row_length; + OType *const my_output_t_tile_0 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + OType *const my_output_t_tile_1 = + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP + + row_length * num_rows; + const size_t stride = row_length / nvec_in; + const size_t stride2 = 2 * row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; + const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; + const unsigned int tile_length = + row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest; + const unsigned int tile_height = + row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest; + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + IVec in[2][nvec_out]; + IVec act_in[2][nvec_out]; + IVec gate_in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space_0[n_iterations][nvec_in]; + OVec out_space_1[n_iterations][nvec_in]; + + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + CType max = 0; + const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; + + CVec partial_dbias; + + { + const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height; +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + if (valid_load) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); + gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); + } else { + in[0][i].clear(); + act_in[0][i].clear(); + gate_in[0][i].clear(); + } + } + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride2 + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { + { + const bool valid_load = + my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height; +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + if (valid_load) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + act_in[current_in][j].load_from( + my_act_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j)); + gate_in[current_in][j].load_from( + my_gate_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j)); + } else { + in[current_in][j].clear(); + act_in[current_in][j].clear(); + gate_in[current_in][j].clear(); + } + } + } + } + CVec after_dact[nvec_out]; // NOLINT(*) + CVec after_dgate[nvec_out]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { +#pragma unroll + for (unsigned int k = 0; k < nvec_in; ++k) { + after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * + CType(in[current_in ^ 1][j].data.elt[k]) * + CType(gate_in[current_in ^ 1][j].data.elt[k]); + after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * + OP2(act_in[current_in ^ 1][j].data.elt[k], {}); + } + } + OVec out_trans_0[nvec_in]; // NOLINT(*) + OVec out_trans_1[nvec_in]; // NOLINT(*) + + constexpr bool IS_DBIAS = false; + constexpr bool IS_FULL_TILE = false; + constexpr int dbias_shfl_src_lane = 0; + const bool valid_store = + (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height); + + cast_and_transpose_regs(after_dact, out_trans_0, partial_dbias, + my_output_c_tile_0, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + cast_and_transpose_regs(after_dgate, out_trans_1, partial_dbias, + my_output_c_tile_1, current_place, stride2, + scale, max, dbias_shfl_src_lane, valid_store); + +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space_0[i][j].data.vec = out_trans_0[j].data.vec; + out_space_1[i][j].data.vec = out_trans_1[j].data.vec; + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + current_stride2 += nvec_out * stride2; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_0[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space_1[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, scale); + } +} + +template +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, + musaStream_t stream) { + CheckInputTensor(input, "dgated_act_cast_transpose_input"); + CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input"); + CheckOutputTensor(*output, "dgated_act_cast_transpose_output"); + + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output->has_data() && output->has_columnwise_data(), + "Both rowwise and columnwise data need to be allocated."); + NVTE_CHECK(output->data.shape.size() == 2, "C output must have 2 dimensions."); + NVTE_CHECK(output->columnwise_data.shape.size() == 2, "T output must have 2 dimensions."); + const size_t row_length = input.data.shape[1]; + const size_t num_rows = input.data.shape[0]; + + NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output."); + NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output."); + NVTE_CHECK(output->data.shape[0] == num_rows, "Wrong dimension of output."); + NVTE_CHECK(output->data.shape[1] == row_length * 2, "Wrong dimension of output."); + NVTE_CHECK(output->columnwise_data.shape[0] == row_length * 2, "Wrong dimension of T output."); + NVTE_CHECK(output->columnwise_data.shape[1] == num_rows, "Wrong dimension of T output."); + + NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match."); + + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, + "C and T outputs need to have the same type."); + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, + "C and T outputs need to share scale inverse tensor."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OutputType, using InputType2 = InputType; + /* dact fusion kernel uses more registers */ + constexpr int desired_load_size_dact = 4; + constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType); + constexpr int otype_size = sizeof(OutputType); + constexpr int nvec_in = desired_load_size_dact / itype_size; + constexpr int nvec_out = desired_store_size_dact / otype_size; + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + const size_t n_tiles = + DIVUP(row_length, static_cast(nvec_in * THREADS_PER_WARP)) * + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; + const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); + + const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && + num_rows % (nvec_out * THREADS_PER_WARP) == 0; + const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * + (THREADS_PER_WARP + 1) * sizeof(Vec); + if (full_tile) { + musaFuncSetAttribute( + dgated_act_cast_transpose_kernel, + musaFuncAttributePreferredSharedMemoryCarveout, 100); + + dgated_act_cast_transpose_kernel + <<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(gated_act_input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, + n_tiles); + } else { + musaFuncSetAttribute( + dgated_act_cast_transpose_kernel_notaligned, + musaFuncAttributePreferredSharedMemoryCarveout, 100); + dgated_act_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(gated_act_input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, + n_tiles); + }); // NOLINT(*) + ); // NOLINT(*) +} + +// Explicit template instantiation +template void cast_transpose_fused( + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, musaStream_t); +#define NVTE_INSTANTIATE_ACTIVATION(op) \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, musaStream_t); \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, musaStream_t); +NVTE_INSTANTIATE_ACTIVATION(relu); +NVTE_INSTANTIATE_ACTIVATION(srelu); +NVTE_INSTANTIATE_ACTIVATION(gelu); +NVTE_INSTANTIATE_ACTIVATION(qgelu); +NVTE_INSTANTIATE_ACTIVATION(silu); +#undef NVTE_INSTANTIATE_ACTIVATION + +} // namespace detail + +} // namespace transformer_engine + +using ComputeType = typename transformer_engine::fp32; + +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + + constexpr const NVTETensor activation_input = nullptr; + + cast_transpose_fused( + *reinterpret_cast(input), reinterpret_cast(activation_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); +} + +void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(act_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); +} + +void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(silu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); +} + +void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(relu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); +} + +void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(srelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); +} + +void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(qgelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); +} + +void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_dgeglu_cast_transpose); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + dgated_act_cast_transpose, gelu>( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(output), stream); +} + +void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input, + NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_dswiglu_cast_transpose); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + dgated_act_cast_transpose, silu>( + *reinterpret_cast(input), *reinterpret_cast(swiglu_input), + reinterpret_cast(output), stream); +} + +void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_dreglu_cast_transpose); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + dgated_act_cast_transpose, relu>( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(output), stream); +} + +void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_dsreglu_cast_transpose); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + dgated_act_cast_transpose, srelu>( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(output), stream); +} + +void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, + NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_dqgeglu_cast_transpose); + using namespace transformer_engine; + using namespace transformer_engine::detail; + + dgated_act_cast_transpose, qgelu>( + *reinterpret_cast(input), *reinterpret_cast(gated_act_input), + reinterpret_cast(output), stream); +} diff --git a/transformer_engine/musa/common/transpose/multi_cast_transpose.mu b/transformer_engine/musa/common/transpose/multi_cast_transpose.mu new file mode 100644 index 0000000000..09d48456d4 --- /dev/null +++ b/transformer_engine/musa/common/transpose/multi_cast_transpose.mu @@ -0,0 +1,341 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../utils.muh" + +namespace transformer_engine { + +namespace { + +// Parameters to tune +constexpr int n_warps_per_tile = 4; +constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; +constexpr int desired_load_size = 8; +constexpr int desired_store_size = 8; +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB + +struct MultiCastTransposeArgs { + // (input) Data buffers for input tensors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for cast output tensors + void* output_c_list[kMaxTensorsPerKernel]; + // (output) Data buffers for transpose output tensors + void* output_t_list[kMaxTensorsPerKernel]; + // (input) Scaling factor for output tensors + void* scale_list[kMaxTensorsPerKernel]; + // (output) AMAX's of input tensors + void* amax_list[kMaxTensorsPerKernel]; + // (output) Inverse of scaling factor for output tensors + void* scale_inv_list[kMaxTensorsPerKernel]; + // Input matrix heights + int num_rows_list[kMaxTensorsPerKernel]; + // Input matrix widths + int row_length_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void __launch_bounds__(threads_per_block) + multi_cast_transpose_kernel(MultiCastTransposeArgs args) { + using IVec = Vec; + using OVecC = Vec; + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr int bdimx = THREADS_PER_WARP; + constexpr int bdimy = n_warps_per_tile; + const int tid = threadIdx.x; + const int tidx = tid % bdimx; + const int tidy = tid / bdimx; + const int bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + + // Find tensor corresponding to block + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const IType* input = reinterpret_cast(args.input_list[tensor_id]); + OType* output_c = reinterpret_cast(args.output_c_list[tensor_id]); + OType* output_t = reinterpret_cast(args.output_t_list[tensor_id]); + const CType* scale_ptr = reinterpret_cast(args.scale_list[tensor_id]); + const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; + CType* amax_ptr = reinterpret_cast(args.amax_list[tensor_id]); + CType* scale_inv_ptr = reinterpret_cast(args.scale_inv_list[tensor_id]); + const int num_rows = args.num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + // Find position of tile within tensor + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int tile_id = bid - args.block_range[tensor_id]; + const int tile_id_m = tile_id / num_tiles_n; + const int tile_id_n = tile_id % num_tiles_n; + const int tile_row = tile_id_m * tile_dim_m; + const int tile_col = tile_id_n * tile_dim_n; + + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + OVecT local_output_t[nvec_in][n_iterations]; + CType local_amax = 0; +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec_out; ++i2) { + const int row = tile_row + i1 * nvec_out + i2; + const int col = tile_col + j1 * nvec_in; + IVec local_input; + OVecC local_output_c; + if constexpr (aligned) { + local_input.load_from(&input[row * row_length + col]); + } else { + local_input.clear(); + if (row < num_rows) { +#pragma unroll + for (int j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec_in; ++j2) { + const CType x = CType(local_input.data.elt[j2]); + const OType y = OType(scale * x); + local_output_c.data.elt[j2] = y; + local_output_t[j2][iter].data.elt[i2] = y; + __builtin_assume(local_amax >= 0); + local_amax = fmaxf(fabsf(x), local_amax); + } + if constexpr (aligned) { + local_output_c.store_to(&output_c[row * row_length + col]); + } else { + if (row < num_rows) { +#pragma unroll + for (int j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + output_c[row * row_length + col + j2] = local_output_c.data.elt[j2]; + } + } + } + } + } + } + + // Copy transposed output from registers to global memory + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll + for (int j2 = 0; j2 < nvec_in; ++j2) { +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; + shared_output_t[j1][i1] = local_output_t[j2][iter]; + } + __syncthreads(); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidx; + const int j1 = tidy + iter * bdimy; + const int row = tile_row + i1 * nvec_out; + const int col = tile_col + j1 * nvec_in + j2; + if constexpr (aligned) { + shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); + } else { + if (col < row_length) { +#pragma unroll + for (int i2 = 0; i2 < nvec_out; ++i2) { + if (row + i2 < num_rows) { + output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; + } + } + } + } + } + __syncthreads(); + } + + // Finalize fp8 factors + local_amax = reduce_max(local_amax, tidy); + if (tid == 0) { + static_assert(std::is_same::value); + if (amax_ptr != nullptr) atomicMaxFloat(amax_ptr, local_amax); + } + if (tile_id == 0 && tid == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } +} + +} // namespace + +void multi_cast_transpose(const std::vector input_list, std::vector output_list, + musaStream_t stream) { + // Check that number of tensors is valid + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); + if (input_list.empty()) { + return; + } + + // Check that tensor properties are valid + DType itype = input_list[0]->data.dtype; + DType otype = output_list[0]->dtype(); + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = *input_list[tensor_id]; + const auto& output = *output_list[tensor_id]; + CheckInputTensor(input, "multi_cast_transpose_input_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_cast_transpose_output_" + std::to_string(tensor_id)); + //std::cout << *static_cast(output.data.dptr) << std::endl; + NVTE_CHECK(output.has_data() && output.has_columnwise_data(), + "Both rowwise and columnwise output data needs to be allocated."); + + NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "C output tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "T output tensor types do not match."); + + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions, but shape is ", + input.data.shape); + NVTE_CHECK(output.data.shape == input.data.shape, "C output tensor shape ", output.data.shape, + "does not match input tensor shape ", input.data.shape); + NVTE_CHECK(output.columnwise_data.shape.size() == 2, "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[0] == input.data.shape[1], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[1] == input.data.shape[0], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + } + + // Input matrices are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_store_size / typeToSize(otype); + const int tile_dim_n = THREADS_PER_WARP * desired_load_size / typeToSize(itype); + + // Add tensors to kernel argument struct + MultiCastTransposeArgs kernel_args_aligned, kernel_args_unaligned; + kernel_args_aligned.num_tensors = 0; + kernel_args_aligned.block_range[0] = 0; + kernel_args_unaligned.num_tensors = 0; + kernel_args_unaligned.block_range[0] = 0; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + // Launch kernel if argument struct is full + // if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) { + // TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + // itype, InputType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + // otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + // constexpr int nvec_out = desired_store_size / sizeof(OutputType); + // const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; + // multi_cast_transpose_kernel + // <<>>(kernel_args_aligned);); // NOLINT(*) + // ); // NOLINT(*) + // kernel_args_aligned.num_tensors = 0; + // } + // if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { + // TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + // itype, InputType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + // otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + // constexpr int nvec_out = desired_store_size / sizeof(OutputType); + // const int n_blocks = + // kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; + // multi_cast_transpose_kernel + // <<>>(kernel_args_unaligned);); // NOLINT(*) + // ); // NOLINT(*) + // kernel_args_unaligned.num_tensors = 0; + // } + + // Calculate number of thread blocks needed for tensor + const int num_rows = input_list[tensor_id]->data.shape[0]; + const int row_length = input_list[tensor_id]->data.shape[1]; + const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int num_tiles = num_tiles_m * num_tiles_n; + + // Figure out whether to use aligned or unaligned kernel + const bool aligned = + ((num_tiles_m * tile_dim_m == num_rows) && (num_tiles_n * tile_dim_n == row_length)); + auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned; + + // Add tensor to kernel argument struct + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); + kernel_args.output_c_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.output_t_list[pos] = output_list[tensor_id]->columnwise_data.dptr; + kernel_args.scale_list[pos] = output_list[tensor_id]->scale.dptr; + kernel_args.amax_list[pos] = output_list[tensor_id]->amax.dptr; + kernel_args.scale_inv_list[pos] = output_list[tensor_id]->scale_inv.dptr; + kernel_args.num_rows_list[pos] = num_rows; + kernel_args.row_length_list[pos] = row_length; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.num_tensors++; + } + + // Launch kernel + // if (kernel_args_aligned.num_tensors > 0) { + // TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + // itype, InputType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + // otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + // constexpr int nvec_out = desired_store_size / sizeof(OutputType); + // const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; + // multi_cast_transpose_kernel + // <<>>(kernel_args_aligned);); // NOLINT(*) + // ); // NOLINT(*) + // } + // if (kernel_args_unaligned.num_tensors > 0) { + // TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + // itype, InputType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + // otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType); + // constexpr int nvec_out = desired_store_size / sizeof(OutputType); + // const int n_blocks = + // kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; + // multi_cast_transpose_kernel + // <<>>(kernel_args_unaligned);); // NOLINT(*) + // ); // NOLINT(*) + // } +} + +} // namespace transformer_engine + +void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, + NVTETensor* output_list, musaStream_t stream) { + NVTE_API_CALL(nvte_multi_cast_transpose); + using namespace transformer_engine; + std::vector input_list_, output_list_; + for (size_t i = 0; i < num_tensors; ++i) { + input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); + output_list_.push_back(reinterpret_cast(output_list[i])); + } + multi_cast_transpose(input_list_, output_list_, stream); +} diff --git a/transformer_engine/musa/common/transpose/rtc/cast_transpose.mu b/transformer_engine/musa/common/transpose/rtc/cast_transpose.mu new file mode 100644 index 0000000000..a47b4d1bc5 --- /dev/null +++ b/transformer_engine/musa/common/transpose/rtc/cast_transpose.mu @@ -0,0 +1,129 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "utils.muh" + +using namespace transformer_engine; + +namespace { + +// Parameters +using CType = float; +using IType = __ITYPE__; +using OType = __OTYPE__; +constexpr size_t load_size = __LOAD_SIZE__; +constexpr size_t store_size = __STORE_SIZE__; +constexpr size_t warps_per_tile = __WARPS_PER_TILE__; +constexpr size_t block_size = __BLOCK_SIZE__; + +} // namespace + +__global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel( + const IType* __restrict__ const input, const CType* __restrict__ const noop, + OType* __restrict__ const output_c, OType* __restrict__ const output_t, + const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr, + CType* __restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(IType); + constexpr size_t nvec_out = store_size / sizeof(OType); + using IVec = Vec; + using OVecC = Vec; + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = num_rows / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // FP8 factors + const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; + CType amax = 0; + + // Load input to registers and transpose + // Note: Each thread loads num_iterations subtiles, computes amax, + // casts type, and transposes in registers. + OVecT local_output_t[nvec_in][num_iterations]; +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; +#pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + IVec local_input; + OVecC local_output_c; + local_input.load_from(&input[row * row_length + col]); +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + const CType in = static_cast(local_input.data.elt[j2]); + const OType out = OType(in * scale); + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(in), amax); + local_output_c.data.elt[j2] = out; + local_output_t[j2][iter].data.elt[i2] = out; + } + local_output_c.store_to(&output_c[row * row_length + col]); + } + } + + // Copy from registers to shared memory to global memory + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output_t[j1][i1] = local_output_t[j2][iter]; + } + __syncthreads(); +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); + } + __syncthreads(); + } + + // Reduce amax over block + if (amax_ptr != nullptr) { + amax = reduce_max(amax, tidy); + if (threadIdx.x == 0) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) { + reciprocal(scale_inv_ptr, scale); + } +} diff --git a/transformer_engine/musa/common/transpose/rtc/cast_transpose_fusion.mu b/transformer_engine/musa/common/transpose/rtc/cast_transpose_fusion.mu new file mode 100644 index 0000000000..b32348c5c7 --- /dev/null +++ b/transformer_engine/musa/common/transpose/rtc/cast_transpose_fusion.mu @@ -0,0 +1,255 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "util/math.h" +#include "utils.muh" + +using namespace transformer_engine; + +namespace { + +// Parameters +using CType = float; +using IType = __ITYPE__; +using IType2 = __ITYPE2__; +using OType = __OTYPE__; +constexpr size_t LOAD_SIZE = __LOAD_SIZE__; +constexpr size_t STORE_SIZE = __STORE_SIZE__; +constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__; +constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__; +constexpr bool IS_DBIAS = __IS_DBIAS__; +constexpr bool IS_DACT = __IS_DACT__; +constexpr bool IS_ACT = __IS_ACT__; +static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); +constexpr size_t ACT_TYPE = __ACTIVATION_TYPE__; + +constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType); +constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType); +using CVec = Vec; +using IVec = Vec; +using IVec2 = Vec; +using OVec = Vec; +using Param = CTDBiasDActParam; + +using OP = CType (*)(const CType, const Empty &); +constexpr OP ActivationList[] = { + nullptr, // 0 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 +}; + +} // namespace + +inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT], + OVec (&out_trans)[NVEC_IN], + CVec &out_dbias, // NOLINT(*) + typename OVec::type *output_cast_tile, + const size_t current_place, + const size_t stride, const CType scale, + CType &amax, // NOLINT(*) + const int dbias_shfl_src_lane) { + using OVecC = Vec; + + CVec step_dbias; + if constexpr (IS_DBIAS) { + step_dbias.clear(); + } + +#pragma unroll + for (unsigned int i = 0; i < NVEC_OUT; ++i) { + OVecC out_cast; +#pragma unroll + for (unsigned int j = 0; j < NVEC_IN; ++j) { + const CType tmp = in[i].data.elt[j]; + if constexpr (IS_DBIAS) { + step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation + } + out_cast.data.elt[j] = static_cast(tmp * scale); + out_trans[j].data.elt[i] = static_cast(tmp * scale); // thread tile transpose + + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(tmp), amax); + } + out_cast.store_to(output_cast_tile, current_place + stride * i); + } + + if constexpr (IS_DBIAS) { +#pragma unroll + for (unsigned int j = 0; j < NVEC_IN; ++j) { + CType elt = step_dbias.data.elt[j]; + elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp + out_dbias.data.elt[j] += elt; + } + } +} + +__global__ void __launch_bounds__(BLOCK_SIZE) + cast_transpose_fusion_kernel_optimized(const Param param, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) + warp_id / WARPS_PER_TILE; + if (tile_id >= num_tiles) { + return; + } + + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const size_t tile_offset = + (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) * THREADS_PER_WARP; + const size_t tile_offset_transp = + (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) * THREADS_PER_WARP; + + const IType *const my_input_tile = param.input + tile_offset; + const IType2 *const my_act_input_tile = param.act_input + tile_offset; + OType *const my_output_c_tile = param.output_c + tile_offset; + OType *const my_output_t_tile = param.output_t + tile_offset_transp; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (NVEC_IN * THREADS_PER_WARP) + tile_id_y * row_length); + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + CVec *const my_dbias_scratch = reinterpret_cast(scratch); + + IVec in[2][NVEC_OUT]; + IVec2 act_in[2][NVEC_OUT]; + + const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE; + constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE; + OVec out_space[n_iterations][NVEC_IN]; + + const size_t stride = row_length / NVEC_IN; + const size_t output_stride = num_rows / NVEC_OUT; + size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride; + size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + + CType amax = 0.0f; + const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; + + CVec partial_dbias; + if constexpr (IS_DBIAS) { + partial_dbias.clear(); + } + +#pragma unroll + for (unsigned int i = 0; i < NVEC_OUT; ++i) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + if constexpr (IS_DACT) { + act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i); + } + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { +#pragma unroll + for (unsigned int j = 0; j < NVEC_OUT; ++j) { + const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j); + in[current_in][j].load_from(my_input_tile, ld_offset); + if constexpr (IS_DACT) { + act_in[current_in][j].load_from(my_act_input_tile, ld_offset); + } + } + } + CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*) +#pragma unroll + for (unsigned int j = 0; j < NVEC_OUT; ++j) { +#pragma unroll + for (unsigned int k = 0; k < NVEC_IN; ++k) { + if constexpr (IS_DACT) { + in_cast_fp32[j].data.elt[k] = + static_cast(in[current_in ^ 1][j].data.elt[k]) * + ActivationList[ACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + in_cast_fp32[j].data.elt[k] = + ActivationList[ACT_TYPE](in[current_in ^ 1][j].data.elt[k], {}); + } else { + in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]); + } + } + } + + const int dbias_shfl_src_lane = + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + + cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, my_output_c_tile, + current_place, stride, scale, amax, dbias_shfl_src_lane); + + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += NVEC_OUT * stride; + current_row += NVEC_OUT; + } + +#pragma unroll + for (unsigned int i = 0; i < NVEC_IN; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * NVEC_IN; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * NVEC_IN; + } + __syncthreads(); + } + + if constexpr (IS_DBIAS) { + my_dbias_scratch[threadIdx.x] = partial_dbias; + __syncthreads(); + if (warp_id_in_tile == 0) { +#pragma unroll + for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) { + CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; +#pragma unroll + for (unsigned int j = 0; j < NVEC_IN; ++j) { + partial_dbias.data.elt[j] += tmp.data.elt[j]; + } + } + partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); + } + } + + // Reduce amax over block + if (param.amax != nullptr) { + const CType max_block = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { + atomicMaxFloat(param.amax, max_block); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) { + reciprocal(param.scale_inv, scale); + } +} diff --git a/transformer_engine/musa/common/transpose/rtc/transpose.mu b/transformer_engine/musa/common/transpose/rtc/transpose.mu new file mode 100644 index 0000000000..1e6a31b554 --- /dev/null +++ b/transformer_engine/musa/common/transpose/rtc/transpose.mu @@ -0,0 +1,101 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "utils.muh" + +using namespace transformer_engine; + +namespace { + +// Parameters +using Type = __TYPE__; +constexpr size_t load_size = __LOAD_SIZE__; +constexpr size_t store_size = __STORE_SIZE__; +constexpr size_t warps_per_tile = __WARPS_PER_TILE__; +constexpr size_t block_size = __BLOCK_SIZE__; + +} // namespace + +__global__ void __launch_bounds__(block_size) + transpose_optimized_kernel(const Type* __restrict__ const input, const float* const noop, + Type* __restrict__ const output, const size_t row_length, + const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(Type); + constexpr size_t nvec_out = store_size / sizeof(Type); + using IVec = Vec; + using OVec = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = num_rows / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // Load input to registers and transpose + // Note: Each thread loads num_iterations subtiles and transposes in + // registers. + OVec local_output[nvec_in][num_iterations]; +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; +#pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + IVec local_input; + local_input.load_from(&input[row * row_length + col]); +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; + } + } + } + + // Copy from registers to shared memory to global memory + __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output[j1][i1] = local_output[j2][iter]; + } + __syncthreads(); +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + shared_output[j1][i1].store_to(&output[col * num_rows + row]); + } + __syncthreads(); + } +} diff --git a/transformer_engine/musa/common/transpose/transpose.mu b/transformer_engine/musa/common/transpose/transpose.mu new file mode 100644 index 0000000000..96641cba01 --- /dev/null +++ b/transformer_engine/musa/common/transpose/transpose.mu @@ -0,0 +1,301 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.muh" + +namespace transformer_engine { + +namespace { + +// String with RTC kernel implementation +#include "string_code_transpose_rtc_transpose_mu.h" + +// Hard-coded kernel parameters +constexpr size_t warps_per_tile = 4; +constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile; + +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /** Vector load size */ + size_t load_size; + /** Vector store size */ + size_t store_size; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Elements per L1 cache load */ + size_t elements_per_load = 0; + /* Elements per L1 cache store */ + size_t elements_per_store = 0; + + KernelConfig(size_t row_length, size_t num_rows, size_t type_size, size_t load_size_, + size_t store_size_, size_t sm_count) + : load_size{load_size_}, store_size{store_size_} { + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % type_size != 0 || store_size % type_size != 0 || + cache_line_size % type_size != 0) { + return; + } + const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size; + const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size; + valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0); + if (!valid) { + return; + } + + // Number of CUDA blocks + num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements); + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), sm_count); + elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) / type_size); + elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) / type_size); + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + 1/elements_per_store) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &s1 = this->elements_per_store; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &s2 = other.elements_per_store; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * s1 * p1 * l2 * s2 * p2; + const auto cost1 = (scale / l1 + scale / s1) / p1; + const auto cost2 = (scale / l2 + scale / s2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } + } +}; + +template +__global__ void __launch_bounds__(block_size) + transpose_general_kernel(const Type *__restrict__ const input, const fp32 *const noop, + Type *__restrict__ const output, const size_t row_length, + const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(Type); + constexpr size_t nvec_out = store_size / sizeof(Type); + using IVec = Vec; + using OVec = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // Load input and store to registers + // Note: Each thread loads num_iterations subtiles and transposes in + // registers. + OVec local_output[nvec_in][num_iterations]; +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; +#pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + IVec local_input; + local_input.clear(); + if (row < num_rows) { +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; + } + } + } + + // Copy transposed output from registers to global memory + __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1]; +#pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output[j1][i1] = local_output[j2][iter]; + } + __syncthreads(); +#pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + if (col < row_length) { +#pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + if (row + i2 < num_rows) { + output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2]; + } + } + } + } + __syncthreads(); + } +} + +} // namespace + +void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, musaStream_t stream) { + Tensor &output = *output_; + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions."); + const size_t row_length = input.data.shape[1]; + const size_t num_rows = input.data.shape[0]; + + NVTE_CHECK(output.data.shape[0] == row_length, "Wrong dimension of output."); + NVTE_CHECK(output.data.shape[1] == num_rows, "Wrong dimension of output."); + + NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); + NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); + NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + input.data.dtype, Type, constexpr const char *type_name = TypeInfo::name; + constexpr size_t type_size = sizeof(Type); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + const size_t sm_count = static_cast(cuda::sm_count()); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, type_size, load_size, store_size, + sm_count); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "transpose" + ",type=", + type_name, ",load_size=", load_size, ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_transpose_mu; + code = regex_replace(code, "__TYPE__", type_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, "transpose_optimized_kernel", code, + "transformer_engine/common/transpose/rtc/transpose.mu"); + } + rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + static_cast(noop.data.dptr), + static_cast(output.data.dptr), row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP; + const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + transpose_general_kernel + <<>>(static_cast(input.data.dptr), + static_cast(noop.data.dptr), + static_cast(output.data.dptr), + row_length, num_rows); + }); // NOLINT(*) +} + +} // namespace transformer_engine + +void nvte_transpose(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_transpose); + using namespace transformer_engine; + auto noop = Tensor(); + transpose(*reinterpret_cast(input), noop, reinterpret_cast(output), + stream); +} + +void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, + musaStream_t stream) { + NVTE_API_CALL(nvte_transpose_with_noop); + using namespace transformer_engine; + transpose(*reinterpret_cast(input), *reinterpret_cast(noop), + reinterpret_cast(output), stream); +} diff --git a/transformer_engine/musa/common/transpose/transpose_fusion.mu b/transformer_engine/musa/common/transpose/transpose_fusion.mu new file mode 100644 index 0000000000..497e36b972 --- /dev/null +++ b/transformer_engine/musa/common/transpose/transpose_fusion.mu @@ -0,0 +1,501 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../utils.muh" + +namespace transformer_engine { + +template +inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], + OVec (&out_trans)[nvec_in], + CVec &out_dbias, // NOLINT(*) + const CType scale_inv, + const int dbias_shfl_src_lane) { + using T = typename OVec::type; + using OVecC = Vec; + + CVec step_dbias; + step_dbias.clear(); + +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + const CType tmp = static_cast(in[i].data.elt[j]) * scale_inv; + const T elt_o = in[i].data.elt[j]; + + /* dbias: thread tile local accumulation */ + step_dbias.data.elt[j] += tmp; + + out_trans[j].data.elt[i] = elt_o; // thread tile transpose + } + } + +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + CType elt = step_dbias.data.elt[j]; + elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in warp + out_dbias.data.elt[j] += elt; + } +} + +// STUFF TO TUNE +constexpr unsigned int n_warps_per_tile = 4; +constexpr int desired_load_size = 8; +constexpr int desired_store_size = 8; + +constexpr unsigned int max_threads_per_block = 256; +static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block); +constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP; + +namespace { + +template +struct TDBiasParam { + using InputType = IType; + using OutputType = OType; + using ComputeType = CType; + const IType *input; + OType *output_t; + const CType *scale_inv; + CType *workspace; +}; + +} // namespace + +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + transpose_dbias_kernel(const Param param, const size_t row_length, const size_t num_rows, + const size_t num_tiles) { + using IType = typename Param::InputType; + using OType = typename Param::OutputType; + using CType = typename Param::ComputeType; + using IVec = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); + // const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) return; + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType *const my_input_tile = + param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + OType *const my_output_t_tile = + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length); + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + CVec *const my_dbias_scratch = reinterpret_cast(scratch); + + IVec in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space[n_iterations][nvec_in]; + CVec partial_dbias; + + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; + + partial_dbias.clear(); + +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + } + } + OVec out_trans[nvec_in]; // NOLINT(*) + transpose_regs_partial_dbias( + in[current_in ^ 1], out_trans, partial_dbias, scale_inv, + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); + +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space[i][j].data.vec = out_trans[j].data.vec; + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + my_dbias_scratch[threadIdx.x] = partial_dbias; + __syncthreads(); + // TODO(ptredak): check if the regular reduction is better + if (warp_id_in_tile == 0) { +#pragma unroll + for (unsigned int i = 1; i < n_warps_per_tile; ++i) { + CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + partial_dbias.data.elt[j] += tmp.data.elt[j]; + } + } + + partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); + } +} + +template +__global__ void __launch_bounds__(cast_transpose_num_threads) + transpose_dbias_kernel_notaligned(const Param param, const size_t row_length, + const size_t num_rows, const size_t num_tiles) { + using IType = typename Param::InputType; + using OType = typename Param::OutputType; + using CType = typename Param::ComputeType; + using IVec = Vec; + using OVec = Vec; + using CVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = + (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) return; + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType *const my_input_tile = + param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP; + OType *const my_output_t_tile = + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP; + CType *const my_partial_dbias_tile = + param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length); + + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; + const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; + const unsigned int tile_length = + row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest; + const unsigned int tile_height = + row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest; + + OVec *const my_scratch = + reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1); + + CVec *const my_dbias_scratch = reinterpret_cast(scratch); + + IVec in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space[n_iterations][nvec_in]; + CVec partial_dbias; + + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + unsigned int my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; + + partial_dbias.clear(); + + { + const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height; +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + if (valid_load) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + } else { + in[0][i].clear(); + } + } + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { + const bool valid_load = + my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height; +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + if (valid_load) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + } else { + in[current_in][j].clear(); + } + } + } + OVec out_trans[nvec_in]; // NOLINT(*) + transpose_regs_partial_dbias( + in[current_in ^ 1], out_trans, partial_dbias, scale_inv, + (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); + +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space[i][j].data.vec = out_trans[j].data.vec; + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP] = out_space[j][i]; + } + __syncthreads(); + my_place = + (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP; + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + my_dbias_scratch[threadIdx.x] = partial_dbias; + __syncthreads(); + // TODO(ptredak): check if the regular reduction is better + if (warp_id_in_tile == 0) { +#pragma unroll + for (unsigned int i = 1; i < n_warps_per_tile; ++i) { + CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + partial_dbias.data.elt[j] += tmp.data.elt[j]; + } + } + + if (my_id_in_warp < tile_length) { + partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); + } + } +} + +constexpr size_t reduce_dbias_num_threads = 256; + +template +__global__ void __launch_bounds__(reduce_dbias_num_threads) + reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial, + const int row_length, const int num_rows) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= row_length) return; + + const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec; + OutputType *const thread_out_base = dbias_output + thread_id * nvec; + + const int stride_in_vec = row_length / nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < num_rows; ++i) { + ldg_vec.load_from(thread_in_base, i * stride_in_vec); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base, 0); +} + +void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ + Tensor *workspace, const int nvec_out) { + const size_t row_length = input.data.shape[1]; + const size_t num_rows = input.data.shape[0]; + + const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + + const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); + + if (workspace->data.dptr == nullptr) { + // Set workspace size + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } +} + +template +void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length, + const size_t num_rows, const int nvec_out, musaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType); + + NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t reduce_dbias_row_length = row_length; + const size_t reduce_dbias_num_rows = + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t reduce_dbias_num_blocks = + DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), + reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, + reduce_dbias_num_rows); +} + +void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias, + Tensor *workspace, musaStream_t stream) { + CheckInputTensor(input, "fp8_transpose_dbias_input"); + CheckOutputTensor(*transposed_output, "transposed_output"); + CheckOutputTensor(*dbias, "dbias"); + + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); + const size_t row_length = input.data.shape[1]; + const size_t num_rows = input.data.shape[0]; + + NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); + NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + + NVTE_CHECK(transposed_output->data.dtype == input.data.dtype, + "T output must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{row_length}, "Wrong shape of DBias."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dbias->data.dtype, BiasType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, Type, constexpr int type_size = sizeof(Type); + constexpr int nvec_in = desired_load_size / type_size; + constexpr int nvec_out = desired_store_size / type_size; + + // Check workspace size + populate_transpose_dbias_workspace_config(input, workspace, nvec_out); + if (workspace->data.dptr == nullptr) { return; } + + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); + const size_t n_tiles = + DIVUP(row_length, static_cast(nvec_in * THREADS_PER_WARP)) * + DIVUP(num_rows, static_cast(nvec_out * THREADS_PER_WARP)); + const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; + const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); + + const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && + num_rows % (nvec_out * THREADS_PER_WARP) == 0; + + using ComputeType = fp32; constexpr size_t shared_size_transpose = + cast_transpose_num_threads / n_warps_per_tile * + (THREADS_PER_WARP + 1) * sizeof(Vec); + constexpr size_t shared_size_dbias = + cast_transpose_num_threads * sizeof(Vec); + static_assert(shared_size_transpose >= shared_size_dbias); + using Param = TDBiasParam; Param param; + param.input = reinterpret_cast(input.data.dptr); + param.output_t = reinterpret_cast(transposed_output->data.dptr); + param.scale_inv = + reinterpret_cast(transposed_output->scale_inv.dptr); + param.workspace = reinterpret_cast(workspace->data.dptr); + + if (full_tile) { + musaFuncSetAttribute(transpose_dbias_kernel, + musaFuncAttributePreferredSharedMemoryCarveout, 100); + transpose_dbias_kernel + <<>>( + param, row_length, num_rows, n_tiles); + } else { + musaFuncSetAttribute(transpose_dbias_kernel_notaligned, + musaFuncAttributePreferredSharedMemoryCarveout, 100); + transpose_dbias_kernel_notaligned + <<>>( + param, row_length, num_rows, n_tiles); + } + + reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, + stream);); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace transformer_engine + +void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output, + NVTETensor dbias, NVTETensor workspace, musaStream_t stream) { + NVTE_API_CALL(nvte_fp8_transpose_dbias); + using namespace transformer_engine; + fp8_transpose_dbias( + *reinterpret_cast(input), reinterpret_cast(transposed_output), + reinterpret_cast(dbias), reinterpret_cast(workspace), stream); +} diff --git a/transformer_engine/musa/common/util/cast.mu b/transformer_engine/musa/common/util/cast.mu new file mode 100644 index 0000000000..d1d7c4609e --- /dev/null +++ b/transformer_engine/musa/common/util/cast.mu @@ -0,0 +1,147 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.muh" +#include "cast_kernels.muh" +#include "dequantize_kernels.muh" +#include "math.h" +// #include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" + +void nvte_quantize(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_quantize); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; + + detail::quantize_helper(input, grad, nullptr, output, + dbias, workspace, stream); +} + +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + musaStream_t stream) { + NVTE_API_CALL(nvte_quantize_noop); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; + + detail::quantize_helper(input, grad, noop, output, + dbias, workspace, stream); +} + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, musaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr const NVTETensor activation_input = nullptr; + + detail::quantize_helper( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_dequantize(const NVTETensor input, NVTETensor output, musaStream_t stream) { + NVTE_API_CALL(nvte_dequantize); + using namespace transformer_engine; + detail::dequantize_helper(*reinterpret_cast(input), + reinterpret_cast(output), stream); +} diff --git a/transformer_engine/musa/common/util/cast_gated_kernels.muh b/transformer_engine/musa/common/util/cast_gated_kernels.muh new file mode 100644 index 0000000000..95cc664d47 --- /dev/null +++ b/transformer_engine/musa/common/util/cast_gated_kernels.muh @@ -0,0 +1,1093 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_gated_kernels.cuh + * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.muh" +#include "math.h" +// #include "ptx.cuh" + +namespace transformer_engine { + +template +__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { + return DIVUP(static_cast(N), static_cast(M)) * M; +} + +namespace gated_kernels { + +constexpr size_t ALIGNMENT_SIZE = 128; +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + +/* +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t out_gate_mem = buff_size_aligned_out; + constexpr size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + // uint64_t *mbar = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + if (next_it < ITERATIONS) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, {}) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_elems_total = BUFFERS_NUM * buff_elems; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + + OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + OType *out_act_colwise_sh = out_act_rowwise_sh; + OType *out_gate_colwise_sh = out_gate_rowwise_sh; + + if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + out_gate_colwise_sh = + reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); + } + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act_rowwise = + reinterpret_cast(&tensor_map_output_act_rowwise); + const uint64_t *TMAP_output_gate_rowwise = + reinterpret_cast(&tensor_map_output_gate_rowwise); + const uint64_t *TMAP_output_act_colwise = + reinterpret_cast(&tensor_map_output_act_colwise); + const uint64_t *TMAP_output_gate_colwise = + reinterpret_cast(&tensor_map_output_gate_colwise); + + __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + const bool is_master_thread = (threadIdx.x == 0); + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + + // Prefetch data of the first stage + if (is_master_thread) { + // Initiate bulk tensor copy + // Grad + if constexpr (IS_DGATED) { + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), + TMAP_grad_in, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + } + + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), + TMAP_in_act, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[0]); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; + if (next_it < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + if constexpr (IS_DGATED) { + // Grad + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + } + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_it]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; + OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; + OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; + OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; + + // Assuming one iteration covers exactly 32 rows + const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; + const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; + + float after_dact_reg[BUFFER_STAGES_NUM]; + float after_dgate_reg[BUFFER_STAGES_NUM]; + float thread_Y_mx_block_amax = 0.0f; + float thread_Y_mx_block_amax_gate = 0.0f; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; + after_dgate_reg[stage] = act_x * grad_elt; + } else { + after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + } + + if constexpr (USE_ROWWISE_SCALING) { + if constexpr (IS_DGATED) { + // dgate + float amax = fabsf(after_dgate_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_gate_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dgate_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + float amax = fabsf(after_dact_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_act_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dact_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + + if constexpr (USE_COLWISE_SCALING) { + __builtin_assume(thread_Y_mx_block_amax >= 0); + __builtin_assume(thread_Y_mx_block_amax_gate >= 0); + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); + if constexpr (IS_DGATED) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + } + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool row_out_of_bounds = (row_base >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + if constexpr (IS_DGATED) { + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_gate_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dgate_reg[stage]); + } + } + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_act_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dact_reg[stage]); + } + } // endif USE_COLWISE_SCALING + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_rowwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_rowwise_sh_curr)); + } + } + + // dGeLU + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_colwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_colwise_sh_curr)); + } + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + musaStream_t stream) { + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem); // + mbar_mem; + + cudaFuncSetAttribute( + cast_fp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_fp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + musaStream_t stream) { + const bool USE_ROWWISE_SCALING = output->has_data(); + const bool USE_COLWISE_SCALING = output->has_columnwise_data(); + + if (USE_ROWWISE_SCALING) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (USE_COLWISE_SCALING) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + // TODO: Make more general + const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; + const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; + size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + e8m0_t *const scales_rowwise_ptr = + USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, + sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, + sizeof(OType)); + } + + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + cols, sizeof(OType)); + } + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; + + const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; + + cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} +*/ + +template +void cast_gated(const Tensor &input, Tensor *output, musaStream_t stream) { + CheckInputTensor(input, "gated_act_input"); + CheckOutputTensor(*output, "gated_act_output"); + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); + NVTE_CHECK(input.data.shape[0] == output->data.shape[0], + "Input shape[0] must be equal to output shape[0]."); + NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, + "Input shape[1] must be 2x larger than output shape[1]."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], + output->data.shape[1], {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, musaStream_t stream) { + CheckInputTensor(grad, "dgated_act_grad"); + CheckInputTensor(input, "dgated_act_input"); + CheckOutputTensor(*output, "dgated_act_output"); + NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), + "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), + ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, + "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, + "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes must match. Input shape: ", input.data.shape, + ", output shape: ", output->data.shape, "."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + musaStream_t stream) { + checkCuDriverContext(stream); + constexpr bool allow_empty = false; + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", allow_empty); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + if constexpr (IS_DGATED) { + CheckInputTensor(grad, "grad"); + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); + NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); + } + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + bool is_fp8_rowwise_output = true; + bool is_fp8_colwise_output = true; + if (output->has_data()) { + is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + if (output->has_columnwise_data()) { + is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + + const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; + + if (is_delayed_tensor_scaling(output->scaling_mode)) { + // if (use_tma_kernels) { + // cast_fp8_gated(grad, gated_input, output, stream); + // } else { + if constexpr (IS_DGATED) { + cast_dgated(grad, gated_input, output, stream); + } else { + cast_gated(gated_input, output, stream); + } + // } + // } else if (is_mxfp_scaling(output->scaling_mode)) { + // if (use_tma_kernels) { + // cast_mxfp8_gated(grad, gated_input, output, stream); + // } else { + // NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", + // "by 32, got input of shape ", gated_input.data.shape); + // } + } else { + NVTE_ERROR("Not supported scaling mode"); + } +} +} // namespace gated_kernels + +namespace detail { + +template +void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, + musaStream_t stream) { + using namespace gated_kernels; + Tensor grad_empty_tensor; + const Tensor &grad_tensor = + IS_DGATED ? *(reinterpret_cast(grad)) : grad_empty_tensor; + const Tensor gated_input_tensor = *reinterpret_cast(gated_input); + Tensor *output_tensor = reinterpret_cast(output); + + if (is_supported_by_CC_100()) { + quantize_gated(grad_tensor, gated_input_tensor, + output_tensor, stream); + } else { + if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { + if constexpr (IS_DGATED) { + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + } else { + cast_gated(gated_input_tensor, output_tensor, stream); + } + } else { + // MX scaling + NVTE_ERROR("Not supported by the Arch < 10.0"); + } + } +} +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ diff --git a/transformer_engine/musa/common/util/cast_kernels.muh b/transformer_engine/musa/common/util/cast_kernels.muh new file mode 100644 index 0000000000..7f65a6e771 --- /dev/null +++ b/transformer_engine/musa/common/util/cast_kernels.muh @@ -0,0 +1,1291 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_kernels.cuh + * \brief CUDA kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ + +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.muh" +#include "math.h" +// #include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" +#include "mtfp8_cast.muh" +#include "mtfp8_cast_transpose.h" + +namespace transformer_engine { + +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 +constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); + +/* +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + if (noop != nullptr && noop[0] == 1.0f) return; + } + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = blockIdx.y; + const int dbias_colwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_rowwise[i].clear(); + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_colwise[i] = 0; + } + } + } + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; +#pragma unroll + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec act_in; + Vec out_c; + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + } + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + } + } + in_compute[j] = elt; + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + // Only single thread writes the computed scaling factor + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = exp2f_rcp(biased_exponent); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + float in_compute[SCALE_DIM_Y]; + + float amax = 0; +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if (!out_of_bounds) { + partial_dbias_colwise[chunk_X] += elt; + } + } + in_compute[i] = elt; + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(amax >= 0); + block_amax = fmaxf(block_amax, amax); + + const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; + constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; + constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; + __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + + if (tid_rowwise_Y > 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + partial_dbias_rowwise[c].store_to( + &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); + } + } + __syncthreads(); + + if (tid_rowwise_Y == 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + Vec other_row_dbias; + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + + const int left_bound = dbias_rowwise_offset_X; + const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + +#pragma unroll + for (int i = 0; i < Y; ++i) { + other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; + } + } + + // Vectorized store when all elements are inside the boundaries + if (right_bound < cols) { + partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + // Element-by-element store when some elements cross the boundaries + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); + } + } + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + } + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +*/ + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +/* +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const int dbias_offset_Y = blockIdx.y + tid_Y; + const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const int dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const int chunk_offset_Y = block_offset_Y; + const int chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const int buff = iter % FP8_BUFFERS_NUM; + const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const int next_buff = next_iter % FP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const int dbias_offset_X = my_column; + const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +*/ + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +/* +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % SHMEM_BUFFERS; + const int it_offset = iter * SHMEM_DIM; + + const int next_iter = iter + 1; + const int next_buff = next_iter % SHMEM_BUFFERS; + const int next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +*/ + +constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, + const int cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + musaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); +} + +template +static void cast_fp8_1D(const Tensor &input, Tensor *output, musaStream_t stream) { + /* + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) + */ +} + +template +void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, musaStream_t stream) { + /* + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) + */ +} + +template +void mxfp8_quantize(const Tensor &input, const Tensor *act_input, + const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, musaStream_t stream) { + /* + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + checkCuDriverContext(stream); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + const auto &input_shape = input.data.shape; + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(IType)); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, + cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + cast_mxfp8_2D_kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + */ +} + +namespace detail { + +using Empty = transformer_engine::Empty; + +__device__ inline float identity(float value, const Empty &) { return value; } + +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + musaStream_t stream) { + // constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + musaStream_t stream) { + // constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace { + +static bool is_full_tile_1D_tensor(const Tensor *const t) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + return isFullTile; +} + +bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr int TMA_bytes = 16; + const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); + return cols % alignment_requirement == 0; +} + +} // namespace + +// Supported by the Arch >= 10.0 +template +void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + musaStream_t stream) { + /* + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DBIAS && !IS_DACT) { + if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype())) { + // Aligned AND FP8 + cast_fp8_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype())) { + // Aligned AND FP8 (+dAct) + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize(input, act_input, noop, output, dbias, + workspace, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + } + */ +} + +extern void no_fp8_grad_bias( + const Tensor* gradO, + bool trans, + const Tensor* gradB, + musaStream_t stream); + +// Supported by the Arch < 10.0 +template +void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + musaStream_t stream) { + if constexpr (IS_DBIAS) { + no_fp8_grad_bias(&input, true, dbias, stream); + } else if (!is_delayed_tensor_scaling(output->scaling_mode) || IS_DBIAS) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + + " on GPU with compute capability < 10.0."); + } + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } +} + +template +void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, musaStream_t stream) { + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + fp8_quantize_arch_ge_100(input, act_input, noop, output, + dbias, workspace, stream); + } else { + // Supported by the Arch < 10.0 + fp8_quantize_arch_l_100(input, act_input, noop, output, + dbias, workspace, stream); + } +} + +namespace detail { + +template +void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + musaStream_t stream) { + const Tensor *input_tensor; + const Tensor *activation_input_tensor; + if constexpr (IS_DBIAS || IS_DACT) { + // backward - input is incoming gradient + input_tensor = reinterpret_cast(grad); + activation_input_tensor = reinterpret_cast(input); + } else { + // forward = input is activation input + input_tensor = reinterpret_cast(input); + activation_input_tensor = nullptr; + } + auto output_tensor = reinterpret_cast(output); + auto dbias_tensor = reinterpret_cast(dbias); + auto workspace_tensor = reinterpret_cast(workspace); + const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); + + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + } else if (output_tensor->has_data()) { + fp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + } + break; + } + case NVTE_MTFP8_BLOCK_SCALING: { + NVTE_CHECK(output_tensor->has_data(), "Data should be provided for MTFP8 quantization."); + NVTE_CHECK(is_fp8_dtype(output_tensor->dtype())); + if (output_tensor->has_columnwise_data()) { + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + mtfp8_cast_transpose(input_tensor, &noop_tensor, output_tensor, stream); + } else { + NVTE_ERROR("MTFP8 cast_transpose_fused not supported."); + } + } else { + mtfp8_quantize( + input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + } + break; + } + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8_quantize( + // *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + // workspace_tensor, stream); + // break; + // } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace detail +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/musa/common/util/dequantize_kernels.muh b/transformer_engine/musa/common/util/dequantize_kernels.muh new file mode 100644 index 0000000000..591302fcb0 --- /dev/null +++ b/transformer_engine/musa/common/util/dequantize_kernels.muh @@ -0,0 +1,367 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_kernels.cuh + * \brief CUDA kernels to cast from MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ + +#include +#include +#include +#include + +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.muh" +#include "math.h" +// #include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" +#include "mtfp8_cast.muh" + +namespace transformer_engine { + +namespace dequantization { + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t BUFFERS_NUM = 2; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t BUFFER_DIM_Y = 16; // only 32 is supported +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(ITERATIONS >= 1); + +/* +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, + const size_t scales_stride) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + constexpr int iteration_zero = 0; + constexpr int buffer_zero = 0; + if (is_master_thread) { + const int chunk_stage_offset_Y = chunk_offset_Y; + const int chunk_stage_offset_X = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buffer_zero]), + reinterpret_cast(&tensor_map_input), chunk_stage_offset_X, + chunk_stage_offset_Y, &mbar[iteration_zero]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[iteration_zero], transaction_size); + + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[iteration_zero]); + } + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % BUFFERS_NUM; + const int next_iter = iter + 1; + if (next_iter < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_iter % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[next_buff]), + reinterpret_cast(&tensor_map_input), chunk_it_offset_x, + chunk_it_offset_y, &mbar[next_iter]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_iter], transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_iter]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + const int scale_offset_Y = + USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) + : (scales_colwise_chunk_offset_Y + (iter * BUFFER_DIM_Y) / SCALE_DIM_Y); + + const int scale_offset_X = + USE_ROWWISE_SCALING + ? (scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE) + : (scales_colwise_chunk_offset_X + tid_colwise_X); + + const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; + const e8m0_t biased_exponent = scales_ptr[scale_idx]; + const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out; + + const int shmem_offset_y = thread_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); + } + out.store_to(&out_sh[buff][shmem_offset_y][shmem_offset_x]); + } else { +#pragma unroll + for (int i = 0; i < BUFFER_DIM_Y; ++i) { + const float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + out_sh[buff][i][tid_colwise_X] = static_cast(block_scale * elt); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +*/ + +static void fp8_dequantize(const Tensor &input, Tensor *output, musaStream_t stream) { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + detail::DequantizeParam p; + p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} + +/* +static void mxfp8_dequantize(const Tensor &input, Tensor *output, musaStream_t stream) { + bool use_rowwise_scaling = input.has_data(); + bool use_colwise_scaling = input.has_columnwise_data(); + checkCuDriverContext(stream); + + const auto &input_shape = input.data.shape; + NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions."); + + if (use_rowwise_scaling) { + NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + } + + if (use_colwise_scaling) { + NVTE_CHECK(input.has_columnwise_data(), "Cannot dequantize tensor without columnwise data."); + NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); + } + + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); + + const size_t unpadded_scales_Y_rowwise = rows; + const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); + const size_t unpadded_scales_X_colwise = cols; + + const size_t scales_Y_rowwise = + DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * + scale_tensor_alignment_Y_rowwise; + const size_t scales_X_rowwise = + DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * + scale_tensor_alignment_X_rowwise; + const size_t scales_Y_colwise = + DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * + scale_tensor_alignment_Y_colwise; + const size_t scales_X_colwise = + DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * + scale_tensor_alignment_X_colwise; + + const e8m0_t *const scales_ptr = + use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) + : reinterpret_cast(input.columnwise_scale_inv.dptr); + + const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; + + const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; + + const dim3 block(THREADS_PER_CHUNK); + const dim3 grid(chunks_X, chunks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(OType)); + + dequantize_mxfp8_kernel + <<>>(tensor_map_input, tensor_map_output, scales_ptr, + rows, cols, scales_stride);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} +*/ +} // namespace dequantization + +namespace detail { + +void dequantize_helper(const Tensor &input, Tensor *output, musaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if (is_tensor_scaling(input.scaling_mode)) { + dequantization::fp8_dequantize(input, output, stream); + } else if (is_mtfp_scaling(input.scaling_mode)) { + mtfp8_dequantize(&input, output, stream); + // } else if (is_mxfp_scaling(input.scaling_mode)) { + // if (is_supported_by_CC_100()) { + // dequantization::mxfp8_dequantize(input, output, stream); + // } else { + // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + // } + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ diff --git a/transformer_engine/musa/common/util/logging.h b/transformer_engine/musa/common/util/logging.h new file mode 100644 index 0000000000..51427b5283 --- /dev/null +++ b/transformer_engine/musa/common/util/logging.h @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ + +#include +#include +#include +// #include + +#include + +#include "../util/string.h" + +#define NVTE_ERROR(...) \ + do { \ + throw ::std::runtime_error(::transformer_engine::concat_strings( \ + __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ + ::transformer_engine::concat_strings(__VA_ARGS__))); \ + } while (false) + +#define NVTE_CHECK(expr, ...) \ + do { \ + if (!(expr)) { \ + NVTE_ERROR("Assertion failed: " #expr ". ", \ + ::transformer_engine::concat_strings(__VA_ARGS__)); \ + } \ + } while (false) + +#define NVTE_CHECK_CUDA(expr) \ + do { \ + const musaError_t status_NVTE_CHECK_MUSA = (expr); \ + if (status_NVTE_CHECK_MUSA != musaSuccess) { \ + NVTE_ERROR("MUSA Error: ", musaGetErrorString(status_NVTE_CHECK_MUSA)); \ + } \ + } while (false) + +#define NVTE_CHECK_CUBLAS(expr) \ + do { \ + const mublasStatus_t status_NVTE_CHECK_MUBLAS = (expr); \ + if (status_NVTE_CHECK_MUBLAS != MUBLAS_STATUS_SUCCESS) { \ + NVTE_ERROR("muBLAS Error: ", mublasStatus_to_string(status_NVTE_CHECK_MUBLAS));\ + } \ + } while (false) + +#define NVTE_CHECK_CUDNN(expr) \ + do { \ + const ::musa::dnn::Status status_NVTE_CHECK_MUDNN = (expr); \ + if (status_NVTE_CHECK_MUDNN != ::musa::dnn::Status::SUCCESS) { \ + NVTE_ERROR("muDNN Runtime Error(", \ + static_cast(status_NVTE_CHECK_MUDNN), \ + "). For more information, enable muDNN logging " \ + "by setting MUDNN_LOG_LEVEL=INFO in the environment."); \ + } \ + } while (false) + +#define NVTE_CHECK_CUDNN_FE(expr) \ + do { \ + } while (false) + +#define NVTE_CHECK_NVRTC(expr) \ + do { \ + } while (false) + +#define NVTE_CHECK_MU(expr) \ + do { \ + MUresult status = (expr); \ + if (status != MUSA_SUCCESS) { \ + const char* err_str; \ + muGetErrorString(status, &err_str); \ + NVTE_ERROR("musa driver Error: ", err_str); \ + } \ + } while (false) + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/musa/common/util/math.h b/transformer_engine/musa/common/util/math.h new file mode 120000 index 0000000000..2b1cd0a008 --- /dev/null +++ b/transformer_engine/musa/common/util/math.h @@ -0,0 +1 @@ +../../../common/util/math.h \ No newline at end of file diff --git a/transformer_engine/musa/common/util/mtfp8_blockwise_quantize.muh b/transformer_engine/musa/common/util/mtfp8_blockwise_quantize.muh new file mode 100644 index 0000000000..ac97cf698e --- /dev/null +++ b/transformer_engine/musa/common/util/mtfp8_blockwise_quantize.muh @@ -0,0 +1,378 @@ +#ifndef TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_BLOCKWISE_QUANTIZE_MUH_ +#define TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_BLOCKWISE_QUANTIZE_MUH_ + +#include "../common.h" +#include "../utils.muh" +#include "math.h" +#include "mtfp8_utils.muh" + +#include + +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::mtfp8 { + +template < + typename Param, + float (*OP)(float, const Param &), + typename IType, + typename OType, + typename CType, + size_t VLEN, + size_t NCol> +__global__ void fp8_nn_blockwise_n_to_1_kernel( + const IType* inp, + const CType* noop, + OType* out, + CType* sinv, + size_t M, + size_t N, + size_t block_m, + size_t block_n, + size_t rounds, + size_t n_warps, + Param param) { + if (noop != nullptr && noop[0] == 1.0f) return; + using IVecT = Vec; + using CVecT = Vec; + using OVecT = Vec; + + using Trait_VLEN = VlenTrait; + using Trait_COL = VlenTrait; + + extern __shared__ __align__(alignof(size_t)) char temp[]; + + const size_t tid = threadIdx.y * blockDim.x + threadIdx.x; + const size_t warp_id = tid >> warp_bits; + const size_t lane_id = tid & warp_mask; + + auto* temp_base = reinterpret_cast(temp); + auto* offset_base = reinterpret_cast(temp + block_m * block_n * sizeof(CType)) + tid * rounds; + + const size_t stride = blockDim.x * blockDim.y; + + const size_t base_m = blockIdx.y * block_m; + const size_t base_n = blockIdx.x * block_n; + + size_t idx = tid; + size_t act_round = 0; + IVecT vec_in; + + CType amax = 0; + __shared__ CType staging[max_warps_per_block]; + + for (; act_round < rounds; ++act_round) { + size_t global_m = base_m; + if constexpr (Trait_COL::is_power_of_2) { + global_m += (idx >> Trait_COL::bits); + } else { + global_m += (idx / blockDim.x); + } + + if (global_m < M) { + size_t global_n = base_n; + if constexpr (Trait_COL::is_power_of_2) { + global_n += ((idx & Trait_COL::mask) << Trait_VLEN::bits); + } else { + global_n += ((idx % blockDim.x) << Trait_VLEN::bits); + } + + size_t offset = global_m * N + global_n; + *(offset_base + act_round) = offset; + + vec_in.load_from(inp + offset, 0); + CVecT& temp = *(temp_base + idx); +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + temp.data.elt[j] = (CType)(OP((float)vec_in.data.elt[j], param)); + amax = fmaxf(fabsf(temp.data.elt[j]), amax); + amax = fmaxf(amax, global_amax_min); + } + idx += stride; + } else { + break; + } + } + + amax = warp_reduce_max(amax); + if (lane_id == 0) { + staging[warp_id] = amax; + } + __syncthreads_lm(); + + amax = 0; + if (warp_id == 0) { + amax = tid < n_warps ? staging[tid] : 0; + amax = warp_reduce_max(amax); + } + if (tid == 0) { + staging[0] = amax; // block_amax + staging[1] = (CType)(Quantized_Limits::max_norm) / amax; // block_scale + } + __syncthreads_lm(); + amax = staging[1]; + + OVecT vec_out; + idx = tid; + for (size_t i = 0; i < act_round; ++i) { + CVecT& temp = *(temp_base + idx); +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_out.data.elt[j] = (OType)(temp.data.elt[j] * amax); + } + vec_out.store_to(out + *(offset_base + i), 0); + idx += stride; + } + + if (tid == 0) { + const size_t sinv_offset = blockIdx.y * gridDim.x + blockIdx.x; + *(sinv + sinv_offset) = staging[0] * (CType)(Quantized_Limits::max_norm_rcp); + } +} + +template < + typename Param, + float (*OP)(float, const Param &), + typename IType, + typename OType, + typename CType, + size_t VLEN> +__global__ void fp8_nn_blockwise_n_to_1_kernel_no_align( + const IType* inp, + const CType* noop, + OType* out, + CType* sinv, + size_t M, + size_t N, + size_t block_m, + size_t block_n, + // dense + int dense_m, + int dense_n, + // sparse + int row_n, + int col_n, + Param param) { + using IVecT = Vec; + using CVecT = Vec; + using OVecT = Vec; + + if (noop != nullptr && noop[0] == 1.0f) return; + + const size_t tid = threadIdx.y * blockDim.x + threadIdx.x; + const size_t warp_id = threadIdx.y; + const size_t lane_id = threadIdx.x; + + const size_t base_m = blockIdx.y * block_m + warp_id; + size_t base_n = blockIdx.x * block_n; + + extern __shared__ __align__(alignof(CType)) CType shm[]; + CType amax = 0; + __shared__ CType staging[max_warps_per_block]; + + const size_t data_strd = blockDim.y * N; + + size_t data_base = base_m * N; + const IType* iptr = inp + data_base; + + size_t c_base; + size_t c_strd; + + const bool is_dense = (blockIdx.x < dense_n && blockIdx.y < dense_m); + if (is_dense) { + base_n += lane_id * VLEN; + c_base = tid; + c_strd = blockDim.x * blockDim.y; + + CVecT* c_vec = reinterpret_cast(shm) + c_base; + iptr += base_n; + + IVecT vec_in; + for (int i = 0; i < row_n; ++i, iptr+=data_strd, c_vec+=c_strd) { + vec_in.load_from(iptr, 0); + CVecT& c_tmp = *c_vec; +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + c_tmp.data.elt[j] = (CType)(OP((float)vec_in.data.elt[j], param)); + amax = fmaxf(fabsf(c_tmp.data.elt[j]), amax); + } + } + } else { + base_n += lane_id; + c_base = warp_id * block_n + lane_id; + c_strd = blockDim.y * block_n; + + CType* cptr = shm + c_base; + + size_t id_m = base_m; + for (int i = 0; (i < row_n) && (id_m < M); ++i, id_m+=blockDim.y, iptr+=data_strd, cptr+=c_strd) { + size_t id_n = base_n; + size_t ff_n = 0; + for (int j = 0; (j < col_n) && (id_n < N); ++j, id_n+=blockDim.x, ff_n+=blockDim.x) { + CType val = (CType)(OP((float)(*(iptr + id_n)), param)); + *(cptr + ff_n) = val; + amax = fmaxf(fabsf(val), amax); + } + } + } + amax = fmaxf(amax, global_amax_min); + + amax = warp_reduce_max(amax); + if (lane_id == 0) { + staging[warp_id] = amax; + } + __syncthreads_lm(); + + amax = 0; + if (warp_id == 0) { + amax = tid < blockDim.y ? staging[tid] : 0; + amax = warp_reduce_max(amax); + } + if (tid == 0) { + staging[0] = amax; // block_amax + staging[1] = (CType)(Quantized_Limits::max_norm) / amax; // block_scale + } + __syncthreads_lm(); + amax = staging[1]; + + OType* optr = out + data_base; + if (is_dense) { + CVecT* c_vec = reinterpret_cast(shm) + c_base; + optr += base_n; + + OVecT vec_out; + for (int i = 0; i < row_n; ++i, optr+=data_strd, c_vec+=c_strd) { + CVecT& c_tmp = *c_vec; +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_out.data.elt[j] = (OType)(c_tmp.data.elt[j] * amax); + } + vec_out.store_to(optr, 0); + } + } else { + CType* cptr = shm + c_base; + + size_t id_m = base_m; + for (int i = 0; (i < row_n) && (id_m < M); ++i, id_m+=blockDim.y, optr+=data_strd, cptr+=c_strd) { + size_t id_n = base_n; + size_t ff_n = 0; + for (int j = 0; (j < col_n) && (id_n < N); ++j, id_n+=blockDim.x, ff_n+=blockDim.x) { + CType& val = *(cptr + ff_n); + val *= amax; + *(optr + id_n) = (OType)(val); + } + } + } + + if (tid == 0) { + const size_t sinv_offset = blockIdx.y * gridDim.x + blockIdx.x; + *(sinv + sinv_offset) = staging[0] * (CType)(Quantized_Limits::max_norm_rcp); + } +} + +template < + typename Param, + float (*OP)(float, const Param &), + typename IType, + typename OType, + typename CType> +inline void fp8_blockwise_cast( + const IType* inp, + const CType* noop, + OType* out, + CType* sinv, + size_t M, + size_t N, + size_t block_m, + size_t block_n, + Param param, + musaStream_t stream) { + NVTE_CHECK(block_m == block_n); + + const int block_x = (int)ceil_div(N, block_n); + const int block_y = (int)ceil_div(M, block_m); + dim3 blocks(block_x, block_y); + + if (N % block_n != 0) { + NVTE_CHECK(is_power_of_2(block_n)); + + // dense + const size_t dense_m = M / block_m; + const size_t dense_n = N / block_n; + const size_t full_vlen = block_n / warp_size; + + // sparse + const auto col_n = ceil_div(block_n, warp_size); + + const auto thread_x = warp_size; + const auto thread_y = std::min(block_m, max_warps_per_block); + const auto row_n = ceil_div(block_m, thread_y); + + NVTE_CHECK(is_power_of_2(col_n)); + dim3 threads((int)(thread_x), (int)(thread_y)); + + const size_t shm_trans = block_m * block_n * sizeof(CType); + + if (full_vlen == 4) { + fp8_nn_blockwise_n_to_1_kernel_no_align + <<>> + (inp, noop, out, sinv, M, N, block_m, block_n, + (int)dense_m, (int)dense_n, + (int)row_n, (int)col_n, + param); + } else if (full_vlen == 2) { + fp8_nn_blockwise_n_to_1_kernel_no_align + <<>> + (inp, noop, out, sinv, M, N, block_m, block_n, + (int)dense_m, (int)dense_n, + (int)row_n, (int)col_n, + param); + } else { + NVTE_CHECK(block_n >= 32); + fp8_nn_blockwise_n_to_1_kernel_no_align + <<>> + (inp, noop, out, sinv, M, N, block_m, block_n, + (int)dense_m, (int)dense_n, + (int)row_n, (int)col_n, + param); + } + + NVTE_CHECK_CUDA(musaGetLastError()); + return; + } + + constexpr size_t VLEN = io_bytes / sizeof(IType); + const size_t thread_x = block_n / VLEN; + + const size_t thread_y = std::min(block_m, max_threads_per_block / thread_x); + dim3 threads((int)(thread_x), (int)(thread_y)); + + const size_t n_threads = thread_x * thread_y; + NVTE_CHECK(n_threads % warp_size == 0); + const size_t n_warps = n_threads / warp_size; + + const size_t rounds = ceil_div(block_m, thread_y); + const size_t shm_trans = block_m * block_n * sizeof(CType); + const size_t shm_offsets = n_threads * rounds * sizeof(size_t); + const size_t shm_total = shm_trans + shm_offsets; + + if (thread_x == 32) { + fp8_nn_blockwise_n_to_1_kernel + <<>> + (inp, noop, out, sinv, M, N, block_m, block_n, rounds, n_warps, param); + } else if (thread_x == 16) { + fp8_nn_blockwise_n_to_1_kernel + <<>> + (inp, noop, out, sinv, M, N, block_m, block_n, rounds, n_warps, param); + } else { + fp8_nn_blockwise_n_to_1_kernel + <<>> + (inp, noop, out, sinv, M, N, block_m, block_n, rounds, n_warps, param); + } + + NVTE_CHECK_CUDA(musaGetLastError()); +} + +} // namespace transformer_engine::mtfp8 + +#endif // TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_BLOCKWISE_QUANTIZE_MUH_ diff --git a/transformer_engine/musa/common/util/mtfp8_cast.muh b/transformer_engine/musa/common/util/mtfp8_cast.muh new file mode 100644 index 0000000000..d31e63865f --- /dev/null +++ b/transformer_engine/musa/common/util/mtfp8_cast.muh @@ -0,0 +1,104 @@ +#ifndef TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_CAST_MUH_ +#define TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_CAST_MUH_ + +#include "mtfp8_groupwise_quantize.muh" +#include "mtfp8_blockwise_quantize.muh" + +namespace transformer_engine { + +namespace mtfp8 { + +template +inline void mtfp8_quantize_dispatch( + const Tensor* input, + const Tensor* noop, + Tensor* output, + Param param, + musaStream_t stream) { + NVTE_CHECK(noop->data.dtype == DType::kFloat32); + NVTE_CHECK(output->scale_inv.dtype == DType::kFloat32); + using CType = float; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + const auto M = input->flat_first_dim(); + const auto N = input->flat_last_dim(); + const auto sinv_m = output->scale_inv.shape[0]; + const auto sinv_n = output->scale_inv.shape[1]; + size_t group_size = 128; + if (N % sinv_n == 0) { + group_size = next_power_of_2(N / sinv_n); + } + if (M == 1 || M == sinv_m) { + // 1 * N + fp8_groupwise_cast( + reinterpret_cast(input->data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale_inv.dptr), + M, N, group_size, param, stream); + } else { + // N * N + fp8_blockwise_cast( + reinterpret_cast(input->data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale_inv.dptr), + M, N, group_size, group_size, param, stream); + } + ); + ); +} + +} // namespace mtfp8 + +template < + bool IS_DBIAS, + bool IS_DACT, + bool IS_ACT, + typename Param, + float (*OP)(float, const Param &)> +inline void mtfp8_quantize( + const Tensor* input, + const Tensor* act_input, + const Tensor* noop, + Tensor* output, + Tensor* dbias, + Tensor* workspace, + musaStream_t stream) { + using namespace mtfp8; + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(*input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input->dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input->data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input->dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input->data.shape, "Input and output shapes need to match."); + + if (IS_DBIAS) { + NVTE_ERROR("Not yet implemented."); + } + if (IS_DACT) { + NVTE_ERROR("Not yet implemented."); + } + + (void)workspace; + mtfp8_quantize_dispatch(input, noop, output, {}, stream); +} + +void mtfp8_dequantize(const Tensor* input, Tensor* output, musaStream_t stream); + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_CAST_MUH_ diff --git a/transformer_engine/musa/common/util/mtfp8_cast_transpose.h b/transformer_engine/musa/common/util/mtfp8_cast_transpose.h new file mode 100644 index 0000000000..1d3d6c7b47 --- /dev/null +++ b/transformer_engine/musa/common/util/mtfp8_cast_transpose.h @@ -0,0 +1,12 @@ +#ifndef TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_CAST_TRANSPOSE_H_ +#define TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_CAST_TRANSPOSE_H_ + +#include "../common.h" + +namespace transformer_engine { + +void mtfp8_cast_transpose(const Tensor* input, const Tensor* noop, Tensor* output_, musaStream_t stream); + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_CAST_TRANSPOSE_MUH_ diff --git a/transformer_engine/musa/common/util/mtfp8_cast_transpose.mu b/transformer_engine/musa/common/util/mtfp8_cast_transpose.mu new file mode 100644 index 0000000000..541ad77435 --- /dev/null +++ b/transformer_engine/musa/common/util/mtfp8_cast_transpose.mu @@ -0,0 +1,402 @@ +#include "mtfp8_cast_transpose.h" + +#include + +#include "../util/string.h" +#include "../utils.muh" +#include "mtfp8_utils.muh" + + +namespace transformer_engine { + +namespace mtfp8 { + +using CType = float; +constexpr size_t warps_per_tile = 4; +constexpr size_t block_size = warp_size * warps_per_tile; + +namespace { + +template +__device__ __forceinline__ T warpReduceMax(T max_value) { + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 16)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 8)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2)); + max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); + return max_value; +} + +constexpr int max(int a, int b) { + return a > b ? a : b; +} + +} + + +template < + typename IType, + typename OType, + size_t N_ELEMENTS_PER_THREAD_X = 4/* VLEN */, + size_t N_ELEMENTS_PER_THREAD_Y = 4, + size_t BLOCK_SIZE_X = 32, + size_t BLOCK_SIZE_Y = 16, + size_t GROUP_SIZE = 128 +> +__global__ void mtfp8_cast_transpose_general_kernel_column_aligned( + const IType *__restrict__ const inp, + const CType *__restrict__ const noop, + OType *__restrict__ const out_c, + OType *__restrict__ const out_t, + CType *__restrict__ const scale_inv, + CType *__restrict__ const columnwise_scale_inv, + size_t ncols, + size_t nrows) { + // rowwise_group_size and columnwise_group_size should be equal + + if (noop != nullptr && noop[0] == 1.0f) return; + + using input_vec_t = Vec; + using out_vec_t = Vec; + using scale_vec_t = Vec; + + const uint32_t local_col_base_id = threadIdx.x * N_ELEMENTS_PER_THREAD_X; + const uint32_t global_col_base_id = blockIdx.x * GROUP_SIZE + local_col_base_id; + const uint32_t local_row_base_id = threadIdx.y * N_ELEMENTS_PER_THREAD_Y; + const uint32_t global_row_base_id = blockIdx.y * GROUP_SIZE; + + const uint32_t rowwise_scale_inv_stride = ncols / GROUP_SIZE; + + // if ((global_row_base_id + local_row_base_id) >= nrows) { + // return; + // } + + const IType* inp_load_ptr = inp + global_row_base_id * ncols + global_col_base_id; + OType* out_c_store_ptr = out_c + global_row_base_id * ncols + global_col_base_id; + CType* rowwise_scale_inv_ptr = scale_inv + global_row_base_id * rowwise_scale_inv_stride + blockIdx.x; + OType* out_t_store_ptr = out_t + global_row_base_id * ncols + global_col_base_id; + CType* columnwise_scale_inv_ptr = columnwise_scale_inv + blockIdx.y * ncols + global_col_base_id; + + constexpr int REPEAT_Y = DIVUP(GROUP_SIZE, BLOCK_SIZE_Y * N_ELEMENTS_PER_THREAD_Y); + constexpr int ELEMENTS_PER_BANK = 4 / sizeof(IType); // dword of bank is 32 bits by default + static_assert(ELEMENTS_PER_BANK != 0); + constexpr int NDWORD = N_ELEMENTS_PER_THREAD_X / ELEMENTS_PER_BANK; + static_assert(NDWORD != 0); + + // 0, 1, 2, ..., 31 + // 128 * 128 * 2 / 1024 + 128 * BLOCK_SIZE_Y * 2 / 1024 + // __shared__ IType shm[GROUP_SIZE][NDWORD][GROUP_SIZE / NDWORD]; + __shared__ IType shm[GROUP_SIZE][GROUP_SIZE]; + __shared__ IType shm_amax_columnwise[BLOCK_SIZE_Y][GROUP_SIZE + 2]; + + float amax_rowwise; + float amax_columnwise[N_ELEMENTS_PER_THREAD_X] = {0.f}; + + input_vec_t tmp_load_reg; + out_vec_t tmp_store_reg; + scale_vec_t scale_store_reg; + + #pragma unroll + for (int loop_y_id = 0; loop_y_id < REPEAT_Y; loop_y_id++) { + // assume no multiple loads along X dimension + + // TODO: try prefetch + + int group_inner_y_id = loop_y_id * BLOCK_SIZE_Y * N_ELEMENTS_PER_THREAD_Y + local_row_base_id; + // load input values into shared memory + #pragma unroll + for (int ii_y = 0; ii_y < N_ELEMENTS_PER_THREAD_Y; ii_y++) { + amax_rowwise = 0.f; + int ld_st_offset = global_row_base_id + group_inner_y_id + ii_y < nrows ? + group_inner_y_id + ii_y: + 0; + *reinterpret_cast(shm[group_inner_y_id + ii_y] + local_col_base_id) = *reinterpret_cast(inp_load_ptr + ld_st_offset * ncols); + tmp_load_reg.load_from(shm[group_inner_y_id + ii_y] + local_col_base_id, 0); + + #pragma unroll + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + amax_rowwise = fmaxf(fmaxf(amax_rowwise, fabsf(tmp_load_reg.data.elt[ii_x])), global_amax_min); + amax_columnwise[ii_x] = fmaxf(fmaxf(amax_columnwise[ii_x], fabsf(tmp_load_reg.data.elt[ii_x])), global_amax_min); + } + + amax_rowwise = warpReduceMax(amax_rowwise) * (float)(Quantized_Limits::max_norm_rcp); + + //// write back to scale_inv and out_c [rowwise result] + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + tmp_store_reg.data.elt[ii_x] = static_cast(float(tmp_load_reg.data.elt[ii_x]) / amax_rowwise); + } + tmp_store_reg.store_to(out_c_store_ptr + ld_st_offset * ncols, 0); + if (threadIdx.x == 0) { + rowwise_scale_inv_ptr[ld_st_offset * rowwise_scale_inv_stride] = amax_rowwise; + } + } + + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + shm_amax_columnwise[threadIdx.y][local_col_base_id + ii_x] = amax_columnwise[ii_x]; + } + } + + // RUN COLUMNWISE + + __syncthreads_lm(); + + for (int i = threadIdx.y; i < GROUP_SIZE; i += blockDim.y) { + IType amax = threadIdx.x < blockDim.y ? + shm_amax_columnwise[threadIdx.x][i] : + (IType)0.f; + amax = warpReduceMax((float)amax); + if (threadIdx.x == 0) { + shm_amax_columnwise[0][i] = amax; + } + } + + __syncthreads_lm(); + #pragma unroll + for (int ii = 0; ii < N_ELEMENTS_PER_THREAD_X; ii++) { + amax_columnwise[ii] = (float)shm_amax_columnwise[0][local_col_base_id + ii] * (float)(Quantized_Limits::max_norm_rcp); + } + + // write back to columnwise_scale_inv and out_t + for (int loop_y_id = 0; loop_y_id < REPEAT_Y; loop_y_id++) { + int group_inner_y_id = loop_y_id * BLOCK_SIZE_Y * N_ELEMENTS_PER_THREAD_Y + local_row_base_id; + for (int ii_y = 0; ii_y < N_ELEMENTS_PER_THREAD_Y; ii_y++) { + int group_inner_y_offset = group_inner_y_id + ii_y; + int store_offset = (global_row_base_id + group_inner_y_offset) < nrows ? + group_inner_y_offset : + 0; + + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + float value = (float)shm[group_inner_y_offset][local_col_base_id + ii_x] / amax_columnwise[ii_x]; + tmp_store_reg.data.elt[ii_x] = static_cast(value); + } + tmp_store_reg.store_to(out_t_store_ptr + store_offset * ncols, 0); + } + } + if (threadIdx.y == 0) { + #pragma unroll + for (int i = 0; i < N_ELEMENTS_PER_THREAD_X; i++) { + scale_store_reg.data.elt[i] = amax_columnwise[i]; + } + scale_store_reg.store_to(columnwise_scale_inv_ptr, 0); + } +} + + +template < + typename IType, + typename OType, + size_t N_ELEMENTS_PER_THREAD_X = 4/* VLEN */, + size_t N_ELEMENTS_PER_THREAD_Y = 4, + size_t BLOCK_SIZE_X = 32, + size_t BLOCK_SIZE_Y = 16, + size_t GROUP_SIZE = 128 +> +__global__ void mtfp8_cast_transpose_general_kernel_column_unaligned( + const IType *__restrict__ const inp, + const CType *__restrict__ const noop, + OType *__restrict__ const out_c, + OType *__restrict__ const out_t, + CType *__restrict__ const scale_inv, + CType *__restrict__ const columnwise_scale_inv, + size_t ncols, + size_t nrows) { + // rowwise_group_size and columnwise_group_size should be equal + + // if (noop != nullptr && noop[0] == 1.0f) return; + + using input_vec_t = Vec; + using out_vec_t = Vec; + using scale_vec_t = Vec; + + const uint32_t local_col_base_id = threadIdx.x * N_ELEMENTS_PER_THREAD_X; + uint32_t global_col_base_id = blockIdx.x * GROUP_SIZE; + global_col_base_id += ((global_col_base_id + local_col_base_id) < ncols ? local_col_base_id : 0); + const uint32_t local_row_base_id = threadIdx.y * N_ELEMENTS_PER_THREAD_Y; + const uint32_t global_row_base_id = blockIdx.y * GROUP_SIZE; + + const uint32_t rowwise_scale_inv_stride = (ncols + GROUP_SIZE - 1) / GROUP_SIZE; + + const IType* inp_load_ptr = inp + global_row_base_id * ncols + global_col_base_id; + OType* out_c_store_ptr = out_c + global_row_base_id * ncols + global_col_base_id; + CType* rowwise_scale_inv_ptr = scale_inv + global_row_base_id * rowwise_scale_inv_stride + blockIdx.x; + OType* out_t_store_ptr = out_t + global_row_base_id * ncols + global_col_base_id; + CType* columnwise_scale_inv_ptr = columnwise_scale_inv + blockIdx.y * ncols + global_col_base_id; + + constexpr int REPEAT_Y = DIVUP(GROUP_SIZE, BLOCK_SIZE_Y * N_ELEMENTS_PER_THREAD_Y); + constexpr int ELEMENTS_PER_BANK = 4 / sizeof(IType); // dword of bank is 32 bits by default + static_assert(ELEMENTS_PER_BANK != 0); + constexpr int NDWORD = N_ELEMENTS_PER_THREAD_X / ELEMENTS_PER_BANK; + static_assert(NDWORD != 0); + + // 0, 1, 2, ..., 31 + // 128 * 128 * 2 / 1024 + 128 * BLOCK_SIZE_Y * 2 / 1024 + // __shared__ IType shm[GROUP_SIZE][NDWORD][GROUP_SIZE / NDWORD]; + __shared__ IType shm[GROUP_SIZE][GROUP_SIZE]; + __shared__ IType shm_amax_columnwise[BLOCK_SIZE_Y][GROUP_SIZE + 2]; + + float amax_rowwise; + float amax_columnwise[N_ELEMENTS_PER_THREAD_X] = {0.f}; + + input_vec_t tmp_load_reg; + out_vec_t tmp_store_reg; + scale_vec_t scale_store_reg; + + #pragma unroll + for (int loop_y_id = 0; loop_y_id < REPEAT_Y; loop_y_id++) { + // assume no multiple loads along X dimension + + // TODO: try prefetch + + int group_inner_y_id = loop_y_id * BLOCK_SIZE_Y * N_ELEMENTS_PER_THREAD_Y + local_row_base_id; + // load input values into shared memory + #pragma unroll + for (int ii_y = 0; ii_y < N_ELEMENTS_PER_THREAD_Y; ii_y++) { + amax_rowwise = 0.f; + int ld_st_offset = global_row_base_id + group_inner_y_id + ii_y < nrows ? + group_inner_y_id + ii_y: + 0; + *reinterpret_cast(shm[group_inner_y_id + ii_y] + local_col_base_id) = *reinterpret_cast(inp_load_ptr + ld_st_offset * ncols); + tmp_load_reg.load_from(shm[group_inner_y_id + ii_y] + local_col_base_id, 0); + + #pragma unroll + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + amax_rowwise = fmaxf(fmaxf(amax_rowwise, fabsf(tmp_load_reg.data.elt[ii_x])), global_amax_min); + amax_columnwise[ii_x] = fmaxf(fmaxf(amax_columnwise[ii_x], fabsf(tmp_load_reg.data.elt[ii_x])), global_amax_min); + } + + amax_rowwise = warpReduceMax(amax_rowwise) * (float)(Quantized_Limits::max_norm_rcp); + + //// write back to scale_inv and out_c [rowwise result] + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + tmp_store_reg.data.elt[ii_x] = static_cast(float(tmp_load_reg.data.elt[ii_x]) / amax_rowwise); + } + tmp_store_reg.store_to(out_c_store_ptr + ld_st_offset * ncols, 0); + if (threadIdx.x == 0) { + rowwise_scale_inv_ptr[ld_st_offset * rowwise_scale_inv_stride] = amax_rowwise; + } + } + + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + shm_amax_columnwise[threadIdx.y][local_col_base_id + ii_x] = amax_columnwise[ii_x]; + } + } + + // RUN COLUMNWISE + + __syncthreads_lm(); + + for (int i = threadIdx.y; i < GROUP_SIZE; i += blockDim.y) { + IType amax = threadIdx.x < blockDim.y ? + shm_amax_columnwise[threadIdx.x][i] : + (IType)0.f; + amax = warpReduceMax((float)amax); + if (threadIdx.x == 0) { + shm_amax_columnwise[0][i] = amax; + } + } + + __syncthreads_lm(); + #pragma unroll + for (int ii = 0; ii < N_ELEMENTS_PER_THREAD_X; ii++) { + amax_columnwise[ii] = (float)shm_amax_columnwise[0][local_col_base_id + ii] * (float)(Quantized_Limits::max_norm_rcp); + } + + // write back to columnwise_scale_inv and out_t + for (int loop_y_id = 0; loop_y_id < REPEAT_Y; loop_y_id++) { + int group_inner_y_id = loop_y_id * BLOCK_SIZE_Y * N_ELEMENTS_PER_THREAD_Y + local_row_base_id; + for (int ii_y = 0; ii_y < N_ELEMENTS_PER_THREAD_Y; ii_y++) { + int group_inner_y_offset = group_inner_y_id + ii_y; + int store_offset = (global_row_base_id + group_inner_y_offset) < nrows ? + group_inner_y_offset : + 0; + + for (int ii_x = 0; ii_x < N_ELEMENTS_PER_THREAD_X; ii_x++) { + float value = (float)shm[group_inner_y_offset][local_col_base_id + ii_x] / amax_columnwise[ii_x]; + tmp_store_reg.data.elt[ii_x] = static_cast(value); + } + tmp_store_reg.store_to(out_t_store_ptr + store_offset * ncols, 0); + } + } + if (threadIdx.y == 0) { + #pragma unroll + for (int i = 0; i < N_ELEMENTS_PER_THREAD_X; i++) { + scale_store_reg.data.elt[i] = amax_columnwise[i]; + } + scale_store_reg.store_to(columnwise_scale_inv_ptr, 0); + } +} + +} // namespace mtfp8 + +void mtfp8_cast_transpose(const Tensor* input, const Tensor* noop, Tensor* output, musaStream_t stream) { + using namespace mtfp8; + CheckNoopTensor(*noop, "mtfp8_cast_transpose_noop"); + CheckInputTensor(*input, "mtfp8_cast_transpose_input"); + CheckOutputTensor(*output, "mtfp8_cast_transpose_output"); + + // Check that inputs and outputs are available + NVTE_CHECK(input->has_data(), "Input is not allocated"); + NVTE_CHECK(output->has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output->has_columnwise_data(), "Output columnwise is not allocated"); + + // Flatten tensor to 2D + NVTE_CHECK(input->data.shape == output->data.shape, + "Input and output shapes do not match (input=", input->data.shape, + ", output=", output->data.shape); + const size_t row_length = input->flat_last_dim(); + const size_t num_rows = input->flat_first_dim(); + NVTE_CHECK(output->flat_first_dim() == num_rows && output->flat_last_dim() == row_length, + "Invalid output dimensions (expected ", std::vector{num_rows, row_length}, + ", got ", std::vector{output->flat_first_dim(), output->flat_last_dim()}, ")"); + + const auto rowwise_sinv_m = output->scale_inv.shape[0]; + const auto rowwise_sinv_n = output->scale_inv.shape[1]; + const auto columnwise_sinv_m = output->columnwise_scale_inv.shape[0]; + const auto columnwise_sinv_n = output->columnwise_scale_inv.shape[1]; + + const size_t group_size = next_power_of_2(row_length / rowwise_sinv_n); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->data.dtype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OutputType, + + constexpr int GROUP_SIZE = 128; // TODO: extend other group_size + NVTE_CHECK(group_size == GROUP_SIZE); + constexpr int BLOCK_SIZE_Y = 16; + constexpr int BLOCK_SIZE_X = 32; + constexpr int N_ELEMENTS_PER_THREAD_X = std::min(GROUP_SIZE / BLOCK_SIZE_X, 8); + constexpr int N_ELEMENTS_PER_THREAD_Y = 2; + + dim3 block(BLOCK_SIZE_X, BLOCK_SIZE_Y); + dim3 grid(DIVUP(row_length, group_size), DIVUP(num_rows, group_size)); + + // std::cout << "output->scale_inv.dptr: " << reinterpret_cast((void*)(output->scale_inv.dptr)) << std::endl; + if ((row_length % GROUP_SIZE) != 0) { + mtfp8_cast_transpose_general_kernel_column_unaligned + <<>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale_inv.dptr), + reinterpret_cast(output->columnwise_scale_inv.dptr), + row_length, + num_rows); + } else { + mtfp8_cast_transpose_general_kernel_column_aligned + <<>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale_inv.dptr), + reinterpret_cast(output->columnwise_scale_inv.dptr), + row_length, + num_rows); + } + ); + ); +} + +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/util/mtfp8_dequantize.mu b/transformer_engine/musa/common/util/mtfp8_dequantize.mu new file mode 100644 index 0000000000..200ad7e9a6 --- /dev/null +++ b/transformer_engine/musa/common/util/mtfp8_dequantize.mu @@ -0,0 +1,102 @@ +#include "mtfp8_cast.muh" + +#include + +#include "../util/string.h" +#include "../utils.muh" +#include "mtfp8_utils.muh" + +namespace transformer_engine { + +namespace mtfp8 { + +template < + typename IType, + typename OType, + typename CType, + size_t VLEN> +__global__ void fp8_general_dequantize( + const IType* inp, + OType* out, + const CType* sinv, + size_t M, + size_t N, + size_t sinv_m, + size_t sinv_n, + size_t block_m, + size_t block_n) { + using IVecT = Vec; + using OVecT = Vec; + + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t offset = tid * VLEN; + + const size_t dimx = offset / N; + const size_t dimy = offset % N; + const bool valid = dimx < M; + + const size_t groupx = dimx / block_m; + const size_t groupy = dimy / block_n; + + IVecT vec_in; + OVecT vec_out; + if (valid) { + const CType s_inv = sinv[groupx * sinv_n + groupy]; + vec_in.load_from(inp + offset, 0); +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_out.data.elt[j] = (OType)((CType)(vec_in.data.elt[j]) * s_inv); + } + vec_out.store_to(out + offset, 0); + } +} + +} // namespace mtfp8 + +void mtfp8_dequantize(const Tensor* input, Tensor* output, musaStream_t stream) { + NVTE_CHECK(is_fp8_dtype(input->data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input->data.shape, "Input and output shapes need to match."); + + using namespace mtfp8; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32); + using CType = float; + + const auto M = input->flat_first_dim(); + const auto N = input->flat_last_dim(); + const auto sinv_m = input->scale_inv.shape[0]; + const auto sinv_n = input->scale_inv.shape[1]; + + size_t block_m = 1; + const size_t block_n = 128; //(N / sinv_n); + if (M != sinv_m) { + block_m = block_n; + } + + constexpr size_t VLEN = io_bytes / sizeof(OType); + const size_t threads = max_threads_per_block; + const size_t blocks = ceil_div(M * N, threads * VLEN); + + fp8_general_dequantize + <<<(int)blocks, (int)threads, 0, stream>>>( + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(input->scale_inv.dptr), + M, + N, + sinv_m, + sinv_n, + block_m, + block_n); + NVTE_CHECK_CUDA(musaGetLastError()); + ); + ); +} + +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/util/mtfp8_groupwise_quantize.muh b/transformer_engine/musa/common/util/mtfp8_groupwise_quantize.muh new file mode 100644 index 0000000000..647d3f0936 --- /dev/null +++ b/transformer_engine/musa/common/util/mtfp8_groupwise_quantize.muh @@ -0,0 +1,329 @@ +#ifndef TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_GROUPWISE_QUANTIZE_MUH_ +#define TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_GROUPWISE_QUANTIZE_MUH_ + +#include "../common.h" +#include "../utils.muh" +#include "math.h" +#include "mtfp8_utils.muh" + +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::mtfp8 { + +template < + typename Param, + float (*OP)(float, const Param &), + typename IType, + typename OType, + typename CType, + size_t VLEN> +__global__ void fp8_groupwise_1_to_1_kernel( + const IType* inp, + const CType* noop, + OType* out, + CType* sinv, + size_t numel, + Param param) { + if (noop != nullptr && noop[0] == 1.0f) return; + using IVecT = Vec; + using CVecT = Vec; + using OVecT = Vec; + + const size_t tid = threadIdx.y * blockDim.x + threadIdx.x; + const size_t warp_id = threadIdx.y; + const size_t lane_id = threadIdx.x; + + const size_t data_offset = blockIdx.x * VLEN * blockDim.x * blockDim.y + tid * VLEN; + const size_t sinv_offset = blockIdx.x * blockDim.y + warp_id; + + IVecT vec_in; + CType amax = 0; + CVecT vec_temp; + const bool valid = (data_offset < numel); + if (valid) { + vec_in.load_from(inp + data_offset, 0); +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_temp.data.elt[j] = (CType)(OP((float)vec_in.data.elt[j], param)); + amax = fmaxf(fabsf(vec_temp.data.elt[j]), amax); + } + } + + amax = warp_reduce_max_broadcast(amax); + amax = fmaxf(amax, global_amax_min); + CType scale = (CType)(Quantized_Limits::max_norm) / amax; + + OVecT vec_out; + if (valid) { +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_out.data.elt[j] = (OType)(vec_temp.data.elt[j] * scale); + } + vec_out.store_to(out + data_offset, 0); + } + + if (valid && lane_id == 0) { + *(sinv + sinv_offset) = amax * (CType)(Quantized_Limits::max_norm_rcp); + } +} + +template < + typename Param, + float (*OP)(float, const Param &), + typename IType, + typename OType, + typename CType, + size_t VLEN> +__global__ void fp8_groupwise_1_to_n_kernel( + const IType* inp, + const CType* noop, + OType* out, + CType* sinv, + size_t numel, + size_t groups_per_warp, + size_t group_size, + Param param) { + if (noop != nullptr && noop[0] == 1.0f) return; + using IVecT = Vec; + using CVecT = Vec; + using OVecT = Vec; + + const size_t tid = threadIdx.y * blockDim.x + threadIdx.x; + const size_t warp_id = threadIdx.y; + const size_t lane_id = threadIdx.x; + + const size_t data_offset = blockIdx.x * VLEN * blockDim.x * blockDim.y + tid * VLEN; + const size_t lane_offset = lane_id * VLEN; + const size_t group_idx = lane_offset / group_size; + const size_t sinv_offset = (blockIdx.x * blockDim.y + warp_id) * groups_per_warp + group_idx; + const bool write_to_sinv = (lane_offset % group_size == 0); + + IVecT vec_in; + CType amax = 0; + CVecT vec_temp; + const bool valid = (data_offset < numel); + if (valid) { + vec_in.load_from(inp + data_offset, 0); +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_temp.data.elt[j] = (CType)(OP((float)vec_in.data.elt[j], param)); + amax = fmaxf(fabsf(vec_temp.data.elt[j]), amax); + } + } + + for (size_t i = 0; i < groups_per_warp; ++i) { + const bool flag = (i == group_idx); + CType group_max = flag ? amax : 0; + group_max = warp_reduce_max_broadcast(group_max); + if (flag) { + amax = group_max; + } + } + amax = fmaxf(amax, global_amax_min); + CType scale = (CType)(Quantized_Limits::max_norm) / amax; + + OVecT vec_out; + if (valid) { +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_out.data.elt[j] = (OType)(vec_temp.data.elt[j] * scale); + } + vec_out.store_to(out + data_offset, 0); + } + + if (valid && write_to_sinv) { + *(sinv + sinv_offset) = amax * (CType)(Quantized_Limits::max_norm_rcp); + } +} + +template < + typename Param, + float (*OP)(float, const Param &), + typename IType, + typename OType, + typename CType, + size_t VLEN> +__global__ void fp8_groupwise_1_to_1_kernel_no_align( + const IType* inp, + const CType* noop, + OType* out, + CType* sinv, + size_t M, + size_t N, + size_t group_size, + size_t last_n, + Param param) { + if (noop != nullptr && noop[0] == 1.0f) return; + + const int row_id = blockIdx.y * blockDim.y + threadIdx.y; + if (row_id >= M) return; + + using IVecT = Vec; + using CVecT = Vec; + using OVecT = Vec; + + const int col_id = blockIdx.x; + const int lane_id = threadIdx.x; + + const size_t row_offset = row_id * N; + const size_t col_offset = col_id * group_size + lane_id * VLEN; + const size_t data_offset = row_offset + col_offset; + + const size_t sinv_offset = row_id * gridDim.x + col_id; + + CType amax = 0; + IVecT vec_in; + CVecT vec_temp; + + const IType* iptr = inp + data_offset; + const bool is_dense = (col_offset + VLEN) < N; + + if (is_dense) { + vec_in.load_from(iptr, 0); +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_temp.data.elt[j] = (CType)(OP((float)vec_in.data.elt[j], param)); + amax = fmaxf(fabsf(vec_temp.data.elt[j]), amax); + } + } else { +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + if (col_offset + j < N) { + vec_temp.data.elt[j] = (CType)(OP((float)(*(iptr+j)), param)); + amax = fmaxf(fabsf(vec_temp.data.elt[j]), amax); + } + } + } + + amax = warp_reduce_max_broadcast(amax); + amax = fmaxf(amax, global_amax_min); + CType scale = (CType)(Quantized_Limits::max_norm) / amax; + + OType* optr = out + data_offset; + if (is_dense) { + OVecT vec_out; +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + vec_out.data.elt[j] = (OType)(vec_temp.data.elt[j] * scale); + } + vec_out.store_to(optr, 0); + } else { +#pragma unroll + for (size_t j = 0; j < VLEN; ++j) { + if (col_offset + j < N) { + *(optr+j) = (OType)(vec_temp.data.elt[j] * scale); + } + } + } + + if (lane_id == 0) { + *(sinv + sinv_offset) = amax * (CType)(Quantized_Limits::max_norm_rcp); + } +} + +template < + typename Param, + float (*OP)(float, const Param &), + typename IType, + typename OType, + typename CType> +inline void fp8_groupwise_cast( + const IType* inp, + const CType* noop, + OType* out, + CType* sinv, + size_t M, + size_t N, + size_t group_size, + Param param, + musaStream_t stream) { + + if (N % group_size != 0) { + constexpr int thx = warp_size; + constexpr int thy = 8; + dim3 threads(thx, thy); + + const int blk_x = N / group_size + 1; + const int blk_y = (int)ceil_div(M, (size_t)thy); + dim3 blocks(blk_x, blk_y); + + const int last_n = (int)(N % group_size); + +#define DISPATCH_1_TO_1_NO_ALIGN(G, V) \ + if (group_size == G) { \ + constexpr size_t VLEN = V; \ + fp8_groupwise_1_to_1_kernel_no_align \ + <<>>( \ + inp, noop, out, sinv, M, N, group_size, last_n, param); \ + NVTE_CHECK_CUDA(musaGetLastError()); \ + return; \ + } + + if constexpr (sizeof(IType) == 2) { + DISPATCH_1_TO_1_NO_ALIGN(128, 4); + NVTE_ERROR("Not supported [1, ", group_size, "] blocksize for mtfp8 groupwise quantize."); + } else if constexpr (sizeof(IType) == 4) { + DISPATCH_1_TO_1_NO_ALIGN(128, 4); + NVTE_ERROR("Not supported [1, ", group_size, "] blocksize for mtfp8 groupwise quantize."); + } + +#undef DISPATCH_1_TO_1_NO_ALIGN + + NVTE_ERROR("Not supported [1, ", group_size, "] blocksize for mtfp8 groupwise quantize."); + } + + const size_t numel = M * N; + constexpr size_t thd_x = warp_size; + constexpr size_t thd_y = max_threads_per_block / warp_size; + dim3 threads((int)thd_x, (int)thd_y); + +#define DISPATCH_1_TO_1(G, V) \ + if (group_size == G) { \ + constexpr size_t VLEN = V; \ + constexpr size_t elems_per_block = max_threads_per_block * VLEN; \ + const int blocks = (int)ceil_div(numel, elems_per_block); \ + fp8_groupwise_1_to_1_kernel \ + <<>>( \ + inp, noop, out, sinv, numel, param); \ + NVTE_CHECK_CUDA(musaGetLastError()); \ + return; \ + } + + if constexpr (sizeof(IType) == 2) { + DISPATCH_1_TO_1(32, 1); + DISPATCH_1_TO_1(64, 2); + DISPATCH_1_TO_1(128, 4); + DISPATCH_1_TO_1(256, 8); + NVTE_ERROR("Not supported [1, ", group_size, "] blocksize for mtfp8 groupwise quantize."); + } else if constexpr (sizeof(IType) == 4) { + DISPATCH_1_TO_1(32, 1); + DISPATCH_1_TO_1(64, 2); + DISPATCH_1_TO_1(128, 4); + DISPATCH_1_TO_1(256, 8); + NVTE_ERROR("Not supported [1, ", group_size, "] blocksize for mtfp8 groupwise quantize."); + } + +#undef DISPATCH_1_TO_1 + + constexpr size_t VLEN = io_bytes / sizeof(IType); + constexpr size_t elems_per_warp = VLEN * warp_size; + if (elems_per_warp % group_size == 0) { + constexpr size_t elems_per_block = max_threads_per_block * VLEN; + const int blocks = (int)ceil_div(numel, elems_per_block); + + const size_t groups_per_warp = elems_per_warp / group_size; + + fp8_groupwise_1_to_n_kernel + <<>>( + inp, noop, out, sinv, numel, groups_per_warp, group_size, param); + + NVTE_CHECK_CUDA(musaGetLastError()); + return; + } + + NVTE_ERROR("Not supported [1, ", group_size, "] blocksize for mtfp8 groupwise quantize."); +} + +} // namespace transformer_engine::mtfp8 + +#endif // TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_GROUPWISE_QUANTIZE_MUH_ diff --git a/transformer_engine/musa/common/util/mtfp8_utils.muh b/transformer_engine/musa/common/util/mtfp8_utils.muh new file mode 100644 index 0000000000..1b309b4565 --- /dev/null +++ b/transformer_engine/musa/common/util/mtfp8_utils.muh @@ -0,0 +1,65 @@ +#ifndef TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_UTILS_MUH_ +#define TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_UTILS_MUH_ + +#include "musa_driver.h" +#include "musa_runtime.h" + +namespace transformer_engine::mtfp8 { + +inline constexpr size_t io_bytes = 16; +inline constexpr size_t warp_size = 32; +inline constexpr size_t warp_bits = 5; +inline constexpr size_t warp_mask = 0x1f; +inline constexpr size_t max_threads_per_block = 1024; +inline constexpr size_t max_warps_per_block = max_threads_per_block / warp_size; + +inline bool is_power_of_2(size_t n) { + return (n > 0) && (n & (n - 1)) == 0; +} + +inline size_t next_power_of_2(size_t n) { + assert(n >= 1); + if (is_power_of_2(n)) { + return n; + } + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n |= n >> 32; + n += 1; + assert(is_power_of_2(n)); + return n; +} + +inline size_t ceil_div(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline __device__ float global_amax_min = 1e-15; + +template +struct VlenTrait { + static constexpr bool is_power_of_2 = false; +}; + +#define ADD_VLEN_TRAIT(LEN, BITS, MASK) \ +template<> \ +struct VlenTrait { \ + static constexpr bool is_power_of_2 = true; \ + static constexpr size_t bits = BITS; \ + static constexpr size_t mask = MASK; \ +} + +ADD_VLEN_TRAIT(4, 2, 0x3); +ADD_VLEN_TRAIT(8, 3, 0x7); +ADD_VLEN_TRAIT(16, 4, 0xf); +ADD_VLEN_TRAIT(32, 5, 0x1f); +ADD_VLEN_TRAIT(64, 6, 0x3f); + +#undef ADD_VLEN_TRAIT + +} // namespace transformer_engine::mtfp8 + +#endif // TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MTFP8_UTILS_MUH_ diff --git a/transformer_engine/musa/common/util/mudnn.h b/transformer_engine/musa/common/util/mudnn.h new file mode 100644 index 0000000000..8873bcbd49 --- /dev/null +++ b/transformer_engine/musa/common/util/mudnn.h @@ -0,0 +1,90 @@ +#ifndef TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MUDNN_H_ +#define TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MUDNN_H_ + +#include +#include + +#include +#include +#include +#include + +namespace transformer_engine::musa { + +using ScalarType = typename at::ScalarType; +using DimVector = typename c10::DimVector; + +using MUTensor = typename at::musa::muTensor; + +inline std::vector Flat2DimShape(const Tensor* t) { + return {t->flat_first_dim(), t->flat_last_dim()}; +} + +inline std::pair +make_mudnn_sizes_strides(const std::vector& shape) { + auto mudnn_sizes = DimVector(shape.cbegin(), shape.cend()); + auto mudnn_strides = c10::contiguous_strides(mudnn_sizes); + return std::make_pair(std::move(mudnn_sizes), std::move(mudnn_strides)); +} + +inline ScalarType ToTorchDtype(DType te_dtype) { + auto th_dtype = ScalarType::Undefined; + switch (te_dtype) { + case DType::kFloat16: + th_dtype = ScalarType::Half; + break; + case DType::kBFloat16: + th_dtype = ScalarType::BFloat16; + break; + case DType::kFloat32: + th_dtype = ScalarType::Float; + break; + case DType::kFloat8E4M3: + th_dtype = ScalarType::Float8_e4m3fn; + break; + case DType::kFloat8E5M2: + th_dtype = ScalarType::Float8_e5m2; + break; + default: + break; + } + return th_dtype; +} + +inline void SetMUTensorDType(DType te_dtype, MUTensor& m_t) { + at::musa::SetMUTensorDType(ToTorchDtype(te_dtype), m_t); +} + +inline void SetMUTensorFormat( + const std::vector& shape, + MUTensor& m_t) { + const int ndim = shape.size(); + const auto mudnn_format = (ndim == 5) ? MUTensor::Format::NCDHW + : MUTensor::Format::NCHW; + m_t.SetFormat(mudnn_format); + + const auto ss = make_mudnn_sizes_strides(shape); + m_t.SetNdInfo(ndim, ss.first.data(), ss.second.data()); +} + +inline MUTensor CreateMUTensor(const SimpleTensor& st) { + MUTensor m_t; + SetMUTensorDType(st.dtype, m_t); + at::musa::SetMUTensorAddr(st.dptr, m_t); + SetMUTensorFormat(st.shape, m_t); + return m_t; +} + +inline MUTensor CreateMUTensor( + const SimpleTensor& st, + const std::vector& shape) { + MUTensor m_t; + SetMUTensorDType(st.dtype, m_t); + at::musa::SetMUTensorAddr(st.dptr, m_t); + SetMUTensorFormat(shape, m_t); + return m_t; +} + +} // namespace transformer_engine::musa + +#endif // TRANSFORMER_ENGINE_MUSA_COMMON_UTIL_MUDNN_H_ diff --git a/transformer_engine/musa/common/util/musa_driver.cpp b/transformer_engine/musa/common/util/musa_driver.cpp new file mode 100644 index 0000000000..931703c063 --- /dev/null +++ b/transformer_engine/musa/common/util/musa_driver.cpp @@ -0,0 +1,108 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "../common.h" +#include "../util/musa_runtime.h" + +namespace transformer_engine { + +namespace { + +/*! \brief Wrapper class for a shared library + * + * \todo Windows support + */ +class Library { + public: + explicit Library(const char *filename) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); + NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + ~Library() { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support +#else + if (handle_ != nullptr) { + dlclose(handle_); + } +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + Library(const Library &) = delete; // move-only + + Library(Library &&other) noexcept { swap(*this, other); } + + Library &operator=(Library other) noexcept { + // Copy-and-swap idiom + swap(*this, other); + return *this; + } + + friend void swap(Library &first, Library &second) noexcept; + + void *get() noexcept { return handle_; } + + const void *get() const noexcept { return handle_; } + + /*! \brief Get pointer corresponding to symbol in shared library */ + void *get_symbol(const char *symbol) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + void *ptr = dlsym(handle_, symbol); + NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); + return ptr; +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + private: + void *handle_ = nullptr; +}; + +void swap(Library &first, Library &second) noexcept { + using std::swap; + swap(first.handle_, second.handle_); +} + +/*! \brief Lazily-initialized shared library for CUDA driver */ +Library &musa_driver_lib() { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + constexpr char lib_name[] = "nvcuda.dll"; +#else + constexpr char lib_name[] = "libmusa.so"; +#endif + static Library lib(lib_name); + return lib; +} + +} // namespace + +namespace cuda_driver { + +void *get_symbol(const char *symbol) { + void *entry_point; + // cudaDriverEntryPointQueryResult driver_result; + // NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result)); + // NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess, + // "Could not find CUDA driver entry point for ", symbol); + entry_point = musa_driver_lib().get_symbol(symbol); + return entry_point; +} + +} // namespace cuda_driver + +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/util/musa_driver.h b/transformer_engine/musa/common/util/musa_driver.h new file mode 100644 index 0000000000..f830f3b762 --- /dev/null +++ b/transformer_engine/musa/common/util/musa_driver.h @@ -0,0 +1,62 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ + +#include + +#include + +#include "../common.h" +#include "../util/string.h" + +namespace transformer_engine { + +namespace cuda_driver { + +/*! \brief Get pointer corresponding to symbol in CUDA driver library */ +void *get_symbol(const char *symbol); + +/*! \brief Call function in CUDA driver library + * + * The CUDA driver library (libcuda.so.1 on Linux) may be different at + * compile-time and run-time. In particular, the CUDA Toolkit provides + * stubs for the driver library in case compilation is on a system + * without GPUs. Indirect function calls into a lazily-initialized + * library ensures we are accessing the correct version. + * + * \param[in] symbol Function name + * \param[in] args Function arguments + */ +template +inline MUresult call(const char *symbol, ArgTs... args) { + using FuncT = MUresult(ArgTs...); + FuncT *func = reinterpret_cast(get_symbol(symbol)); + return (*func)(args...); +} + +} // namespace cuda_driver + +} // namespace transformer_engine + +#define NVTE_CHECK_CUDA_DRIVER(expr) \ + do { \ + const MUresult status_NVTE_CHECK_MUSA_DRIVER = (expr); \ + if (status_NVTE_CHECK_MUSA_DRIVER != MUSA_SUCCESS) { \ + const char *desc_NVTE_CHECK_MUSA_DRIVER; \ + ::transformer_engine::cuda_driver::call("muGetErrorString", status_NVTE_CHECK_MUSA_DRIVER, \ + &desc_NVTE_CHECK_MUSA_DRIVER); \ + NVTE_ERROR("MUSA Error: ", desc_NVTE_CHECK_MUSA_DRIVER); \ + } \ + } while (false) + +#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \ + do { \ + NVTE_CHECK_CUDA_DRIVER(::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \ + } while (false) + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ diff --git a/transformer_engine/musa/common/util/musa_runtime.cpp b/transformer_engine/musa/common/util/musa_runtime.cpp new file mode 100644 index 0000000000..998fb193ed --- /dev/null +++ b/transformer_engine/musa/common/util/musa_runtime.cpp @@ -0,0 +1,194 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/musa_runtime.h" + +#include +#include + +#include "../common.h" +#include "../util/musa_driver.h" +#include "../util/system.h" + +namespace transformer_engine { + +namespace cuda { + +namespace { + +// String with build-time CUDA include path +#include "string_path_musa_include.h" + +} // namespace + +int num_devices() { + auto query_num_devices = []() -> int { + int count; + NVTE_CHECK_CUDA(musaGetDeviceCount(&count)); + return count; + }; + static int num_devices_ = query_num_devices(); + return num_devices_; +} + +int current_device() { + // Return 0 if CUDA context is not initialized + MUcontext context; + NVTE_CALL_CHECK_CUDA_DRIVER(muCtxGetCurrent, &context); + if (context == nullptr) { + return 0; + } + + // Query device from CUDA runtime + int device_id; + NVTE_CHECK_CUDA(musaGetDevice(&device_id)); + return device_id; +} + +int sm_arch(int device_id) { + static std::vector cache(num_devices(), -1); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid MUSA device ID"); + auto init = [&]() { + musaDeviceProp prop; + NVTE_CHECK_CUDA(musaGetDeviceProperties(&prop, device_id)); + cache[device_id] = 10 * prop.major + prop.minor; + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +} + +int sm_count(int device_id) { + static std::vector cache(num_devices(), -1); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid MUSA device ID"); + auto init = [&]() { + musaDeviceProp prop; + NVTE_CHECK_CUDA(musaGetDeviceProperties(&prop, device_id)); + cache[device_id] = prop.multiProcessorCount; + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +} + +void stream_priority_range(int *low_priority, int *high_priority, int device_id) { + static std::vector> cache(num_devices()); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + int ori_dev = current_device(); + if (device_id != ori_dev) NVTE_CHECK_CUDA(musaSetDevice(device_id)); + int min_pri, max_pri; + NVTE_CHECK_CUDA(musaDeviceGetStreamPriorityRange(&min_pri, &max_pri)); + if (device_id != ori_dev) NVTE_CHECK_CUDA(musaSetDevice(ori_dev)); + cache[device_id] = std::make_pair(min_pri, max_pri); + }; + std::call_once(flags[device_id], init); + *low_priority = cache[device_id].first; + *high_priority = cache[device_id].second; +} + +bool supports_multicast(int device_id) { +#if CUDART_VERSION >= 12010 + // NOTE: This needs to be guarded at compile time because the + // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions. + static std::vector cache(num_devices(), false); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + CUdevice cudev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); + int result; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); + cache[device_id] = static_cast(result); + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +#else + return false; +#endif +} + +const std::string &include_directory(bool required) { + static std::string path; + + // Update cached path if needed + static bool need_to_check_env = true; + if (path.empty() && required) { + need_to_check_env = true; + } + if (need_to_check_env) { + // Search for CUDA headers in common paths + using Path = std::filesystem::path; + std::vector> search_paths = {{"NVTE_MUSA_INCLUDE_DIR", ""}, + {"MUSA_HOME", ""}, + {"MUSA_DIR", ""}, + {"", string_path_musa_include}, + {"", "/usr/local/musa"}}; + for (auto &[env, p] : search_paths) { + if (p.empty()) { + p = getenv(env.c_str()); + } + if (!p.empty()) { + if (file_exists(p / "musa_runtime.h")) { + path = p; + break; + } + if (file_exists(p / "include" / "musa_runtime.h")) { + path = p / "include"; + break; + } + } + } + + // Throw exception if path is required but not found + if (path.empty() && required) { + std::string message; + message.reserve(2048); + message += "Could not find musa_runtime.h in"; + bool is_first = true; + for (const auto &[env, p] : search_paths) { + message += is_first ? " " : ", "; + is_first = false; + if (!env.empty()) { + message += env; + message += "="; + } + if (p.empty()) { + message += ""; + } else { + message += p; + } + } + message += + (". " + "Specify path to MUSA Toolkit headers " + "with NVTE_MUSA_INCLUDE_DIR."); + NVTE_ERROR(message); + } + need_to_check_env = false; + } + + // Return cached path + return path; +} + +} // namespace cuda + +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/util/musa_runtime.h b/transformer_engine/musa/common/util/musa_runtime.h new file mode 100644 index 0000000000..f35b45d6ba --- /dev/null +++ b/transformer_engine/musa/common/util/musa_runtime.h @@ -0,0 +1,74 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ + +#include + +#include + +namespace transformer_engine { + +namespace cuda { + +/* \brief Number of accessible devices */ +int num_devices(); + +/* \brief Which device is currently being used */ +int current_device(); + +/* \brief Compute capability of device + * + * \param[in] device_id CUDA device (default is current device) + * + * \return Compute capability as int. Last digit is minor revision, + * remaining digits are major revision. + */ +int sm_arch(int device_id = -1); + +/* \brief Number of multiprocessors on a device + * + * \param[in] device_id CUDA device (default is current device) + * + * \return Number of multiprocessors + */ +int sm_count(int device_id = -1); + +/* \brief Minimum and maximum stream priorities supported on device + * + * \param[in] device_id CUDA device (default is current device) + * + * \param[out] low_priority Lowest priority value on device. + * + * \param[out] high_priority Highest priority value on device. + */ +void stream_priority_range(int *low_priority, int *high_priority, int device_id = -1); + +/* \brief CUDA Multicast support status for device + * + * \param[in] device_id CUDA device (default is current device) + * + * \return CUDA multicast support flag + */ +bool supports_multicast(int device_id = -1); + +/* \brief Path to CUDA Toolkit headers + * + * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the + * environment. Otherwise searches in common install paths. + * + * \param[in] required Whether to throw exception if not found + * + * \return Path to include directory, or an empty string if not found + */ +const std::string &include_directory(bool required = false); + +} // namespace cuda + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ diff --git a/transformer_engine/musa/common/util/padding.mu b/transformer_engine/musa/common/util/padding.mu new file mode 100644 index 0000000000..cd020a6e06 --- /dev/null +++ b/transformer_engine/musa/common/util/padding.mu @@ -0,0 +1,219 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../utils.muh" + +namespace transformer_engine { + +namespace { + +// Parameters to tune +constexpr int n_warps_per_tile = 4; +constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; +constexpr int desired_load_store_size = 8; +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB + +struct MultiPaddingArgs { + // (input) Data buffers for input tensors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for cast output tensors + void* output_list[kMaxTensorsPerKernel]; + // Input matrix heights + int num_rows_list[kMaxTensorsPerKernel]; + // Input matrix heights (padded) + int padded_num_rows_list[kMaxTensorsPerKernel]; + // Input matrix widths + int row_length_list[kMaxTensorsPerKernel]; + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiPaddingArgs args) { + using Vec = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr int bdimx = THREADS_PER_WARP; + constexpr int bdimy = n_warps_per_tile; + const int tid = threadIdx.x; + const int tidx = tid % bdimx; + const int tidy = tid / bdimx; + const int bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + constexpr int tile_dim_m = THREADS_PER_WARP * nvec; + constexpr int tile_dim_n = THREADS_PER_WARP * nvec; + + // Number of nvec x nvec subtiles for each thread to + // load/store + constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + + // Find tensor corresponding to block + int tensor_id = 0; + while (args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + const Type* input = reinterpret_cast(args.input_list[tensor_id]); + Type* output = reinterpret_cast(args.output_list[tensor_id]); + const int num_rows = args.num_rows_list[tensor_id]; + const int padded_num_rows = args.padded_num_rows_list[tensor_id]; + const int row_length = args.row_length_list[tensor_id]; + + // Find position of tile within tensor + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int tile_id = bid - args.block_range[tensor_id]; + const int tile_id_m = tile_id / num_tiles_n; + const int tile_id_n = tile_id % num_tiles_n; + const int tile_row = tile_id_m * tile_dim_m; + const int tile_col = tile_id_n * tile_dim_n; + + // Load input and store to registers + // Note: Each thread loads n_iterations subtiles, casts to output + // type, and transposes in registers. + Type local_zero = static_cast(0.f); +#pragma unroll + for (int iter = 0; iter < n_iterations; ++iter) { + const int i1 = tidy + iter * bdimy; + const int j1 = tidx; +#pragma unroll + for (int i2 = 0; i2 < nvec; ++i2) { + const int row = tile_row + i1 * nvec + i2; + const int col = tile_col + j1 * nvec; + Vec local_input; + Vec local_output; + local_input.clear(); + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + local_input.data.elt[j2] = input[row * row_length + col + j2]; + } + } + } +#pragma unroll + for (int j2 = 0; j2 < nvec; ++j2) { + local_output.data.elt[j2] = local_input.data.elt[j2]; + } + if (row < num_rows) { + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_output.data.elt[j2]; + } + } + } else if (row < padded_num_rows) { + // padding + for (int j2 = 0; j2 < nvec; ++j2) { + if (col + j2 < row_length) { + output[row * row_length + col + j2] = local_zero; + } + } + } + } + } +} + +} // namespace + +void multi_padding(const std::vector input_list, std::vector output_list, + const std::vector padded_num_rows_list, musaStream_t stream) { + // Check that number of tensors is valid + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); + if (input_list.empty()) { + return; + } + + // Check that tensor properties are valid + DType type = input_list[0]->data.dtype; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + const auto& input = *input_list[tensor_id]; + const auto& output = *output_list[tensor_id]; + CheckInputTensor(input, "multi_padding_input_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_padding_output_" + std::to_string(tensor_id)); + + NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match."); + NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match."); + + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); + NVTE_CHECK(output.data.shape[0] == padded_num_rows_list[tensor_id], + "output tensor shape does not match padded input shape."); + } + + // Input matrices are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles + const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); + + // Add tensors to kernel argument struct + MultiPaddingArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { + // Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + kernel_args.num_tensors = 0; + } + + // Calculate number of thread blocks needed for tensor + const int num_rows = input_list[tensor_id]->data.shape[0]; + const int padded_num_rows = padded_num_rows_list[tensor_id]; + const int row_length = input_list[tensor_id]->data.shape[1]; + const int num_tiles_m = (padded_num_rows + tile_dim_m - 1) / tile_dim_m; + const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; + const int num_tiles = num_tiles_m * num_tiles_n; + + // Add tensor to kernel argument struct + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); + kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.num_rows_list[pos] = num_rows; + kernel_args.padded_num_rows_list[pos] = padded_num_rows; + kernel_args.row_length_list[pos] = row_length; + kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; + kernel_args.num_tensors++; + } + + // Launch kernel + if (kernel_args.num_tensors > 0) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type); + const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; + multi_padding_kernel + <<>>(kernel_args);); // NOLINT(*) + } +} + +} // namespace transformer_engine + +void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, + const int* padded_num_rows_list, musaStream_t stream) { + NVTE_API_CALL(nvte_multi_padding); + using namespace transformer_engine; + std::vector input_list_, output_list_; + std::vector padded_num_rows_list_; + for (size_t i = 0; i < num_tensors; ++i) { + input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); + output_list_.push_back(reinterpret_cast(output_list[i])); + padded_num_rows_list_.push_back(padded_num_rows_list[i]); + } + multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); +} diff --git a/transformer_engine/musa/common/util/pybind_helper.h b/transformer_engine/musa/common/util/pybind_helper.h new file mode 100644 index 0000000000..e9a6eab648 --- /dev/null +++ b/transformer_engine/musa/common/util/pybind_helper.h @@ -0,0 +1,111 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ + +#include +#include +#include +#include + +#include "musa_runtime.h" + +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType", pybind11::module_local()) \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType", \ + pybind11::module_local()) \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo", \ + pybind11::module_local()) \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + py::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + py::call_guard()) \ + .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ + py::call_guard()) \ + .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ + py::call_guard()) \ + .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + py::call_guard()); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def( \ + "get_stream_priority_range", \ + [](int device_id = -1) { \ + int low_pri, high_pri; \ + transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ + return std::make_pair(low_pri, high_pri); \ + }, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); + +#endif diff --git a/transformer_engine/musa/common/util/rtc.cpp b/transformer_engine/musa/common/util/rtc.cpp new file mode 100644 index 0000000000..3453bcdd72 --- /dev/null +++ b/transformer_engine/musa/common/util/rtc.cpp @@ -0,0 +1,237 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/rtc.h" + +#include +#include +#include + +#include "../common.h" +#include "../util/musa_driver.h" +#include "../util/string.h" +#include "../util/system.h" + +namespace transformer_engine { + +namespace rtc { + +namespace { + +// Strings with headers for RTC kernels +#include "string_code_util_math_h.h" +#include "string_code_utils_muh.h" + +/*! \brief Latest compute capability that NVRTC supports + * + * \return Compute capability as int. Last digit is minor revision, + * remaining digits are major revision. + */ +inline int max_supported_sm_arch() { + static int arch_ = -1; + // if (arch_ < 0) { + // int num_archs = 0; + // NVTE_CHECK_NVRTC(nvrtcGetNumSupportedArchs(&num_archs)); + // NVTE_CHECK(num_archs > 0, "Could not determine SM archs that NVRTC supports"); + // std::vector archs(num_archs); + // NVTE_CHECK_NVRTC(nvrtcGetSupportedArchs(archs.data())); + // arch_ = archs.back(); + // } + return arch_; +} + +} // namespace + +bool is_enabled() { + // static bool is_enabled_ = false; + // static bool need_to_check_env = true; + // if (need_to_check_env) { + // is_enabled_ = !getenv("NVTE_DISABLE_NVRTC"); + // need_to_check_env = false; + // } + // return is_enabled_; + return false; +} + +Kernel::Kernel(std::string mangled_name, std::string compiled_code) + : mangled_name_{std::move(mangled_name)}, + compiled_code_{std::move(compiled_code)}, + modules_(cuda::num_devices(), null_module), + functions_(cuda::num_devices(), null_function), + init_flags_{std::make_unique>(cuda::num_devices())} {} + +Kernel::~Kernel() { + for (int device_id = 0; device_id < static_cast(modules_.size()); ++device_id) { + // Unload CUDA modules if needed + if (modules_[device_id] != null_module) { + MUdevice device; + MUcontext context; + if (cuda_driver::call("muDeviceGet", &device, device_id) != MUSA_SUCCESS) { + continue; + } + if (cuda_driver::call("muDevicePrimaryCtxRetain", &context, device) != MUSA_SUCCESS) { + continue; + } + if (cuda_driver::call("muCtxSetCurrent", context) != MUSA_SUCCESS) { + continue; + } + cuda_driver::call("muModuleUnload", modules_[device_id]); + cuda_driver::call("muDevicePrimaryCtxRelease", device); + } + } +} + +Kernel::Kernel(Kernel&& other) noexcept { swap(*this, other); } + +Kernel& Kernel::operator=(Kernel other) noexcept { + // Copy-and-swap idiom + swap(*this, other); + return *this; +} + +void swap(Kernel& first, Kernel& second) noexcept { + using std::swap; + swap(first.mangled_name_, second.mangled_name_); + swap(first.compiled_code_, second.compiled_code_); + swap(first.modules_, second.modules_); + swap(first.functions_, second.functions_); + swap(first.init_flags_, second.init_flags_); +} + +MUfunction Kernel::get_function(int device_id) { + // Load kernel on device if needed + auto load_on_device = [&]() { + // Set driver context to proper device + MUdevice device; + MUcontext context; + NVTE_CALL_CHECK_CUDA_DRIVER(muDeviceGet, &device, device_id); + NVTE_CALL_CHECK_CUDA_DRIVER(muDevicePrimaryCtxRetain, &context, device); + NVTE_CALL_CHECK_CUDA_DRIVER(muCtxSetCurrent, context); + + // Load function into driver context + NVTE_CALL_CHECK_CUDA_DRIVER(muModuleLoadDataEx, &modules_[device_id], compiled_code_.c_str(), + 0, // numOptions + nullptr, // options + nullptr); // optionValues + NVTE_CALL_CHECK_CUDA_DRIVER(muModuleGetFunction, &functions_[device_id], modules_[device_id], + mangled_name_.c_str()); + + // Reset driver context + NVTE_CALL_CHECK_CUDA_DRIVER(muDevicePrimaryCtxRelease, device); + }; + std::call_once(init_flags_->at(device_id), load_on_device); + + // Return CUDA function + return functions_[device_id]; +} + +void Kernel::set_function_cache_config(int device_id, MUfunc_cache cache_config) { + NVTE_CALL_CHECK_CUDA_DRIVER(muFuncSetCacheConfig, get_function(device_id), cache_config); +} + +KernelManager& KernelManager::instance() { + NVTE_CHECK(is_enabled(), "NVRTC support is not enabled"); + static KernelManager instance_; + return instance_; +} + +void KernelManager::compile(const std::string& kernel_label, const std::string& kernel_name, + const std::string& code, const std::string& filename) { +// std::lock_guard lock_guard_(lock_); + +// // Choose whether to compile to PTX or cubin +// const int device_id = cuda::current_device(); +// const int sm_arch_ = cuda::sm_arch(device_id); +// const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch()); +// const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch); + +// // Compilation flags +// std::vector opts = { +// #if NDEBUG == 0 +// "-G", +// #endif +// "--std=c++17"}; +// if (compile_ptx) { +// opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch)); +// } else { +// opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch)); +// } +// opts.push_back(concat_strings("-I", cuda::include_directory(true))); +// std::vector opts_ptrs; +// for (const auto& opt : opts) { +// opts_ptrs.push_back(opt.c_str()); +// } + +// // Compile source +// nvrtcProgram program; +// constexpr int num_headers = 2; +// constexpr const char* headers[num_headers] = {string_code_utils_muh, string_code_util_math_h}; +// constexpr const char* include_names[num_headers] = {"utils.muh", "util/math.h"}; +// NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers, +// headers, include_names)); +// NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str())); +// const nvrtcResult compile_result = +// nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data()); +// if (compile_result != NVRTC_SUCCESS) { +// // Display log if compilation failed +// std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n"); +// const size_t log_offset = log.size(); +// size_t log_size; +// NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size)); +// log.resize(log_offset + log_size); +// NVTE_CHECK_NVRTC(nvrtcGetProgramLog(program, &log[log_offset])); +// log.back() = '\n'; +// std::cerr << log; +// NVTE_CHECK_NVRTC(compile_result); +// } + +// // Get mangled function name +// const char* mangled_name; +// NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, kernel_name.c_str(), &mangled_name)); + +// // Get compiled code +// std::string compiled_code; +// if (compile_ptx) { +// size_t compiled_size; +// NVTE_CHECK_NVRTC(nvrtcGetPTXSize(program, &compiled_size)); +// compiled_code.resize(compiled_size); +// NVTE_CHECK_NVRTC(nvrtcGetPTX(program, compiled_code.data())); +// } else { +// size_t compiled_size; +// NVTE_CHECK_NVRTC(nvrtcGetCUBINSize(program, &compiled_size)); +// compiled_code.resize(compiled_size); +// NVTE_CHECK_NVRTC(nvrtcGetCUBIN(program, compiled_code.data())); +// } + +// // Cache compiled code +// const auto key = get_kernel_cache_key(kernel_label, device_id); +// kernel_cache_.insert({key, Kernel(mangled_name, std::move(compiled_code))}); +// kernel_cache_.at(key).get_function(device_id); // Make sure kernel is available on device + +// // Clean up +// NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program)); +} + +void KernelManager::set_cache_config(const std::string& kernel_label, MUfunc_cache cache_config) { + const int device_id = cuda::current_device(); + const auto key = get_kernel_cache_key(kernel_label, device_id); + NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to configure RTC kernel before compilation"); + kernel_cache_.at(key).set_function_cache_config(device_id, cache_config); +} + +bool KernelManager::is_compiled(const std::string& kernel_label, int device_id) const { + const auto key = get_kernel_cache_key(kernel_label, device_id); + return kernel_cache_.count(key) > 0; +} + +std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label, + int device_id) const { + return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label); +} + +} // namespace rtc + +} // namespace transformer_engine diff --git a/transformer_engine/musa/common/util/rtc.h b/transformer_engine/musa/common/util/rtc.h new file mode 100644 index 0000000000..331ed4ec40 --- /dev/null +++ b/transformer_engine/musa/common/util/rtc.h @@ -0,0 +1,184 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_ + +#include +#include +// #include + +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/musa_driver.h" +#include "../util/musa_runtime.h" + +namespace transformer_engine { + +namespace rtc { + +/*! \brief Whether NVRTC support is enabled + * + * NVRTC support can be disabled by setting NVTE_DISABLE_NVRTC=1 in + * the environment. + */ +bool is_enabled(); + +/*! \brief Wrapper class for a runtime-compiled CUDA kernel */ +class Kernel { + public: + Kernel(std::string mangled_name, std::string compiled_code); + ~Kernel(); + Kernel(const Kernel &) = delete; // move-only + Kernel(Kernel &&) noexcept; + Kernel &operator=(Kernel) noexcept; + friend void swap(Kernel &first, Kernel &second) noexcept; + + /*! \brief Launch CUDA kernel + * + * Loads the kernel into the device the first time the device is + * accessed. + * + * \param[in] device_id CUDA device + * \param[in] grid_dim Grid dimensions in blocks + * \param[in] block_dim Thread block dimensions + * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in + * bytes + * \param[in] stream CUDA stream + * \param[in] args Kernel arguments + */ + template + void launch(int device_id, const dim3 grid_dim, const dim3 block_dim, + unsigned int shared_mem_bytes, musaStream_t stream, ArgTs &&...args) { + void *arg_ptrs[] = {const_cast(static_cast(&args))...}; + NVTE_CALL_CHECK_CUDA_DRIVER(muLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y, + grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes, + static_cast(stream), arg_ptrs, nullptr); + } + + /*! \brief CUDA function for given CUDA device + * + * Loads the kernel into the device the first time the device is + * accessed. + */ + MUfunction get_function(int device_id); + + /*! \brief Sets the preferred cache configuration for a function + * + * Wrapper of the CUDA Driver API function "cuFuncSetCacheConfig" + */ + void set_function_cache_config(int device_id, MUfunc_cache cache_config); + + private: + /*! \brief Mangled function name */ + std::string mangled_name_; + /*! \brief Compiled assembly, either in PTX or cubin format */ + std::string compiled_code_; + /*! CUDA module for each CUDA device */ + std::vector modules_; + /*! CUDA function for each CUDA device */ + std::vector functions_; + + /*! Flags for thread-safe kernel initialization */ + std::unique_ptr> init_flags_; + + /*! \brief Uninitialized CUDA module */ + static constexpr MUmodule null_module = static_cast(nullptr); + /*! Uninitialized CUDA function */ + static constexpr MUfunction null_function = static_cast(nullptr); +}; + +/*! \brief Singleton class to manage runtime-compiled CUDA kernels */ +class KernelManager { + public: + /*! \brief Get singleton instance */ + static KernelManager &instance(); + + /*! \brief Compile CUDA kernel for current CUDA device + * + * The compiled kernel is cached and made available for launching. + * + * \param[in] kernel_label Unique identifying string for kernel + * \param[in] kernel_name Kernel name within source code + * \param[in] code Kernel source code + * \param[in] filename Path to associate with source code, + * primarily for debugging + */ + void compile(const std::string &kernel_label, const std::string &kernel_name, + const std::string &code, const std::string &filename); + + /*! \brief Whether CUDA kernel has been compiled for CUDA device + * + * \param[in] kernel_label Unique identifying string for kernel + * \param[in] device_id CUDA device (default is current device) + + * \return Whether kernel has been compiled + */ + bool is_compiled(const std::string &kernel_label, int device_id = -1) const; + + /*! \brief Launch CUDA kernel on current CUDA device + * + * Assumes the kernel has already been compiled. + * + * \param[in] kernel_label Unique identifying string for kernel + * \param[in] grid_dim Grid dimensions in blocks + * \param[in] block_dim Thread block dimensions + * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in + * bytes + * \param[in] stream CUDA stream + * \param[in] args Kernel arguments + */ + template + void launch(const std::string &kernel_label, const dim3 grid_dim, const dim3 block_dim, + unsigned int shared_mem_bytes, musaStream_t stream, ArgTs &&...args) { + const int device_id = cuda::current_device(); + const auto key = get_kernel_cache_key(kernel_label, device_id); + NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to launch RTC kernel before compilation"); + kernel_cache_.at(key).launch(device_id, grid_dim, block_dim, shared_mem_bytes, stream, + std::forward(args)...); + } + + /*! \brief Sets the preferred cache configuration for a function in the context + * + * Assumes the kernel has already been compiled. + * + * \param[in] kernel_label Unique identifying string for kernel + * \param[in] cache_config Prefered cache configuration + */ + void set_cache_config(const std::string &kernel_label, MUfunc_cache cache_config); + + private: + /*! \brief Compiled kernels */ + std::unordered_map kernel_cache_; + /*! \brief Mutex for thread-safe compilation */ + std::mutex lock_; + + KernelManager() = default; + ~KernelManager() = default; + KernelManager(const KernelManager &) = delete; + KernelManager &operator=(const KernelManager &) = delete; + + /*! \brief Construct key for kernel cache + * + * \param[in] kernel_label Unique identifying string for kernel + * \param[in] device_id CUDA device (default is current device) + * + * \return Key for kernel cache + */ + std::string get_kernel_cache_key(const std::string &kernel_label, int device_id) const; +}; + +} // namespace rtc + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_ diff --git a/transformer_engine/musa/common/util/string.h b/transformer_engine/musa/common/util/string.h new file mode 120000 index 0000000000..6e5a1f46ba --- /dev/null +++ b/transformer_engine/musa/common/util/string.h @@ -0,0 +1 @@ +../../../common/util/string.h \ No newline at end of file diff --git a/transformer_engine/musa/common/util/string_header.h.in b/transformer_engine/musa/common/util/string_header.h.in new file mode 120000 index 0000000000..88f8bc8333 --- /dev/null +++ b/transformer_engine/musa/common/util/string_header.h.in @@ -0,0 +1 @@ +../../../common/util/string_header.h.in \ No newline at end of file diff --git a/transformer_engine/musa/common/util/system.cpp b/transformer_engine/musa/common/util/system.cpp new file mode 120000 index 0000000000..f02555c600 --- /dev/null +++ b/transformer_engine/musa/common/util/system.cpp @@ -0,0 +1 @@ +../../../common/util/system.cpp \ No newline at end of file diff --git a/transformer_engine/musa/common/util/system.h b/transformer_engine/musa/common/util/system.h new file mode 120000 index 0000000000..9b87cbf2f6 --- /dev/null +++ b/transformer_engine/musa/common/util/system.h @@ -0,0 +1 @@ +../../../common/util/system.h \ No newline at end of file diff --git a/transformer_engine/musa/common/util/vectorized_pointwise.h b/transformer_engine/musa/common/util/vectorized_pointwise.h new file mode 100644 index 0000000000..dc24979cc1 --- /dev/null +++ b/transformer_engine/musa/common/util/vectorized_pointwise.h @@ -0,0 +1,597 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_ + +#include + +#include "../common.h" +#include "../utils.muh" + +namespace transformer_engine { + +/* \brief Helper class that enables storing multiple values of type DType + as 1 value of type LType. +*/ +template +class VectorizedStorage { + public: + using LType = typename transformer_engine::BytesToType::Type; + constexpr static int nvec = n; + union vectorized_storage { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + inline __device__ vectorized_storage() {} + inline __device__ ~vectorized_storage() {} + } scratch_; + + inline __device__ VectorizedStorage() {} + inline __device__ VectorizedStorage(const VectorizedStorage &y2) { + scratch_.aligned = y2.scratch_.aligned; + } + inline __device__ VectorizedStorage(const LType &y2) { scratch_.aligned = y2; } + inline __device__ VectorizedStorage &operator+=( + const VectorizedStorage &rhs) { +#pragma unroll + for (int i = 0; i < nvec; ++i) { + scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]); + } + return *this; + } + inline __device__ ~VectorizedStorage() {} + + /* \brief Access to separate elements. */ + inline __device__ DType *separate() { return scratch_.separate; } + + inline __device__ const DType *separate() const { return scratch_.separate; } + + inline __device__ LType &aligned() { return scratch_.aligned; } +}; + +// Returns const LType is DType is const +template +struct select_const { + using type = LType; +}; + +template +struct select_const { + using type = const LType; +}; + +/* \brief Helper class that enables accessing multiple values of type DType + as 1 value of type LType. Additional aligned template argument + allows performance optimizations if the pointer and the size of + the allocation is aligned to sizeof(LType) / sizeof(DType) elements. +*/ +template +class VectorizedAccessor { + public: + using StorageType = VectorizedStorage::type, nvec>; + using LType = typename select_const::type; + StorageType storage_; + + LType *aligned_ptr_; + DType *unaligned_ptr_; + int alignment_; + size_t n_elems_; + + inline __device__ VectorizedAccessor(DType *const ptr, const size_t size) { + unaligned_ptr_ = ptr; + if (aligned) { + alignment_ = 0; + aligned_ptr_ = reinterpret_cast(ptr); + n_elems_ = (size + nvec - 1) / nvec; + } else { + size_t ptr_as_number = reinterpret_cast(ptr); + alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); + aligned_ptr_ = reinterpret_cast(ptr - alignment_); + n_elems_ = (size + alignment_ + nvec - 1) / nvec; + } + } + + /* \brief Alignment of the input pointer in elements. */ + inline __device__ int alignment() const { return alignment_; } + + /* \brief Access to separate elements. */ + inline __device__ DType *separate() { return storage_.scratch_.separate; } + + /* \brief Number of aligned elements that span the entire input tensor. */ + inline __device__ size_t num_aligned_elements() const { return n_elems_; } + + /* \brief Load values from the input. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void load(const size_t id, const size_t N) { + if (aligned) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { + if (id > 0 && id < n_elems_ - 1) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType *ptr = reinterpret_cast(&(aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(unaligned_ptr_ + N)) { + storage_.scratch_.separate[j] = *ptr; + } else { + storage_.scratch_.separate[j] = DType(); + } + } + } + } + } +}; + +/* \brief Class used for vectorized read-only access. */ +template +class VectorizedLoader : public VectorizedAccessor { + public: + inline __device__ VectorizedLoader(const DType *ptr, const size_t N) + : VectorizedAccessor(ptr, N) {} +}; + +/* \brief Class used for vectorized writable access. */ +template +class VectorizedStorer : public VectorizedAccessor { + public: + inline __device__ VectorizedStorer(DType *ptr, const size_t N) + : VectorizedAccessor(ptr, N) {} + + /* \brief Store values to the output. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void store(const size_t id, const size_t N) { + if (aligned) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { + if (id > 0 && id < this->n_elems_ - 1) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType *ptr = reinterpret_cast(&(this->aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(this->unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(this->unaligned_ptr_ + N)) { + *ptr = this->storage_.scratch_.separate[j]; + } + } + } + } + } +}; + +constexpr int unary_kernel_threads = 512; + +template +__launch_bounds__(unary_kernel_threads) __global__ + void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, + const size_t N, const size_t num_aligned_elements) { + if (noop != nullptr && noop[0] == 1.0f) return; + + VectorizedLoader loader(input, N); + VectorizedStorer storer(output, N); + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + const size_t M = num_aligned_elements; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { + loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const ComputeType val = static_cast(loader.separate()[i]); + ComputeType temp = OP(val, p); + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + + temp = temp * s; + } + + storer.separate()[i] = static_cast(temp); + } + storer.store(tid, N); + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } +} + +template +__launch_bounds__(unary_kernel_threads) __global__ + void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, + Param p, const size_t N, const size_t num_aligned_elements) { + VectorizedLoader loader(input, N); + VectorizedLoader grad_loader(grad, N); + VectorizedStorer storer(output, N); + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + const size_t M = num_aligned_elements; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { + loader.load(tid, N); + grad_loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const ComputeType val = static_cast(loader.separate()[i]); + const ComputeType g = static_cast(grad_loader.separate()[i]); + ComputeType temp = OP(val, p) * g; + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + + temp = temp * s; + } + + storer.separate()[i] = static_cast(temp); + } + storer.store(tid, N); + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } +} + +namespace { + +inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, const int nvec, + const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + int alignment = (ptr_as_number % (nvec * size)) / size; + return DIVUP(lead_dim + alignment, static_cast(nvec)); +} + +enum class Alignment { + SAME_ALIGNED, // All tensors aligned + SAME_UNALIGNED, // All tensors have the same misalignment + DIFFERENT // Tensors have different alignment +}; + +inline int CalcAlignment(const void *ptr, const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + return ptr_as_number % size; +} + +/* \brief Check alignment of the inputs and outputs when using vectorized accesses. + \param lead_dim Leading dimension of the tensors. + \param other_dim The size of the other dimensions of the tensors. + \param nvec Length of the vector. + \param ptrs Inputs and Outputs to the operator. +*/ +template +Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) { + std::vector alignments; + alignments.reserve(sizeof...(T)); + + // calculate the alignments of all ptrs and store them into alignments + (..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec))); + + bool all_same = std::all_of(alignments.cbegin(), alignments.cend(), + [alignments](int val) { return val == alignments.front(); }); + if (!all_same) { + return Alignment::DIFFERENT; + } + + if (alignments.front() == 0 && lead_dim % nvec == 0) { + // all alignment are 0 + return Alignment::SAME_ALIGNED; + } else { + return Alignment::SAME_UNALIGNED; + } +} + +} // namespace + +template +void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, + const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, + const Param params, musaStream_t stream) { + if (N != 0) { + auto align = CheckAlignment(N, nvec, input, output); + + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); + constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); + constexpr size_t max_blocks = 65535; + num_blocks = std::min(num_blocks, max_blocks); + + switch (align) { + case Alignment::SAME_ALIGNED: + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + unary_kernel<<>>( + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // If the pointers are aligned differently we cannot vectorize + unary_kernel<1, true, fp32, Param, OP><<>>( + input, noop, output, scale, amax, scale_inv, params, N, N); + break; + } + } + } +} + +template +void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, + OutputType *output, const fp32 *scale, fp32 *amax, + fp32 *scale_inv, const size_t N, const Param params, + musaStream_t stream) { + if (N != 0) { + auto align = CheckAlignment(N, nvec, input, grad, output); + + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); + constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); + constexpr size_t max_blocks = 65535; + num_blocks = std::min(num_blocks, max_blocks); + + switch (align) { + case Alignment::SAME_ALIGNED: + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + unary_grad_kernel<<>>( + grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // If the pointers are aligned differently we cannot vectorize + unary_grad_kernel<1, true, fp32, Param, OP><<>>( + grad, input, output, scale, amax, scale_inv, params, N, N); + break; + } + } + } +} + +template +__launch_bounds__(unary_kernel_threads) __global__ + void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale, + ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, + const Param p, const size_t num_aligned_elements) { + const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { + const size_t id_x = tid % num_aligned_elements; + const size_t id_y = tid / num_aligned_elements; + VectorizedLoader loader0(input + id_y * n * 2, n); + VectorizedLoader loader1(input + id_y * n * 2 + n, n); + VectorizedStorer storer(output + id_y * n, n); + + loader0.load(id_x, n); + loader1.load(id_x, n); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const ComputeType val = static_cast(loader0.separate()[i]); + const ComputeType val2 = static_cast(loader1.separate()[i]); + ComputeType temp = static_cast(Activation(val, p) * val2); + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(temp), max); + temp = temp * s; + } + storer.separate()[i] = static_cast(static_cast(temp)); + } + storer.store(id_x, n); + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } +} + +template +void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, + fp32 *amax, fp32 *scale_inv, const size_t m, const size_t n, + const Param &p, musaStream_t stream) { + if (m != 0 && n != 0) { + size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); + constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements * m, threads); + constexpr size_t max_blocks = 65535; + num_blocks = std::min(num_blocks, max_blocks); + + switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { + case Alignment::SAME_ALIGNED: + gated_act_kernel + <<>>(input, output, scale, amax, scale_inv, m, n, p, + num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + gated_act_kernel + <<>>(input, output, scale, amax, scale_inv, m, n, p, + num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // If the pointers are aligned differently we cannot vectorize + gated_act_kernel<1, true, ComputeType, Param, Activation> + <<>>(input, output, scale, amax, scale_inv, m, n, p, n); + break; + } + } + } +} + +template +__launch_bounds__(unary_kernel_threads) __global__ + void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, + const size_t m, const size_t n, const Param p, + const size_t num_aligned_elements) { + const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { + const size_t id_x = tid % num_aligned_elements; + const size_t id_y = tid / num_aligned_elements; + VectorizedLoader grad_loader(grad + id_y * n, n); + VectorizedLoader input_loader0(input + id_y * n * 2, n); + VectorizedLoader input_loader1(input + id_y * n * 2 + n, n); + VectorizedStorer storer0(output + id_y * n * 2, n); + VectorizedStorer storer1(output + id_y * n * 2 + n, n); + + grad_loader.load(id_x, n); + input_loader0.load(id_x, n); + input_loader1.load(id_x, n); + +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const ComputeType grad_val = static_cast(grad_loader.separate()[i]); + const ComputeType gelu_in = static_cast(input_loader0.separate()[i]); + const ComputeType gate_in = static_cast(input_loader1.separate()[i]); + + ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; + ComputeType after_dgate = grad_val * Activation(gelu_in, p); + + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(after_dgelu), max); + after_dgelu = after_dgelu * s; + max = fmaxf(fabsf(after_dgate), max); + after_dgate = after_dgate * s; + } + + storer0.separate()[i] = static_cast(after_dgelu); + storer1.separate()[i] = static_cast(after_dgate); + } + storer0.store(id_x, n); + storer1.store(id_x, n); + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } +} + +template +void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input, + OutputType *output, const fp32 *scale, fp32 *amax, + fp32 *scale_inv, const size_t m, const size_t n, const Param &p, + musaStream_t stream) { + if (m != 0 && n != 0) { + size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType)); + constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements * m, threads); + constexpr size_t max_blocks = 65535; + num_blocks = std::min(num_blocks, max_blocks); + + switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { + case Alignment::SAME_ALIGNED: + dgated_act_kernel + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + dgated_act_kernel + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // If the pointers are aligned differently we cannot vectorize + dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, n); + break; + } + } + } +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_ diff --git a/transformer_engine/musa/common/utils.muh b/transformer_engine/musa/common/utils.muh new file mode 100644 index 0000000000..6408ebad5d --- /dev/null +++ b/transformer_engine/musa/common/utils.muh @@ -0,0 +1,989 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ +#define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ + +#include +#include +#include + +#if !defined(__MUSACC_RTC__) +#include +#else +// Importing C++ standard headers is a pain with NVRTC +using uint8_t = unsigned char; +using uint16_t = unsigned short int; // NOLINT(*) +using uint32_t = unsigned int; +using uint64_t = unsigned long long int; // NOLINT(*) +static_assert(sizeof(uint8_t) == 1); +static_assert(sizeof(uint16_t) == 2); +static_assert(sizeof(uint32_t) == 4); +static_assert(sizeof(uint64_t) == 8); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr uint32_t THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) + return {a.x + b.x, a.y + b.y}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*) + a.x += b.x; + a.y += b.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sum { + inline __device__ Sum() {} + inline __device__ T operator()(const T &a, const T &b) const { return a + b; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) { + return __shfl_xor_sync(static_cast(-1), x, idx); +} + +template <> +inline __device__ float2 warp_shuffle_xor(const float2 &x, uint32_t idx) { + return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)}; +} + +template +inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) { + return __shfl_down_sync(static_cast(-1), x, idx); +} + +template <> +inline __device__ float2 warp_shuffle_down(const float2 &x, uint32_t idx) { + return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace transformer_engine { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint16 { + uint4 u; + uint4 v; + uint4 s; + uint4 t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint8 { + uint4 u; + uint4 v; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BytesToType {}; + +template <> +struct BytesToType<64> { + using Type = uint16; + static_assert(sizeof(Type) == 64); +}; + +template <> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template <> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template <> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template <> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template <> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template <> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeToVec2 {}; + +template <> +struct TypeToVec2 { + using Type = float2; +}; + +template <> +struct TypeToVec2 { + using Type = half2; +}; + +template <> +struct TypeToVec2<__mt_bfloat16> { + using Type = __mt_bfloat162; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CTDBiasDActParam { + using InputType = IType; + using InputType2 = IType2; + using OutputType = OType; + using ComputeType = CType; + const IType *input; + const IType2 *act_input; + OType *output_c; + OType *output_t; + const CType *scale_ptr; + CType *amax; + CType *scale_inv; + CType *workspace; + CType *warp_scales_inv; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Get { + template + static inline __device__ R of(const T &vec); +}; + +template <> +template +inline __device__ R Get<0>::of(const T &vec) { + return vec.x; +} + +template <> +template +inline __device__ R Get<1>::of(const T &vec) { + return vec.y; +} + +template <> +template +inline __device__ R Get<2>::of(const T &vec) { + return vec.z; +} + +template <> +template +inline __device__ R Get<3>::of(const T &vec) { + return vec.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter { + static inline __device__ Dst convert(const Src &from) { return Dst(from); } +}; + +template <> +struct Converter { + static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); } +}; + +template <> +struct Converter { + static inline __device__ __mt_bfloat162 convert(const float2 &x) { +#if 1 + return __float22bfloat162_rn(x); +#else + union { + __mt_bfloat162 raw; + __mt_bfloat16 elt[2]; + } tmp; + tmp.elt[0] = __float2bfloat16_rn(x.x); + tmp.elt[1] = __float2bfloat16_rn(x.y); + return tmp.raw; +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Zeros { + static inline __device__ T get() { return T(0.f); } +}; + +template <> +struct Zeros { + static inline __device__ float2 get() { return make_float2(0.f, 0.f); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Vec { + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + + using Vec_type = typename BytesToType::Type; + using type = Elt_type; + + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; + + Alias_type data; + + template + inline __device__ void to(Vec &other) { // NOLINT(*) +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + other.data.elt[it] = S(this->data.elt[it]); + } + } + + __device__ Vec() = default; + + __device__ Vec(const Elt_type& num) { +#pragma unroll + for (int i = 0; i < NUM_ELT; i++) { + this->data.elt[i] = num; + } + } + + __device__ Vec& operator=(const Elt_type& num) { +#pragma unroll + for (int i = 0; i < NUM_ELT; i++) { + this->data.elt[i] = num; + } + return *this; + } + + template + inline __device__ void assign(const Op &op) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = op(it); + } + } + + // Pointer is cast to vector type + inline __device__ void load_from(const void *base_ptr, size_t idx = 0) { + this->data.vec = static_cast(base_ptr)[idx]; + } + + // Pointer is cast to vector type + inline __device__ void store_to(void *base_ptr, size_t idx = 0) const { + static_cast(base_ptr)[idx] = this->data.vec; + } + + // Pointer is cast to element type. Loads min(count, NUM_ELT) + // elements and any remaining elements are set to zero. + inline __device__ void load_from_elts(const void *base_ptr, size_t idx = 0, + size_t count = NUM_ELT) { + const Elt_type *elt_ptr = static_cast(base_ptr) + idx; + if (count < NUM_ELT || reinterpret_cast(elt_ptr) % BYTES != 0) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = (it < count ? elt_ptr[it] : Elt_type(0.f)); + } + } else { + this->load_from(elt_ptr); + } + } + + // Pointer is cast to element type. Stores min(count, NUM_ELT) + // elements. + inline __device__ void store_to_elts(void *base_ptr, size_t idx = 0, + size_t count = NUM_ELT) const { + Elt_type *elt_ptr = static_cast(base_ptr) + idx; + if (count < NUM_ELT || reinterpret_cast(elt_ptr) % BYTES != 0) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + if (it < count) { + elt_ptr[it] = this->data.elt[it]; + } + } + } else { + this->store_to(elt_ptr); + } + } + + inline __device__ void clear() { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = Elt_type(0.f); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct InterCTASync { + inline __device__ InterCTASync(int *barrier, int group, int num_groups, int group_size) + : phase_counter_(0), + b0_(barrier + group) // The barrier for this group of CTAs. + , + b1_(barrier + group + num_groups) // The barrier for this group of CTAs. + , + group_size_(group_size) { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int *barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for (int found = -1; found != expected;) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); + } + } + + inline __device__ void sync() { + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : group_size_; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if (threadIdx.x == 0) { + spin_wait_(barrier, step, expected); + } + // CTA waits for thread 0 + __syncthreads(); + } + + int phase_counter_; + int *b0_; + int *b1_; + int group_size_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { + WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES + }; + + template + inline __device__ Reducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), + inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {} + + template + inline __device__ T allreduce(T data, const Op &op) { + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if (this->warp_n_ == 0 && this->lane_ == 0) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + static_assert(CTAS_PER_ROW <= 32); + T total = Zeros::get(); + if (this->lane_ < CTAS_PER_ROW) { + total = workspace[this->lane_]; + } + total = Reducer::allreduce_(total, op); + + return total; + } + + InterCTASync inter_cta_; + + T *const w0_; + T *const w1_; + int bidn_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer { + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : warp_n_(warp_n), lane_(lane) {} + + template + static inline __device__ T allreduce_(T data, const Op &op) { +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + data = op(data, warp_shuffle_xor(data, it)); + } + return data; + } + + template + inline __device__ T allreduce(T data, const Op &op) { + return allreduce_(data, op); + } + + template + inline __device__ T reduce(T data, const Op &op) { +// only lane 0 holds the result! +#pragma unroll + for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) { + data = op(data, warp_shuffle_down(data, it)); + } + return data; + } + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + using Base = Reducer; + + using Type = T; + + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), + use0_(true), + smem0_(&(static_cast(smem)[warp_m * WARPS_N])), + smem1_(smem0_ + WARPS_M * WARPS_N) {} + + template + inline __device__ T allreduce(T data, const Op &op) { + T *const smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); + } + return out; + } + + template + inline __device__ T reduce(T data, const Op &op) { + T *const smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + if (this->warp_n_ == 0 && this->lane_ == 0) { +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); + } + } + return out; + } + + T *const smem0_; + T *const smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct DynamicReducer : public Reducer { + using Base = Reducer; + using Type = typename Base::Type; + + template + inline __device__ DynamicReducer(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), + inter_cta_(params.barrier, bidm, params.ctas_per_col, params.ctas_per_row), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * params.ctas_per_row), + w1_(w0_ + params.ctas_per_col * WARPS_M * params.ctas_per_row) {} + + template + inline __device__ T allreduce(T data, const Op &op) { + // Trivial case + if (inter_cta_.group_size_ == 1) { + return Base::allreduce(data, op); + } + + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if (this->warp_n_ == 0 && this->lane_ == 0) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + T total = Zeros::get(); + for (int it = this->lane_; it < inter_cta_.group_size_; it += THREADS_PER_WARP) { + total = op(total, workspace[it]); + } + total = Reducer::allreduce_(total, op); + + return total; + } + + template + inline __device__ T reduce(T data, const Op &op) { + return allreduce(data, op); + } + + InterCTASync inter_cta_; + + T *const w0_; + T *const w1_; + int bidn_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* +This is an implementation of the parallel Welford algorithm for incrementally computing variance + +This algorithm is known as Chan's update formulae (Chat et al '79): +http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf + +An introduction is provided by Wikipedia here: +https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance?section=5#Parallel_algorithm + +A detailed reference on the exact version implemented (with better numerical stability) is provided here: +https://dbs.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf +*/ + +template +inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, + int num_active) { // NOLINT(*) + // Assume at least leftmost is valid and + // init: step = next_pow2(num_active) / 2 (might get NaN otherwise) + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + +#pragma unroll + for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + // Exchange + T n_b = warp_shuffle_down(n_a, step); + T m_b = warp_shuffle_down(m_a, step); + T m2_b = warp_shuffle_down(m2_a, step); + + // Update + const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. + // Might have different n per thread, otherwise this would simplify :( + const T rn_ab = 1.f / n_ab; + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(static_cast(-1), m_a, 0); + m2_a = __shfl_sync(static_cast(-1), m2_a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + // This could be done generically with the Reducer. But then we + // would have to exchange 3 instead of 2 fields. + + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW), + block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), + bidn_(bidn) // CTA id within the group. + , + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + warp_n_(warp_n), + lane_(lane) {} + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO(ptredak) rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = block_stats_.compute(elts, block_rn); + + stats_t *const workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + if (warp_n_ == 0 && lane_ == 0) { + workspace[bidn_] = block_stats; + } + + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume CTA group size in N less than 32, such that we can finalize with a single warp. + static_assert(CTAS_PER_ROW <= 32); + + // Every warp does the final reduction locally. + if (lane_ < CTAS_PER_ROW) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = transformer_engine::Get<0>::of(result); + m2 = transformer_engine::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + + return {m, m2}; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + + stats_t *const w0_; + stats_t *const w1_; + int bidn_; + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + stats_t *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); + stats_t warp_stats = warp_stats_.compute(elts, warp_rn); + + // Each warp warp leader stores its stats + const auto warp_n = warp_stats_.reducer_.warp_n_; + const auto lane = warp_stats_.reducer_.lane_; + if (lane == 0) { + smem[warp_n] = warp_stats; + } + __syncthreads(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume that there are less than 32 warps, such that we can finalize with a single warp + static_assert(WARPS_N <= 32); + if (lane < WARPS_N) { + stats_t result = smem[lane]; + n = N * THREADS_PER_WARP; + m = transformer_engine::Get<0>::of(result); + m2 = transformer_engine::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return {m, m2}; + } + WarpStats warp_stats_; + stats_t *smem0_; + stats_t *smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; + + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats(const Params ¶ms, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t lane, void *smem) + : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {} + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + auto sum = Sum(); + + T m = Zeros::get(); +#pragma unroll + for (int it = 0; it < N; it++) { + m += elts[it]; + } + m = reducer_.allreduce(m, sum) * rn; + + T m2 = Zeros::get(); +#pragma unroll + for (int it = 0; it < N; it++) { + T diff = (elts[it] - m); + m2 += diff * diff; + } + m2 = reducer_.allreduce(m2, sum); + + return {m, m2}; + } + + Reducer reducer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ float warp_reduce_max(const float m) { + float tmp = m; +#pragma unroll + for (int delta = num_elems / 2; delta > 0; delta /= 2) { + const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta); + __builtin_assume(tmp >= 0); + __builtin_assume(other_m >= 0); + tmp = fmaxf(tmp, other_m); + } + return tmp; +} + +__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero); + return val_tmp; +} + +template +__device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) { + __shared__ float staging[num_warps]; + constexpr int warp_size = 32; + const float my_max = m; + const float my_warp_max = warp_reduce_max(my_max); + if (threadIdx.x % 32 == 0) { + staging[warpid] = my_warp_max; + } + __syncthreads(); + compute_t result = 0; + if (warpid == 0) { + const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; + result = warp_reduce_max(my_max); + } + return result; +} + +/** + * Max reduction in subwarps + * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors. + * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate, + * thus splitting the warp into 4x smaller subwarps 8-thread width. + * 'Butterfly' reduction is used inside subwarps. + */ +template +__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = subwarp_width / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width); + return val_tmp; +} + +// Works only on positive values +__device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) { + atomicMax(reinterpret_cast(addr), __float_as_int(value)); +} + +// Works only on positive values +__device__ __forceinline__ void atomicMinFloat(float *addr, const float value) { + atomicMin(reinterpret_cast(addr), __float_as_int(value)); +} + +template +__device__ __forceinline__ void reciprocal(T *value_inv, const T value) { + *value_inv = 1 / value; +} + +template <> +__device__ __forceinline__ void reciprocal(float *value_inv, const float value) { + *value_inv = __frcp_rn(value); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp8e4m3 = __mt_fp8_e4m3; +using fp8e5m2 = __mt_fp8_e5m2; +using e8m0_t = uint8_t; + +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; + +template +struct Numeric_Traits; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 8; + static constexpr double maxNorm = 448; +}; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 15; + static constexpr double maxNorm = 57344; +}; + +template +struct Quantized_Limits { + static constexpr int max_unbiased_exponent = Numeric_Traits::maxUnbiasedExponent; + static constexpr float max_norm = Numeric_Traits::maxNorm; + static constexpr float max_norm_rcp = 1.0 / max_norm; + static constexpr float emax = 1 << max_unbiased_exponent; + static constexpr float emax_rcp = 1.0 / emax; +}; + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } +#if 0 + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/musa/pytorch/__init__.py b/transformer_engine/musa/pytorch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/transformer_engine/musa/pytorch/attention.py b/transformer_engine/musa/pytorch/attention.py new file mode 100644 index 0000000000..afc29aba9f --- /dev/null +++ b/transformer_engine/musa/pytorch/attention.py @@ -0,0 +1,1066 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Attention.""" +from contextlib import nullcontext +from importlib.metadata import version as get_pkg_version +import math +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import logging +from packaging.version import Version as PkgVersion + +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.utils import get_cudnn_version + +from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.float8_tensor import Float8Tensor + +from transformer_engine.pytorch.constants import ( + AttnMaskTypes, + AttnTypes, + QKVLayouts, + dist_group_type, +) +from transformer_engine.pytorch.distributed import ( + get_distributed_world_size, + set_all_rng_states, + CudaRNGStatesTracker, + graph_safe_rng_available, +) +from transformer_engine.pytorch.graph import is_graph_capturing +_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) +_flash_attn_version_required = PkgVersion("2.0.6") +_flash_attn_max_version = PkgVersion("2.6.8") +_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") +_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") +_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") + + +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + check_set_window_size, + get_cu_seqlens, + _get_full_cu_seqlens, + get_alibi, + get_qkv_layout, + +) + +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + check_set_window_size, + get_cu_seqlens, + _get_full_cu_seqlens, + get_alibi, + get_qkv_layout, + InferenceParams, + get_cu_seqlens_and_indices, + UnpackTensor, + get_indices, + PackTensors, + AttentionParams, + get_attention_backend, +) + +from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + UnfusedDotProductAttention, + FusedAttention, + _PrepareQKVForFA, + + # FlashAt +) + +from transformer_engine.pytorch.attention.dot_product_attention import ( + _attention_backends, +) + + +from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled +_flash_attn_3_is_installed = False +_flash_attn_3_version = PkgVersion("0") +# HACK(huang.huang): recompute-variance for fa: import packages +from transformer_engine.pytorch.attention.dot_product_attention import ( + DotProductAttention +) +from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + FlashAttention +) +# HACK(huang.huang): + + +# HACK(huang.huang): recompute-variance for fa: implement flash_attn_varlen_func_variance +# which will return [coreattention_output, lse, ...] instead of coreattention_output only; +# and will seperate the execution of the sdp_kernel from other operations before and after it +_MIN_MUSA_DIM = 64 +_MAX_MUSA_DIM = 192 +def flash_attn_varlen_func_variance( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + # The input shape of varlen flash is [bs x seq_len, nheads, head_dim] + # but the input of sdp is [bs, nheads, seq_len, head_dim] + # seq_len = max_seqlen_q + # bs = q.shape[0] // seq_len + head_dim= q.shape[-1] + if head_dim >= _MIN_MUSA_DIM and head_dim <= _MAX_MUSA_DIM: + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): + attn_output = torch.ops.aten._scaled_dot_product_attention_flash_musa( + q, + k, + v, + dropout_p=dropout_p, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=causal #self.is_causal and attention_mask is None and q_len > 1, + ) + return attn_output + else: + raise NotImplementedError(f"head_dim={head_dim} is not supported by flash_attn_varlen_func_variance") +# HACK(huang.huang) + + +# HACK(huang.huang): recompute-variance for fa: modify __init__ for FlashAttention and DotProductAttention, +# just add a attr "recompute_variance" for them +def FlashAttention__init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = nullcontext, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + recompute_variance: bool = False, # MUSA patch: support recompute_variance +) -> None: + super(FlashAttention, self).__init__() + + assert ( + _flash_attn_version >= _flash_attn_version_required + ), f"FlashAttention minimum version {_flash_attn_version_required} is required." + assert ( + _flash_attn_version <= _flash_attn_max_version + ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." + + self.softmax_scale = softmax_scale + self.attention_dropout_ctx = attention_dropout_ctx + self.attention_dropout = attention_dropout + self.attention_type = attention_type + self.layer_number = 1 if layer_number is None else layer_number + self.deterministic = deterministic + self.recompute_variance = recompute_variance # MUSA patch: support recompute_variance + + +def DotProductAttention__init__( + self, + num_attention_heads: int, + kv_channels: Union[int, Tuple[int, int]], + num_gqa_groups: Optional[int] = None, + attention_dropout: float = 0.0, + qkv_format: str = "sbhd", + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + sequence_parallel: bool = False, + tp_size: int = 1, + get_rng_state_tracker: Optional[Callable] = None, + tp_group: Optional[dist_group_type] = None, + layer_number: Optional[int] = None, + attention_type: str = "self", + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, + cp_global_ranks: List[int] = None, + cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", + softmax_scale: Optional[float] = None, + recompute_variance: bool = False, # MUSA patch: support for variance computation +) -> None: + super(DotProductAttention, self).__init__() + + self.logger = logging.getLogger("DotProductAttention") + # self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) + self.qkv_format = qkv_format + attn_mask_type = attn_mask_type.replace(",", "_") + if attn_mask_type == "causal_padding": + attn_mask_type = "padding_causal" + self.attn_mask_type = attn_mask_type + self.window_size = check_set_window_size(attn_mask_type, window_size) + if tp_group is None: + self.tp_size = tp_size + if tp_size == 1: + self.set_tensor_parallel_group(tp_group) + else: + self.tp_size = get_distributed_world_size(tp_group) + self.set_tensor_parallel_group(tp_group) + self.get_rng_state_tracker = get_rng_state_tracker + self.num_attention_heads = num_attention_heads + self.layer_number = 1 if layer_number is None else layer_number + self.cp_group = cp_group + self.cp_global_ranks = cp_global_ranks + self.cp_stream = cp_stream + self.cp_comm_type = cp_comm_type + + self.recompute_variance = recompute_variance # MUSA patch: support for variance computation + self.hidden_size_per_attention_head_k = ( + kv_channels if isinstance(kv_channels, int) else kv_channels[0] + ) + self.hidden_size_per_attention_head_v = ( + kv_channels if isinstance(kv_channels, int) else kv_channels[1] + ) + + self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups + self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) + + assert ( + num_attention_heads % self.num_gqa_groups == 0 + ), "The number of attention heads must be divisible by the number of GQA groups!" + + self.rng_states_tracker = None + if sequence_parallel or get_rng_state_tracker is None: + attention_dropout_ctx = nullcontext + else: + self.rng_states_tracker = get_rng_state_tracker() + set_all_rng_states(self.rng_states_tracker.get_states()) + attention_dropout_ctx = self.rng_states_tracker.fork + + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt( + kv_channels if isinstance(kv_channels, int) else kv_channels[0] + ) + + self.deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() + ) + # To use the workspace optimization path for determinism, please + # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0, + # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0. + cudnn_version = get_cudnn_version() + if (8, 9, 5) <= cudnn_version < (9, 0, 0): + if self.deterministic: + os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" + + # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT + # - unset: enables workspace optimization when required workspace is <= 256MB + # or when bias gradient needs to be computed + # - n: enables workspace optimization when required workspace is <= n bytes + # - -1: enables workspace optimization always + # - 0: disables workspace optimization always + if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ: + if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0": + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" + if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + + assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" + + self.attention_type = attention_type + self.attention_dropout = attention_dropout + + attn_kwargs = { + "attention_dropout": attention_dropout, + "attention_dropout_ctx": attention_dropout_ctx, + } + + self.flash_attention = FlashAttention( + softmax_scale, + attention_type=attention_type, + layer_number=layer_number, + deterministic=self.deterministic, + recompute_variance=self.recompute_variance, # MUSA patch: support for variance computation + **attn_kwargs, + ) + + # Instantiating three types since use of flash-attn and FusedAttention + # might be ruled out due to forward inputs. + self.fused_attention = FusedAttention( + softmax_scale, + attention_type=attention_type, + layer_number=layer_number, + deterministic=self.deterministic, + **attn_kwargs, + ) + self.unfused_attention = UnfusedDotProductAttention( + softmax_scale, **attn_kwargs, layer_number=layer_number + ) + + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument + """ + Temporarily remove core_attention._extra_state as a missing key + when loading older Transformer Engine checkpoints. Will phase out + this hook in Transformer Engine 2.0. + """ + for key in incompatible_keys.missing_keys: + if "core_attention._extra_state" in key: + incompatible_keys.missing_keys.remove(key) + + self.register_load_state_dict_post_hook(remove_extra_states_check) +# HACK(huang.huang) + + +# HACK(huang.huang): recompute-variance for fa: add functions "forward_fa", "forward_after_fa", "forward_before_fa" for DotProductAttention +def FlashAttention_forward_after_fa(self, output, qkv_format, indices_q, batch_size, attn_mask_type, max_seqlen_q, q_shape, v_shape): + bs = q_shape[0] + q_seq_len = q_shape[1] + output = output[0].transpose(1, 2).contiguous().view(bs, q_seq_len, q_shape[-2], v_shape[-1]) #core_output, args* + if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: + output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + + if qkv_format == "sbhd": + # (bs)hd -> bs(hd) -> sb(hd) + output = output.view(batch_size, max_seqlen_q, -1).transpose(0, 1).contiguous() + elif qkv_format == "bshd": + # (bs)hd -> bs(hd) + output = output.view(batch_size, max_seqlen_q, -1).contiguous() + elif qkv_format == "thd": + # thd -> t(hd) + output = output.view(output.shape[0], -1).contiguous() + return output + +def FlashAttention_forward_fa( + self, + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + attn_mask_type, + window_size, + alibi_slopes, + qkv_format, + indices_q, + batch_size, + q_shape, + v_shape, + *args, + ): + with self.attention_dropout_ctx(): + fa_optional_forward_kwargs = {} + if _flash_attn_2_3_plus: + fa_optional_forward_kwargs["window_size"] = window_size + if _flash_attn_2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes + if _flash_attn_2_4_1_plus: + fa_optional_forward_kwargs["deterministic"] = self.deterministic + output = flash_attn_varlen_func_variance( + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, + ) + return output, qkv_format, indices_q, batch_size, attn_mask_type, max_seqlen_q, q_shape, v_shape + +def FlashAttention_forward_before_fa( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, + cp_global_ranks: List[int] = None, + cp_stream: torch.cuda.Stream = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, +) -> torch.Tensor: + """flash-attn fprop""" + + assert all( + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + for x in [query_layer, key_layer, value_layer] + ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors." + assert ( + query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda + ), "FlashAttention currently only supports CUDA tensors." + assert ( + qkv_layout in QKVLayouts + ), f"FlashAttention does not support qkv_layout = {qkv_layout}!" + + cp_size = 1 + if isinstance(cp_group, dist_group_type): + cp_size = get_distributed_world_size(cp_group) + elif isinstance(cp_group, list): + for group in cp_group: + cp_size *= get_distributed_world_size(group) + context_parallel = cp_size > 1 + + qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + + if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + if qkv_format == "sbhd": + # For now just 128, will make it more general in the future + if ( + query_layer.shape[-1] == 128 + and query_layer.shape[0] * query_layer.shape[1] >= 512 + and qkv_layout == "sbh3d" + ): + query_layer, key_layer, value_layer = _PrepareQKVForFA.apply( + query_layer, key_layer, value_layer + ) + else: + query_layer, key_layer, value_layer = [ + x.transpose(0, 1) for x in (query_layer, key_layer, value_layer) + ] + if context_parallel: + query_layer, key_layer, value_layer = [ + x.contiguous() for x in (query_layer, key_layer, value_layer) + ] + else: + if qkv_format == "sbhd": + query_layer._data, key_layer._data, value_layer._data = [ + x.transpose(0, 1) + for x in (query_layer._data, key_layer._data, value_layer._data) + ] + query_layer, key_layer, value_layer = [ + Float8Tensor.make_like(x, data=x._data, shape=x._data.shape) + for x in (query_layer, key_layer, value_layer) + ] + if context_parallel: + query_layer._data, key_layer._data, value_layer._data = [ + x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) + ] + + batch_size = query_layer.shape[0] + + if qkv_format in ["sbhd", "bshd"]: + max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size + indices_q = None + if "padding" in attn_mask_type: + assert not context_parallel, "Padding mask not supported with context parallelism!" + # [b * s, h, d] + query_layer, key_layer, value_layer = [ + x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) + for x in [query_layer, key_layer, value_layer] + ] + + + if self.attention_type == "self": + assert ( + max_seqlen_q == max_seqlen_kv + ), "Maximum sequence length for Q and KV should be the same." + if cu_seqlens_q is None: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask) + else: + indices_q = get_indices(max_seqlen_q, cu_seqlens_q) + cu_seqlens_kv = cu_seqlens_q + query_layer, key_layer, value_layer = PackTensors.apply( + indices_q, query_layer, key_layer, value_layer + ) + else: + if cu_seqlens_q is None or cu_seqlens_kv is None: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0]) + cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) + else: + indices_q = get_indices(max_seqlen_q, cu_seqlens_q) + indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) + query_layer = PackTensors.apply(indices_q, query_layer) + key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer) + else: + # Cumulative sequence lengths for unpadded data + if cu_seqlens_q is None: + cu_seqlens_q = _get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, + ) + if cu_seqlens_kv is None: + cu_seqlens_kv = _get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + elif qkv_format == "thd": + assert ( + cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + if max_seqlen_q is None: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + max_seqlen_q = seqlens_q.max().item() + if max_seqlen_kv is None: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_kv = seqlens_kv.max().item() + + + if context_parallel and all( + not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] + ): + assert ( + alibi_slopes is None + ), "Alibi slope bias addition is not supported with context parallelism." + with self.attention_dropout_ctx(): + output = attn_forward_func_with_cp( + self.training, + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q if qkv_format == "thd" else None, + cu_seqlens_kv if qkv_format == "thd" else None, + self.attention_dropout if self.training else 0.0, + cp_group, + cp_global_ranks, + cp_stream, + cp_comm_type, + softmax_scale=self.softmax_scale, + qkv_format="bshd" if qkv_format == "sbhd" else qkv_format, + attn_mask_type=attn_mask_type, + deterministic=self.deterministic, + window_size=window_size, + quantizers=quantizers, + ) + else: + + from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled + + if CPUOffloadEnabled: + tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] + for tensor in tensor_list: + if tensor is not None: + tensor.activation_offloading = True + + + # transpose before fa, which will be saved for bwd + bs = query_layer.shape[0] + seq_len = query_layer.shape[1] + kv_seq_len = key_layer.shape[1] + # seq_len = max_seqlen_q + # bs = query_layer.shape[0] // seq_len + q_shape = query_layer.shape + v_shape = value_layer.shape + query_layer = query_layer.view(bs, seq_len, query_layer.shape[-2], query_layer.shape[-1]).transpose(1, 2) + key_layer = key_layer.view(bs, kv_seq_len, key_layer.shape[-2], key_layer.shape[-1]).transpose(1, 2) + value_layer = value_layer.view(bs, kv_seq_len, value_layer.shape[-2], value_layer.shape[-1]).transpose(1, 2) + + return ( + query_layer, + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + attn_mask_type, + window_size, + alibi_slopes, + qkv_format, + indices_q, + batch_size, + q_shape, + v_shape) + +def DotProductAttention_forward_fa( + self, + *args, + ): + return self.flash_attention.forward_fa(*args) + +def DotProductAttention_forward_after_fa(self, *args): + output = self.flash_attention.forward_after_fa(*args) + return output + +def DotProductAttention_forward_before_fa( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_format: Optional[str] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: Optional[str] = None, + window_size: Optional[Tuple[int, int]] = None, + checkpoint_core_attention: bool = False, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + fast_zero_fill: bool = True, + inference_params: Optional[InferenceParams] = None, +) -> torch.Tensor: + with self.prepare_forward( + query_layer, + num_gemms=3, + allow_non_contiguous=True, + ) as query_layer: + if self.fp8: + if self.fp8_meta["recipe"].fp8_mha: + if not self.fp8_meta["recipe"].fp8_dpa: + self.fp8_meta["recipe"].fp8_dpa = True + self.logger.warning( + """Forcing fp8_meta["recipe"].fp8_dpa=True due to """ + """fp8_meta["recipe"].fp8_mha=True""" + ) + + if self.fp8 and self.fp8_meta["recipe"].fp8_dpa: + forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) + backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False) + assert forward_dtype in [ + tex.DType.kFloat8E4M3, + tex.DType.kFloat8E5M2, + ] and backward_dtype in [ + tex.DType.kFloat8E4M3, + tex.DType.kFloat8E5M2, + ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" + + assert ( + query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda + ), "DotProductAttention only supports CUDA tensors." + assert ( + query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype + ), "Queries, keys and values must have the same data type!" + assert ( + key_layer.shape[:-1] == value_layer.shape[:-1] + ), "Keys and values must have the same batch size, sequence length and number of heads!" + assert ( + key_layer.shape[-1] == self.hidden_size_per_attention_head_k + ), f"Keys have head_dim = {key_layer.shape[-1]}, " + "but expected head_dim = {self.hidden_size_per_attention_head_k}!" + assert ( + value_layer.shape[-1] == self.hidden_size_per_attention_head_v + ), f"Values have head_dim = {value_layer.shape[-1]}, " + "but expected head_dim = {self.hidden_size_per_attention_head_v}!" + + if qkv_format is None: + qkv_format = self.qkv_format + + if attn_mask_type is None: + attn_mask_type = self.attn_mask_type + else: + attn_mask_type = attn_mask_type.replace(",", "_") + if attn_mask_type == "causal_padding": + attn_mask_type = "padding_causal" + assert ( + attn_mask_type in AttnMaskTypes + ), f"Attention mask type {attn_mask_type} is not supported!" + if qkv_format == "thd": + assert ( + "padding" in attn_mask_type + ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" + + if window_size is None: + window_size = self.window_size + window_size = check_set_window_size(attn_mask_type, window_size) + + if self.rng_states_tracker is not None and is_graph_capturing(): + assert isinstance( + self.rng_states_tracker, CudaRNGStatesTracker + ), "Unsupported RNG states tracker." + assert ( + graph_safe_rng_available() + ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." + + if inference_params is not None: + assert self.layer_number is not None, "Layer number must be set!" + + # convert causal to causal_bottom_right in inference when KV-caching is in use + # so users can run with the same attn_mask_type for training and inference + if attn_mask_type in ["causal", "padding_causal"]: + attn_mask_type = attn_mask_type + "_bottom_right" + + if qkv_format == "bshd": + key_layer = key_layer.transpose(0, 1) + value_layer = value_layer.transpose(0, 1) + + ( + inference_key_memory, + inference_value_memory, + ) = inference_params.key_value_memory_dict[self.layer_number] + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + + # Copy keys and values into KV-cache + inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( + key_layer + ) + inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( + value_layer + ) + key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + + if qkv_format == "bshd": + key_layer = key_layer.transpose(0, 1) + value_layer = value_layer.transpose(0, 1) + + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + assert ( + key_layer.shape[-2] == self.num_gqa_groups_per_partition + and value_layer.shape[-2] == self.num_gqa_groups_per_partition + ), ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads!" + ) + assert qkv_format in [ + "sbhd", + "bshd", + "thd", + ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" + + if qkv_format == "thd": + assert all( + len(x.shape) == 3 for x in (query_layer, key_layer, value_layer) + ), "Queries, keys and values must be 3D tensors when qkv_format = thd!" + assert ( + cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + assert ( + cu_seqlens_q.shape == cu_seqlens_kv.shape + and len(cu_seqlens_q.shape) == 1 + and len(cu_seqlens_kv.shape) == 1 + ), "cu_seqlens_q and cu_seqlens_q must both have shape [batch_size + 1]!" + assert ( + cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 + ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + batch_size = len(cu_seqlens_q) - 1 + if max_seqlen_q is None: + if cu_seqlens_q_padded is not None: + seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] + else: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64) + if max_seqlen_kv is None: + if cu_seqlens_kv_padded is not None: + seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] + else: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + + cp_size = 1 + if isinstance(self.cp_group, dist_group_type): + cp_size = get_distributed_world_size(self.cp_group) + elif isinstance(self.cp_group, list): + for group in self.cp_group: + cp_size *= get_distributed_world_size(group) + context_parallel = cp_size > 1 + + if qkv_format in ["sbhd", "bshd"]: + assert all( + len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) + ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" + if qkv_format == "sbhd": + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv + batch_size = query_layer.shape[1] + else: + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv + batch_size = query_layer.shape[0] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size + if cu_seqlens_q is not None: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + assert all( + seqlens_q <= max_seqlen_q + ), """Sequence lengths indicated by cu_seqlens_q must be no greater than + the sequence dimension in 'query_layer'!""" + if cu_seqlens_kv is not None: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + assert all( + seqlens_kv <= max_seqlen_kv + ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than + the sequence dimension in 'key_layer' and 'value_layer'!""" + if cu_seqlens_q is None or cu_seqlens_kv is None: + if "padding" in attn_mask_type: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + if self.attention_type == "self": + cu_seqlens_q = get_cu_seqlens(attention_mask) + cu_seqlens_kv = cu_seqlens_q + else: + cu_seqlens_q = get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + else: + cu_seqlens_q = _get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, + ) + cu_seqlens_kv = _get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + + if ( + isinstance(query_layer, Float8Tensor) + and isinstance(key_layer, Float8Tensor) + and isinstance(value_layer, Float8Tensor) + ): + qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout( + query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format + ) + else: + qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout( + query_layer, key_layer, value_layer, qkv_format=qkv_format + ) + + global _alibi_cache + if alibi_slopes is not None: + assert ( + core_attention_bias_type == "alibi" + ), "core_attention_bias_type must be alibi in order to use alibi_slopes!" + if self.layer_number == 1: + _alibi_cache["_alibi_slopes_require_update"] = True + _alibi_cache["_alibi_bias_require_update"] = True + bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) + if core_attention_bias_type == "alibi": + assert ( + core_attention_bias is None + ), "core_attention_bias must be None when core_attention_bias_type is alibi!" + if ( + _alibi_cache["_num_heads"] != query_layer.shape[-2] + or _alibi_cache["_max_seqlen_q"] != max_seqlen_q + or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv + or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment + or _alibi_cache["_alibi_slopes"] is None + ): + _alibi_cache["_alibi_slopes_require_update"] = True + _alibi_cache["_alibi_bias_require_update"] = True + + core_attention_bias_shape = None + if core_attention_bias is not None: + if ( + core_attention_bias.shape[0] == batch_size + and core_attention_bias.shape[1] == query_layer.shape[-2] + ): + core_attention_bias_shape = "bhss" + elif ( + core_attention_bias.shape[0] == 1 + and core_attention_bias.shape[1] == query_layer.shape[-2] + ): + core_attention_bias_shape = "1hss" + elif ( + core_attention_bias.shape[0] == batch_size and core_attention_bias.shape[1] == 1 + ): + core_attention_bias_shape = "b1ss" + elif core_attention_bias.shape[0] == 1 and core_attention_bias.shape[1] == 1: + core_attention_bias_shape = "11ss" + else: + assert ( + False + ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + + pad_between_seqs = ( + cu_seqlens_q_padded is not None + and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) + ) or ( + cu_seqlens_kv_padded is not None + and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) + ) + + attention_params = AttentionParams( + qkv_type=type(query_layer), + qkv_dtype=query_layer.dtype, + qkv_layout=qkv_layout, + batch_size=batch_size, + num_heads=query_layer.shape[-2], + num_gqa_groups=key_layer.shape[-2], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=query_layer.shape[-1], + head_dim_v=value_layer.shape[-1], + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias_shape=core_attention_bias_shape, + core_attention_bias_requires_grad=( + core_attention_bias.requires_grad if core_attention_bias is not None else False + ), + pad_between_seqs=pad_between_seqs, + attention_dropout=self.attention_dropout, + context_parallel=context_parallel, + deterministic=self.deterministic, + is_training=self.training, + fp8=self.fp8, + fp8_meta=self.fp8_meta, + ) + global _attention_backends, _use_flash_attn_3 + if ( + _attention_backends["attention_params"] is None + or attention_params != _attention_backends["attention_params"] + ): + _attention_backends["attention_params"] = attention_params + _attention_backends["backend_selection_requires_update"] = True + + if _attention_backends["backend_selection_requires_update"]: + _use_flash_attn_3 = _flash_attn_3_is_installed + ( + use_flash_attention, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + _, + ) = get_attention_backend(attention_params) + if use_flash_attention: + self.logger.info( + "Running with FlashAttention backend (version %s)", + _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version, + ) + elif use_fused_attention: + self.logger.info( + "Running with FusedAttention backend (sub-backend %s)", + int(fused_attention_backend), + ) + elif use_unfused_attention: + self.logger.info("Running with UnfusedDotProductAttention backend") + else: + use_flash_attention = _attention_backends["use_flash_attention"] + use_fused_attention = _attention_backends["use_fused_attention"] + fused_attention_backend = _attention_backends["fused_attention_backend"] + use_unfused_attention = _attention_backends["use_unfused_attention"] + + use_flash_attention = True #TODO:huang.huang set fa manually now! + if use_flash_attention: + if core_attention_bias_type == "alibi": + alibi_slopes, _ = get_alibi( + query_layer.shape[-2], + max_seqlen_q, + max_seqlen_kv, + alibi_slopes=alibi_slopes, + ) + return self.flash_attention.forward_before_fa( + query_layer, + key_layer, + value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=self.cp_group, + cp_global_ranks=self.cp_global_ranks, + cp_stream=self.cp_stream, + cp_comm_type=self.cp_comm_type, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + ) + + raise RuntimeError("No dot product attention support for the provided inputs!") +# HACK(huang.huang) + +from .utils import replace_attr, add_attr +replace_attr(FlashAttention,"__init__", FlashAttention__init__) +add_attr(FlashAttention, "forward_fa", FlashAttention_forward_fa) +add_attr(FlashAttention, "forward_before_fa", FlashAttention_forward_before_fa) +add_attr(FlashAttention, "forward_after_fa", FlashAttention_forward_after_fa) + +replace_attr(DotProductAttention, "__init__", DotProductAttention__init__) +add_attr(DotProductAttention, "forward_fa", DotProductAttention_forward_fa) +add_attr(DotProductAttention, "forward_before_fa", DotProductAttention_forward_before_fa) +add_attr(DotProductAttention, "forward_after_fa", DotProductAttention_forward_after_fa) + diff --git a/transformer_engine/musa/pytorch/cpp_extensions/__init__.py b/transformer_engine/musa/pytorch/cpp_extensions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/transformer_engine/musa/pytorch/cpp_extensions/cast.py b/transformer_engine/musa/pytorch/cpp_extensions/cast.py new file mode 100644 index 0000000000..96b25fb3a2 --- /dev/null +++ b/transformer_engine/musa/pytorch/cpp_extensions/cast.py @@ -0,0 +1,19 @@ +import torch + +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, +) + + +def cast_to_fp8(src: torch.Tensor, out: Float8Tensor): + assert isinstance(src, torch.Tensor) + assert isinstance(out, Float8Tensor), "Only supports Float8Tensor now." + out.quantize_(src, noop_flag=None) + + +def weak_support_fp8_cast(): + import transformer_engine.pytorch.cpp_extensions as m + setattr(m, "cast_to_fp8", cast_to_fp8) + + +weak_support_fp8_cast() diff --git a/transformer_engine/musa/pytorch/csrc/common.cpp b/transformer_engine/musa/pytorch/csrc/common.cpp new file mode 100644 index 0000000000..bd8617155e --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/common.cpp @@ -0,0 +1,249 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" + +#include "c10/util/ArrayRef.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" +#include "common/util/musa_runtime.h" +namespace transformer_engine::pytorch { + +std::vector getTensorShape(at::Tensor t) { + std::vector shape; + for (auto s : t.sizes()) { + shape.push_back(s); + } + return shape; +} + +std::unique_ptr convert_quantizer(py::handle quantizer) { + init_extension(); + if (quantizer.is_none()) { + return std::make_unique(quantizer); + } + for (auto [_check_type, check_quantizer_type, _create_tensor, create_quantizer] : + detail::custom_types_converters) { + if (check_quantizer_type(quantizer.ptr())) { + return create_quantizer(quantizer); + } + } + + NVTE_ERROR("Unexpected type for quantizer"); +} + +transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, + const std::string& fp8_recipe) { + // if e4m3 or hybrid + forward + if ((fp8_recipe == "E4M3") || ((fp8_recipe == "HYBRID") && e4m3_if_hybrid)) { + return transformer_engine::DType::kFloat8E4M3; + } + return transformer_engine::DType::kFloat8E5M2; +} + +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { + NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + for (auto [check_type, check_quantizer_type, create_tensor, _] : + detail::custom_types_converters) { + if (check_type(tensor.ptr())) { + NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()), + "Unexpected quantization params type."); + auto x = create_tensor(tensor, my_quantizer.get()); + return x; + } + } + + // Regular pyTorch tensor + at::Tensor torch_tensor = tensor.cast(); + + // #TODO (pgadzinski) - needed in attention for non-contiguous tensors. + //if (!torch_tensor.is_contiguous()) { + // torch_tensor = torch_tensor.contiguous(); + //} + auto ret = TensorWrapper(my_quantizer->get_scaling_mode()); + ret.set_rowwise_data(torch_tensor.data_ptr(), + GetTransformerEngineDType(torch_tensor.scalar_type()), + getTensorShape(torch_tensor)); + my_quantizer->set_quantization_params(&ret); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type) { + return transformer_engine::TensorWrapper(data_ptr, shape, type); +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { + return transformer_engine::TensorWrapper(data_ptr, shape, type); +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { + transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); + std::vector shape; + for (auto s : tensor.sizes()) { + shape.push_back(s); + } + return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape, + const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, + const at::Tensor scale, + at::Tensor scale_inv, + NVTEScalingMode scaling_mode) { + transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); + + auto tensor_shape = getTensorShape(tensor); + auto scale_inv_shape = getTensorShape(scale_inv); + + NVTE_CHECK(amax.scalar_type() == at::kFloat); + NVTE_CHECK(scale.scalar_type() == at::kFloat); + NVTE_CHECK(scale_inv.scalar_type() == at::kFloat); + + return makeTransformerEngineTensor(tensor.data_ptr(), tensor_shape, dtype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr(), scale_inv_shape, + scaling_mode); +} + +template +T product(const std::vector& shape) { + T ret = 1; + for (auto s : shape) { + ret *= s; + } + return ret; +} + +template size_t product(const std::vector& shape); +template int64_t product(const std::vector& shape); + +size_t product(const NVTEShape& shape, size_t begin, size_t end) { + NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, + " in a shape with ", shape.ndim, " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape.data[i]; + } + return ret; +} + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape) { + std::vector shape; + for (size_t i = 0; i < nvte_shape.ndim; i++) { + shape.push_back(nvte_shape.data[i]); + } + return shape; +} + +at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, + bool init_to_zeros) { + std::vector shape_int64(shape.begin(), shape.end()); + c10::IntArrayRef ar_shape(shape_int64); + const auto opt = at::TensorOptions().dtype(GetATenDType(type)) + .device( + c10::DeviceType::PrivateUse1, + static_cast(transformer_engine::cuda::current_device())); + if (init_to_zeros) { + return at::zeros(ar_shape, opt); + } else { + return at::empty(ar_shape, opt); + } +} + +at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type, + bool init_to_zeros) { + auto size = shape.ndim; + const auto opt = at::TensorOptions().dtype(GetATenDType(type)) + .device( + c10::DeviceType::PrivateUse1, + static_cast(transformer_engine::cuda::current_device())); + if (size == 2 && init_to_zeros) { + return at::zeros({static_cast(shape.data[0]), static_cast(shape.data[1])}, opt); + } else if (size == 2) { + return at::empty({static_cast(shape.data[0]), static_cast(shape.data[1])}, opt); + } else if (size == 1 && init_to_zeros) { + return at::zeros({static_cast(shape.data[0])}, opt); + } else if (size == 1) { + return at::empty({static_cast(shape.data[0])}, opt); + } + NVTE_CHECK(false, "Should never reach here! func: allocateSpace"); +} + +at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) { + const auto opt = at::TensorOptions().dtype(GetATenDType(dtype)) + .device( + c10::DeviceType::PrivateUse1, + static_cast(transformer_engine::cuda::current_device())); + return at::empty({static_cast(M), static_cast(N)}, opt); +} + +at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype) { + const auto opt = at::TensorOptions().dtype(GetATenDType(dtype)) + .device( + c10::DeviceType::PrivateUse1, + static_cast(transformer_engine::cuda::current_device())); + return at::empty({static_cast(M)}, opt); +} + +void* getDataPtr(at::Tensor tensor, int offset) { + void* dptr = nullptr; + if (tensor.numel() > 0) { + dptr = tensor.data_ptr(); + } + if (dptr != nullptr && offset != 0) { + char* char_ptr = reinterpret_cast(dptr); + char_ptr += offset * tensor.element_size(); + dptr = reinterpret_cast(char_ptr); + } + return dptr; +} + +std::vector convertShape(const NVTEShape& shape) { + return std::vector(shape.data, shape.data + shape.ndim); +} + +int roundup(const int value, const int multiple) { + assert(multiple > 0); + return ((value + multiple - 1) / multiple) * multiple; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/musa/pytorch/csrc/common.h b/transformer_engine/musa/pytorch/csrc/common.h new file mode 100644 index 0000000000..845a065ecd --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/common.h @@ -0,0 +1,328 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// #include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "c10/util/ArrayRef.h" +#include "common/util/logging.h" + +namespace transformer_engine::pytorch { + +// Each tensor here is shape (N, ) holding all scaling +// data for a single FP8 block, e.g. LayerNormLinear +class FP8TensorMeta { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax_history; +}; + +// Used as named indices on the `scale`, `scale_inv`, +// and `amax` tensors in the `FP8TensorMeta` class. +enum FP8FwdTensors { + GEMM1_INPUT = 0, + GEMM1_WEIGHT = 1, + GEMM1_OUTPUT = 2, + GEMM2_INPUT = 3, + GEMM2_WEIGHT = 4, + GEMM2_OUTPUT = 5, + GEMM3_INPUT = 6, + GEMM3_WEIGHT = 7, + GEMM3_OUTPUT = 8 +}; + +// Used as named indices on the `scale`, `scale_inv`, +// and `amax` tensors in the `FP8TensorMeta` class. +enum FP8BwdTensors { + GRAD_OUTPUT1 = 0, + GRAD_INPUT1 = 1, + GRAD_OUTPUT2 = 2, + GRAD_INPUT2 = 3, + GRAD_OUTPUT3 = 4, + GRAD_INPUT3 = 5 +}; + +class Quantizer { + public: + virtual NVTEScalingMode get_scaling_mode() const = 0; + + virtual void set_quantization_params(TensorWrapper* tensor) const = 0; + + virtual std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const = 0; + + virtual ~Quantizer() = default; + + bool rowwise_usage = true; + bool columnwise_usage = true; + bool internal = false; + py::handle quantizer; + + protected: + explicit Quantizer(const py::handle& quantizer); +}; + +class NoneQuantizer : public Quantizer { + public: + explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {} + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override {} + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class Float8Quantizer : public Quantizer { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax; + DType dtype; + + explicit Float8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class MXFP8Quantizer : public Quantizer { + public: + DType dtype; + + explicit MXFP8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class MTFP8Quantizer : public Quantizer { + public: + DType dtype; + int64_t block_m; + int64_t block_n; + + explicit MTFP8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_MTFP8_BLOCK_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +std::unique_ptr convert_quantizer(py::handle quantizer); + +std::vector getTensorShape(at::Tensor t); + +transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, + const std::string& fp8_recipe); + +inline at::ScalarType GetATenDType(transformer_engine::DType t) { + switch (t) { + case transformer_engine::DType::kInt32: + return torch::kInt32; + case transformer_engine::DType::kInt64: + return torch::kInt64; + case transformer_engine::DType::kFloat32: + return at::kFloat; + case transformer_engine::DType::kFloat16: + return at::kHalf; + case transformer_engine::DType::kBFloat16: + return at::kBFloat16; + case transformer_engine::DType::kByte: + return at::kByte; + case transformer_engine::DType::kFloat8E4M3: + return at::kFloat8_e4m3fn; + case transformer_engine::DType::kFloat8E5M2: + return at::kFloat8_e5m2; + default: + NVTE_ERROR("Invalid type"); + } +} + +inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { + switch (t) { + case at::kFloat8_e4m3fn: + return transformer_engine::DType::kFloat8E4M3; + case at::kFloat8_e5m2: + return transformer_engine::DType::kFloat8E5M2; + case at::kHalf: + return transformer_engine::DType::kFloat16; + case at::kFloat: + return transformer_engine::DType::kFloat32; + case at::kBFloat16: + return transformer_engine::DType::kBFloat16; + case at::kBool: + return transformer_engine::DType::kByte; + case torch::kByte: + return transformer_engine::DType::kByte; + case torch::kInt32: + return transformer_engine::DType::kInt32; + case torch::kInt64: + return transformer_engine::DType::kInt64; + default: + std::cout << "Type: " << static_cast(t) << std::endl; + NVTE_ERROR("Invalid type"); + } +} + +inline transformer_engine::DType GetTransformerEngineDType(int DType_value) { + return static_cast(DType_value); +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, + const std::vector& shape, + const transformer_engine::DType type); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape = {1}, + const std::vector& columnwise_scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, + const NVTEShape& shape, + const transformer_engine::DType type); + +transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); + +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +template +T product(const std::vector& shape); + +size_t product(const NVTEShape& shape, size_t begin, size_t end); + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); + +at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, + bool init_to_zeros); + +at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type, + bool init_to_zeros = false); + +at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype); + +at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); + +void* getDataPtr(at::Tensor tensor, int offset = 0); + +std::vector convertShape(const NVTEShape& shape); + +int roundup(const int value, const int multiple); + +} // namespace transformer_engine::pytorch + +namespace std { +template +string to_string(const vector& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +// Torch shape -> string +template +string to_string(const c10::ArrayRef& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +inline string to_string(const NVTEShape& s) { + string ret = "["; + for (size_t i = 0; i < s.ndim; ++i) { + ret += to_string(s.data[i]) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} +} // namespace std + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/musa/pytorch/csrc/extensions.h b/transformer_engine/musa/pytorch/csrc/extensions.h new file mode 100644 index 0000000000..2c811cfaec --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions.h @@ -0,0 +1,474 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ + +#include + +#include "common.h" + +/*************************************************************************************************** + * Permutation + **************************************************************************************************/ + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num); + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob); + +std::tuple moe_permute_mask(const transformer_engine::DType dtype, + at::Tensor input, at::Tensor row_id_map, + at::Tensor probs, int num_tokens, + int num_experts, int num_out_tokens, + int hidden_size); + +std::tuple moe_unpermute_mask(const transformer_engine::DType dtype, + at::Tensor input, at::Tensor row_id_map, + at::Tensor merging_probs, + at::Tensor permuted_probs, int num_tokens, + int num_experts, int hidden_size); + +std::tuple moe_unpermute_mask_bwd_with_merging_probs( + const transformer_engine::DType dtype, at::Tensor fwd_output_grad, at::Tensor fwd_input, + at::Tensor merging_probs, at::Tensor row_id_map, int num_tokens, int num_experts, + int num_out_tokens, int hidden_size); +/*************************************************************************************************** + * Attention + **************************************************************************************************/ + +NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); + +std::vector fused_attn_fwd( + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const std::vector window_size, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const c10::optional cu_seqlens_q_padded, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle o_quantizer, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread); + +std::vector fused_attn_bwd( + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional cu_seqlens_q_padded, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer); + +at::Tensor fa_prepare_fwd(at::Tensor qkvi); +at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); + +/*************************************************************************************************** + * GEMM + **************************************************************************************************/ + +using MaybeTensor = std::optional; + +void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count, int m_split, int n_split, + bool gemm_producer, at::Tensor counter); + +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); + +/*************************************************************************************************** + * Transpose + **************************************************************************************************/ + +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype); +namespace transformer_engine::pytorch { +std::vector fused_multi_quantize_batch_init(std::vector input_list, + size_t hidden_dim, + std::vector m_splits, + std::vector quantizer_list, + transformer_engine::DType otype); + } +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output = std::nullopt); + +namespace transformer_engine::pytorch { + +/*************************************************************************************************** + * Activations + **************************************************************************************************/ + +py::object gelu(const at::Tensor &input, py::handle quantizer); + +py::object relu(const at::Tensor &input, py::handle quantizer); + +py::object geglu(const at::Tensor &input, py::handle quantizer); + +py::object qgeglu(const at::Tensor &input, py::handle quantizer); + +py::object reglu(const at::Tensor &input, py::handle quantizer); + +py::object swiglu(const at::Tensor &input, py::handle quantizer); + +py::object qgelu(const at::Tensor &input, py::handle quantizer); + +py::object srelu(const at::Tensor &input, py::handle quantizer); + +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +} // namespace transformer_engine::pytorch + +/*************************************************************************************************** + * LayerNorm + **************************************************************************************************/ + +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &mu, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma); + +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object ln_out, py::handle quantizer, + transformer_engine::DType out_dtype, const int sm_margin, + const bool zero_centered_gamma); + +/*************************************************************************************************** + * RMSNorm + **************************************************************************************************/ + +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &rsigma, const at::Tensor &gamma, + const int sm_margin, const bool zero_centered_gamma); + +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object ln_out, py::handle quantizer, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma); + +/*************************************************************************************************** + * Cast + **************************************************************************************************/ + +namespace transformer_engine::pytorch { + +py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, + std::optional noop); + +py::object dequantize(const py::handle &input, transformer_engine::DType otype); + +std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer); + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, + std::optional comm_type = std::nullopt, + MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); + +/*************************************************************************************************** + * Cast fusions + **************************************************************************************************/ + +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +} // namespace transformer_engine::pytorch + +/*************************************************************************************************** + * Softmax + **************************************************************************************************/ + +at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor); + +at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor); + +at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor); + +at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor); + +at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor); + +at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, + at::Tensor softmax_results_, + float scale_factor); + +at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor); + +at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_, + at::Tensor softmax_results_, + float scale_factor); + +/*************************************************************************************************** + * FP8 recipe + **************************************************************************************************/ + +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, float margin); + +// Note that the start_offset is the logical offset along the tensor dimension. +// The offset in bytes is start_offset * sizeof(tensor.dtype) +void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, + size_t w, size_t start_offset, size_t block_len); + +void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, + size_t h, size_t w, size_t start_offset, size_t block_len, + const transformer_engine::DType out_dtype); + +/*************************************************************************************************** + * Rotary positional embedding + **************************************************************************************************/ + +at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const bool transpose_output_memory); + +at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const bool transpose_output_memory); + +at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, + const at::Tensor &freqs, const int cp_size, const int cp_rank); + +at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, + const at::Tensor &freqs, const int cp_size, const int cp_rank); + +/*************************************************************************************************** + * Miscellaneous + **************************************************************************************************/ + +size_t get_mublas_version(); + +size_t get_mudnn_version(); + +/*************************************************************************************************** + * Support THD format for Context Parallel + **************************************************************************************************/ + +at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, + int half_idx); + +void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, bool lse_packed); + +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, + bool lse_packed, int second_half_lse_seqlen); + +void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, + const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, + bool only_second_half, bool lse_packed); + +void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, const std::string &first_half, + const std::string &second_half); + +at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, + int world_size, int rank); + +/*************************************************************************************************** + * multi_tensor_* kernels + **************************************************************************************************/ + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, float scale); + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + at::optional per_tensor_python); + +std::tuple multi_tensor_unscale_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + at::Tensor inv_scale, at::optional per_tensor_python); + +using transformer_engine::DType; +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay); + +void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype); + +void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, const float beta1, const float beta2, + const float epsilon, at::Tensor step, const int mode, + const int bias_correction, const float weight_decay, + at::Tensor inv_scale); + +void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, const float beta1, const float beta2, + const float epsilon, at::Tensor step, const int mode, + const int bias_correction, const float weight_decay, + at::Tensor inv_scale); + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, float wd, + float momentum, float dampening, float lr, bool nesterov, bool first_run, + bool wd_after_momentum, float scale); + +void multi_tensor_compute_scale_and_scale_inv_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + float max_fp8, bool force_pow_2_scales, float epsilon); + +/*************************************************************************************************** + * padding + **************************************************************************************************/ + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list); + +/*************************************************************************************************** + * swizzle + **************************************************************************************************/ + +void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv); + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv); + +/*************************************************************************************************** + * Comm+GEMM Overlap Wrappers + **************************************************************************************************/ + +class CommOverlapHelper : torch::CustomClassHolder { + private: + bool initialized{false}; + bool backend_is_mccl{false}; + std::map pgs; + + public: + int myrank = -1; + int numranks = -1; + int mylocal = -1; + int numlocal = -1; + int mynode = -1; + int numnodes = -1; + + CommOverlapHelper(); + + CommOverlapHelper(c10d::ProcessGroup *world_group, + std::optional intra_node_group, + std::optional inter_node_group); + + ~CommOverlapHelper(); + + void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, + ExtComm comm); + + void ub_barrier(ExtComm comm); +}; + +class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { + public: + CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = false, + bool rs_overlap_first_gemm = false); + + ~CommOverlap() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + +}; // CommOverlap + +class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { + public: + CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + transformer_engine::CommOverlapType comm_type, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, + bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, + bool aggregate = false); + + ~CommOverlapP2P() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + +}; // CommOverlapP2P + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/musa/pytorch/csrc/extensions/activation.cpp b/transformer_engine/musa/pytorch/csrc/extensions/activation.cpp new file mode 100644 index 0000000000..5ae30c60b4 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/activation.cpp @@ -0,0 +1,118 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" +#include "pybind.h" + +namespace transformer_engine::pytorch { + +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); + + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + input_shape[input_shape.size() - 1] /= shape_divisor; + auto fake_tensor_type = input.scalar_type(); + + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); + + act_func(te_input.data(), te_output.data(), at::musa::getCurrentMUSAStream()); + + return out; +} + +template +py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); + auto grad_tensor = grad.contiguous(); + + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = input.scalar_type(); + + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); + + act_func(te_grad.data(), te_input.data(), te_output.data(), at::musa::getCurrentMUSAStream()); + + return out; +} + +py::object gelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object relu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object geglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object qgeglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object reglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object swiglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object qgelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object srelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/musa/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/musa/pytorch/csrc/extensions/apply_rope.cpp new file mode 100644 index 0000000000..301ac9ca52 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/apply_rope.cpp @@ -0,0 +1,223 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const bool transpose_output_memory) { + using namespace transformer_engine::pytorch; + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(input.size(0) <= freqs.size(0), + "expected freqs tensor has a longer sequence length than input"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(input.size(3) >= freqs.size(3), + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + // input sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = input.size(0); + const int b = input.size(1); + const int h = input.size(2); + const int d = input.size(3); + // input strides + const int stride_s = input.stride(0); + const int stride_b = input.stride(1); + const int stride_h = input.stride(2); + const int stride_d = input.stride(3); + // freqs' shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = freqs.size(3); + + // output + auto act_options = input.options().requires_grad(false); + at::Tensor output; + if (transpose_output_memory) { + output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + output = torch::empty({s, b, h, d}, act_options); + } + // output strides + const int o_stride_s = output.stride(0); + const int o_stride_b = output.stride(1); + const int o_stride_h = output.stride(2); + const int o_stride_d = output.stride(3); + + auto input_cu = makeTransformerEngineTensor(input); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto output_cu = makeTransformerEngineTensor(output); + + nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, + stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, at::musa::getCurrentMUSAStream()); + + return output; +} + +at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const bool transpose_output_memory) { + using namespace transformer_engine::pytorch; + TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(output_grads.size(0) <= freqs.size(0), + "expected freqs tensor has a longer sequence length than output_grads"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(output_grads.size(3) >= freqs.size(3), + "expected the last dim of the output_grads tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + // output_grads sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = output_grads.size(0); + const int b = output_grads.size(1); + const int h = output_grads.size(2); + const int d = output_grads.size(3); + // output_grads strides + const int stride_s = output_grads.stride(0); + const int stride_b = output_grads.stride(1); + const int stride_h = output_grads.stride(2); + const int stride_d = output_grads.stride(3); + // freqs' shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = freqs.size(3); + + auto act_options = output_grads.options().requires_grad(false); + at::Tensor input_grads; + if (transpose_output_memory) { + input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + input_grads = torch::empty({s, b, h, d}, act_options); + } + const int o_stride_s = input_grads.stride(0); + const int o_stride_b = input_grads.stride(1); + const int o_stride_h = input_grads.stride(2); + const int o_stride_d = input_grads.stride(3); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto input_grads_cu = makeTransformerEngineTensor(input_grads); + + nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, + d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, at::musa::getCurrentMUSAStream()); + + return input_grads; +} + +at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, + const at::Tensor &freqs, const int cp_size, const int cp_rank) { + using namespace transformer_engine::pytorch; + TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(input.size(2) >= freqs.size(3), + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + // input sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + const int t = input.size(0); + const int h = input.size(1); + const int d = input.size(2); + // input strides + const int stride_t = input.stride(0); + const int stride_h = input.stride(1); + const int stride_d = input.stride(2); + // batch size + const int b = cu_seqlens.size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + // output + auto act_options = input.options().requires_grad(false); + auto output = torch::empty({t, h, d}, act_options); + // output strides + const int o_stride_t = output.stride(0); + const int o_stride_h = output.stride(1); + const int o_stride_d = output.stride(2); + + auto input_cu = makeTransformerEngineTensor(input); + auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto output_cu = makeTransformerEngineTensor(output); + + nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, + at::musa::getCurrentMUSAStream()); + + return output; +} + +at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, + const at::Tensor &freqs, const int cp_size, const int cp_rank) { + using namespace transformer_engine::pytorch; + TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(output_grads.size(2) >= freqs.size(3), + "expected the last dim of the output_grads tensor equals or is " + "greater than the freqs tensor"); + TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + + // output_grads sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + const int t = output_grads.size(0); + const int h = output_grads.size(1); + const int d = output_grads.size(2); + // output_grads strides + const int stride_t = output_grads.stride(0); + const int stride_h = output_grads.stride(1); + const int stride_d = output_grads.stride(2); + // batch size + const int b = cu_seqlens.size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + auto act_options = output_grads.options().requires_grad(false); + auto input_grads = torch::empty({t, h, d}, act_options); + const int o_stride_t = input_grads.stride(0); + const int o_stride_h = input_grads.stride(1); + const int o_stride_d = input_grads.stride(2); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto input_grads_cu = makeTransformerEngineTensor(input_grads); + + nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, + at::musa::getCurrentMUSAStream()); + + return input_grads; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/attention.mu b/transformer_engine/musa/pytorch/csrc/extensions/attention.mu new file mode 100644 index 0000000000..225733ae50 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/attention.mu @@ -0,0 +1,1011 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common/common.h" +#include "common/fused_attn/thd_utils.h" +#include "extensions.h" + +using namespace transformer_engine::fused_attn; + +constexpr int block_size = 512; +constexpr int ctas_per_sm = 4; + +// get the fused attention backend +NVTE_Fused_Attn_Backend get_fused_attn_backend( + const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, + attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, + head_dim_qk, head_dim_v, window_size_left, window_size_right); + return fused_attention_backend; +} + +// fast zero-fills of tensors +template +__global__ void __launch_bounds__(block_size) + mha_fill_kernel(scalar_t *out_tensor, const int32_t *const start_row, const size_t num_rows) { + size_t row_stride = gridDim.y * blockDim.x; + size_t row_index = blockIdx.x + static_cast(start_row[0]); + size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; + while (row_index < num_rows) { + out_tensor[row_index * row_stride + col_index] = 0; + row_index += gridDim.x; + } +} + +// fast zero-fills of tensors +void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) { + std::vector shape = transformer_engine::pytorch::convertShape(self.shape()); + + auto max_tokens = shape[0]; + auto fcd_size = 1; + for (int i = 1; i <= shape.size(); i++) { + fcd_size *= shape[i]; + } + TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); + const int num_mp = at::musa::getCurrentDeviceProperties()->multiProcessorCount; + uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); + uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); + dim3 dim_grid(num_blk_x, num_blk_y); + dim3 dim_block(block_size); + // trzeba jakos przekonwertowac DType na scalar_type + at::ScalarType scalar_type = transformer_engine::pytorch::GetATenDType(self.dtype()); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "mha_fill", [&]() { + mha_fill_kernel<<>>( + static_cast(self.get_rowwise_data().data_ptr), + static_cast(start_index.data_ptr()), max_tokens); + C10_MUSA_KERNEL_LAUNCH_CHECK(); + }); +} + +// extract seed and offset from PhiloxMusaState +__global__ void unpack(at::PhiloxMusaState arg, int64_t *rng_state_ptr) { + if (arg.captured_) { + rng_state_ptr[0] = static_cast(*arg.seed_.ptr); + rng_state_ptr[1] = + static_cast(*(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + } else { + rng_state_ptr[0] = static_cast(arg.seed_.val); + rng_state_ptr[1] = static_cast(arg.offset_.val); + } +} + +// extract PhiloxMusaState from CUDA random number generator +at::PhiloxMusaState init_philox_state(at::MUSAGeneratorImpl *gen, size_t elts_per_thread) { + at::PhiloxMusaState philox_args; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_musa_state(elts_per_thread); + return philox_args; +} + +// fused attention FWD with separate Q, K and V tensors +std::vector fused_attn_fwd( + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, const std::vector window_size, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const c10::optional cu_seqlens_q_padded, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle o_quantizer, const c10::optional Bias, + const c10::optional rng_gen, size_t rng_elts_per_thread) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + TensorWrapper te_Q, te_K, te_V, te_O, te_S; + + auto none = py::none(); + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); + + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); + + // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); + auto options = torch::TensorOptions() + .dtype(GetATenDType(qkv_type)) + .device(c10::kPrivateUse1, c10::musa::current_device()); + // create output tensor O + + auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; + o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + py::object o_python, s_python; + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; + + // construct NVTE tensors + TensorWrapper te_Bias; + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + auto h = q_shape[q_shape.size() - 2]; + auto d = q_shape[q_shape.size() - 1]; + if (set_zero && ((h * d) % block_size == 0) && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + te_O.zero_(at::musa::getCurrentMUSAStream()); + } + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + te_O.zero_(at::musa::getCurrentMUSAStream()); + } + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + auto bias_sizes = Bias.value().sizes().vec(); + std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); + } + auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); + std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); + std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + te_cu_seqlens_q = + makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); + te_cu_seqlens_kv = + makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); + + if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { + auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); + std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), + cu_seqlens_q_padded_sizes.end()}; + auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); + std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), + cu_seqlens_kv_padded_sizes.end()}; + te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + } + + // extract rng seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::musa::detail::getDefaultMUSAGenerator()); + at::PhiloxMusaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + unpack<<<1, 1, 0, at::musa::getCurrentMUSAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), + te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, + max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, + attn_mask_type, window_size[0], window_size[1], workspace.data(), + at::musa::getCurrentMUSAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors] + std::vector output_tensors; + output_tensors.push_back(o_python); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + at::Tensor output_tensor; + if (nvte_aux_tensor_pack.size >= 2) { + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + if (i < nvte_aux_tensor_pack.size - 2) { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } else if (i == nvte_aux_tensor_pack.size - 2) { + output_tensor = rng_state; + } else if (i == nvte_aux_tensor_pack.size - 1) { + output_tensor = Bias.value(); + } + } else { + output_tensor = (i < nvte_aux_tensor_pack.size - 1) + ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) + : rng_state; + } + } else { + output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + } + output_tensors.push_back(py::cast(output_tensor)); + tensor->data.dptr = output_tensor.data_ptr(); + } + + // execute the kernel + nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), + te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, + max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, + attn_mask_type, window_size[0], window_size[1], workspace.data(), + at::musa::getCurrentMUSAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, softmax-related tensors, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with separate Q, K and V +std::vector fused_attn_bwd( + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional cu_seqlens_q_padded, + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + auto none = py::none(); + TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); + te_O = makeTransformerEngineTensor(O, none); + te_dO = makeTransformerEngineTensor(dO, none); + // qkv type from the te_Q + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + + py::object s_python, dp_python; + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); + auto h_q = q_shape[q_shape.size() - 2]; + auto h_kv = k_shape[k_shape.size() - 2]; + auto d_qk = q_shape[q_shape.size() - 1]; + auto d_v = v_shape[v_shape.size() - 1]; + auto options = torch::TensorOptions() + .dtype(GetATenDType(dqkv_type)) + .device(c10::kPrivateUse1, c10::musa::current_device()); + std::vector o_shape{q_shape.begin(), q_shape.end()}; + o_shape[o_shape.size() - 1] = d_v; + + at::Tensor dQ, dK, dV, dQKV, dKV; + py::object py_dQ, py_dK, py_dV; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + std::vector tmp_shape; + + switch (layout_group) { + case NVTE_QKV_Layout_Group::NVTE_3HD: + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); + dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + break; + case NVTE_QKV_Layout_Group::NVTE_H3D: + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); + dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + dK = dQKV.index({"...", torch::indexing::Slice(1, 2, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + dV = dQKV.index({"...", torch::indexing::Slice(2, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); + dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 3); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); + dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); + dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + dV = dKV.index({"...", torch::indexing::Slice(1, torch::indexing::None, 1), + torch::indexing::Slice(0, torch::indexing::None, 1)}) + .squeeze(tmp_shape.size() - 2); + break; + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + dK = torch::empty(tmp_shape, options); + tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + dV = torch::empty(tmp_shape, options); + break; + default: + NVTE_ERROR("QKV layout not supported!"); + } + std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + + // construct NVTE tensors + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && + dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && + (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } + + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + + // create cu_seqlens tensorwrappers + auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); + std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); + std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, + DType::kInt32, nullptr, nullptr, nullptr); + + TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; + if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { + auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); + std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), + cu_seqlens_q_padded_sizes.end()}; + auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); + std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), + cu_seqlens_kv_padded_sizes.end()}; + te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + } + + // convert auxiliary tensors from forward to NVTETensors + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create dBias the same shape as Bias + at::Tensor dBias; + TensorWrapper te_dBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + if (nvte_aux_tensor_pack.size >= 2) { + std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); + dBias = torch::empty(bias_shape, options); + te_dBias = makeTransformerEngineTensor(dBias); + } else { + dBias = torch::empty({1, static_cast(h_q), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv)}, + options); + te_dBias = makeTransformerEngineTensor(dBias); + } + if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + dBias.fill_(0); + } + } + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + window_size[0], window_size[1], deterministic, workspace.data(), + at::musa::getCurrentMUSAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), + te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), + te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + window_size[0], window_size[1], deterministic, workspace.data(), + at::musa::getCurrentMUSAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {py_dQ, py_dK, py_dV, py::cast(dBias)}; +} + +namespace flash_attention { + +constexpr int warp_size = 32; +constexpr int type_size = 2; // FP16 or BF16 +constexpr int nvec = sizeof(uint64_t) / type_size; +constexpr int load_size = warp_size * nvec; +constexpr int block_size = 512; + +template +__launch_bounds__(block_size) __global__ + void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z, + const size_t W) { + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec; + const T *my_input = qkvi + offset_input; + + const size_t s = warpid / B; + if (s >= S) return; + + const size_t b = warpid % B; + + const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec; + + T *my_output = qkv + offset_output; + + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size); + *out = *reinterpret_cast(my_input + i * load_size * 3); + } +} + +template +__launch_bounds__(block_size) __global__ + void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B, + const size_t S, const size_t Z, const size_t W) { + const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v); + + const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size; + const int id_in_warp = threadIdx.x % warp_size; + const size_t offset_input = warpid * W * Z + id_in_warp * nvec; + const T *my_input = input + offset_input; + + const size_t b = warpid / S; + if (b >= B) return; + + const size_t s = warpid % S; + + const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W; + + T *my_output = qkv + offset_output; + + for (int i = 0; i < Z; ++i) { + uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); + *out = *reinterpret_cast(my_input + i * load_size); + } +} + +} // namespace flash_attention + +at::Tensor fa_prepare_fwd(at::Tensor qkvi) { + NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || + qkvi.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(qkvi.size(3) == flash_attention::load_size); + NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); + NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); + NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); + NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride."); + + // [s, b, n, h * 3] -> [3, b, s, n, h] + std::vector shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; + at::Tensor qkv = at::empty( + shape, qkvi.options().device(c10::kPrivateUse1, c10::musa::current_device())); + + size_t warps = qkvi.size(0) * qkvi.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (qkvi.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_fwd + <<>>( + qkvi.data_ptr(), qkv.data_ptr(), shape[1], shape[2], shape[3], shape[4]); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_fwd + <<>>( + qkvi.data_ptr(), qkv.data_ptr(), shape[1], shape[2], shape[3], shape[4]); + } + + return qkv; +} + +at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { + NVTE_CHECK(q.is_contiguous()); + NVTE_CHECK(k.is_contiguous()); + NVTE_CHECK(v.is_contiguous()); + NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor."); + NVTE_CHECK(q.scalar_type() == at::ScalarType::Half || + q.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(k.scalar_type() == q.scalar_type()); + NVTE_CHECK(v.scalar_type() == q.scalar_type()); + NVTE_CHECK(q.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(q.size(3) == flash_attention::load_size); + NVTE_CHECK(k.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(k.size(3) == flash_attention::load_size); + NVTE_CHECK(v.size(3) % flash_attention::load_size == 0); + NVTE_CHECK(v.size(3) == flash_attention::load_size); + + // 3 x [s, b, n, h] -> [b, s, n, 3 * h] + + std::vector shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; + at::Tensor qkv = at::empty( + shape, q.options().device(c10::kPrivateUse1, c10::musa::current_device())); + + size_t warps = q.size(0) * q.size(1); + size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; + size_t blocks = (warps + warps_per_block - 1) / warps_per_block; + dim3 grid(blocks, 3); + int threads = flash_attention::block_size; + if (q.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + flash_attention::prepare_kernel_bwd + <<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), qkv.data_ptr(), + q.size(0), q.size(1), q.size(2), q.size(3)); + } else { + using dtype = at::BFloat16; + flash_attention::prepare_kernel_bwd + <<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), qkv.data_ptr(), + q.size(0), q.size(1), q.size(2), q.size(3)); + } + + return qkv; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ + +at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, + int half_idx) { + NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + + // Shapes of q and dq are [t, h, d], so the dimension of "t" is 0 + // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 + int seq_dim = tensor.dim() == 3 ? 0 : 1; + + int batch = cu_seqlens.size(0) - 1; + int num_heads = tensor.size(seq_dim + 1); + int dim_per_head = tensor.size(seq_dim + 2); + int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); + + // For 128-bits load/store + NVTE_CHECK(hidden_size_in_bytes % 16 == 0); + + // Generate output + std::vector shape(tensor.dim()); + for (size_t i = 0; i < shape.size(); i++) { + shape[i] = tensor.size(i); + } + shape[seq_dim] /= 2; + at::Tensor half = at::empty( + shape, tensor.options().device(c10::kPrivateUse1, c10::musa::current_device())); + + // Launch Kernel + constexpr unsigned int block = 256; + unsigned int grid_x = (tensor.size(seq_dim) / 2 * 32 + block - 1) / block; + unsigned int grid_y = 1; + for (int i = 0; i < seq_dim; i++) { + grid_y *= tensor.size(i); + } + dim3 grid = {grid_x, grid_y}; + thd_read_half_tensor_kernel<<>>( + half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr(), batch, hidden_size_in_bytes, + half_idx, tensor.size(seq_dim)); + + return half; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, bool lse_packed) { + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); + NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + + int batch, num_heads, lse_seqlen, second_half_lse_seqlen; + + if (lse_packed) { + NVTE_CHECK(lse.dim() == 2); + NVTE_CHECK(lse_per_step.dim() == 2); + + batch = cu_seqlens.size(0) - 1; + num_heads = lse.size(0); + lse_seqlen = lse.size(1); + second_half_lse_seqlen = lse_per_step.size(1); + + NVTE_CHECK(lse_per_step.size(0) == num_heads); + NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); + } else { + NVTE_CHECK(lse.dim() == 3); + NVTE_CHECK(lse_per_step.dim() == 3); + + batch = lse.size(0); + num_heads = lse.size(1); + lse_seqlen = lse.size(2); + second_half_lse_seqlen = lse_per_step.size(2); + + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + } + + constexpr unsigned int block = 256; + unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; + unsigned int grid_y = num_heads; + dim3 grid = {grid_x, grid_y}; + if (lse_packed) { + thd_lse_kernel + <<>>( + lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, lse_seqlen, second_half_lse_seqlen); + } else { + thd_lse_kernel + <<>>( + lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, lse_seqlen, second_half_lse_seqlen); + } +} + +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, + bool lse_packed, int second_half_lse_seqlen) { + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + + int batch, num_heads, lse_seqlen; + std::vector shape; + + if (lse_packed) { + NVTE_CHECK(lse.dim() == 2); + + batch = cu_seqlens.size(0) - 1; + num_heads = lse.size(0); + lse_seqlen = lse.size(1); + + NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); + + shape = {num_heads, second_half_lse_seqlen}; + } else { + NVTE_CHECK(lse.dim() == 3); + + batch = lse.size(0); + num_heads = lse.size(1); + lse_seqlen = lse.size(2); + + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); + + shape = {batch, num_heads, second_half_lse_seqlen}; + } + + at::Tensor half_lse = at::zeros( + shape, lse.options().device(c10::kPrivateUse1, c10::musa::current_device())); + + constexpr unsigned int block = 256; + unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; + unsigned int grid_y = num_heads; + dim3 grid = {grid_x, grid_y}; + if (lse_packed) { + thd_lse_kernel + <<>>( + lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, lse_seqlen, second_half_lse_seqlen); + } else { + thd_lse_kernel + <<>>( + lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, + num_heads, lse_seqlen, second_half_lse_seqlen); + } + + return half_lse; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, + const at::Tensor &lse, const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, bool lse_packed) { + NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type()); + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + + int total_tokens = out.size(0); + int num_heads = out.size(1); + int dim_per_head = out.size(2); + + NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); + NVTE_CHECK(out_per_step.size(1) == num_heads); + NVTE_CHECK(out_per_step.size(2) == dim_per_head); + + int batch, lse_seqlen, lse_per_step_seqlen; + if (lse_packed) { + batch = cu_seqlens.size(0) - 1; + lse_seqlen = lse.size(1); + lse_per_step_seqlen = lse_per_step.size(1); + + NVTE_CHECK(lse.size(0) == num_heads); + NVTE_CHECK(lse_seqlen >= total_tokens); + NVTE_CHECK(lse_per_step.size(0) == num_heads); + NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1)); + } else { + batch = lse.size(0); + lse_seqlen = lse.size(2); + lse_per_step_seqlen = lse_per_step.size(2); + + NVTE_CHECK(lse.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1)); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + } + + constexpr int tile = 16; + constexpr int block = 512; + unsigned int grid_x = + (static_cast(total_tokens) / (only_second_half + 1) * tile + block - 1) / block; + dim3 grid = {grid_x, (unsigned int)num_heads}; + + if (lse_packed) { + thd_out_correction_kernel + <<>>( + out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), + lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, + dim_per_head, lse_seqlen, lse_per_step_seqlen); + } else { + thd_out_correction_kernel + <<>>( + out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), + lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, + dim_per_head, lse_seqlen, lse_per_step_seqlen); + } +} + +void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, + const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, + bool only_second_half, bool lse_packed) { + if (only_second_half) { + if (out.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); + } else if (out.scalar_type() == at::ScalarType::Float) { + using dtype = float; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } + } else { + if (out.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); + } else if (out.scalar_type() == at::ScalarType::Float) { + using dtype = float; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens, + lse_packed); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +template +static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens) { + NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + + // Shape of dq is [t, h, d], so the dimension of "t" is 0 + // Shape of dkv is [2, t, h, d], so the dimension of "t" is 1 + int seq_dim = grad.dim() == 3 ? 0 : 1; + + int total_tokens = grad.size(seq_dim); + int num_heads = grad.size(seq_dim + 1); + int dim_per_head = grad.size(seq_dim + 2); + int batch = cu_seqlens.size(0) - 1; + + if constexpr (functor_idx < 2) { + NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens / 2); + } else { + NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens); + } + NVTE_CHECK(grad_per_step.size(seq_dim + 1) == num_heads); + NVTE_CHECK(grad_per_step.size(seq_dim + 2) == dim_per_head); + + size_t hidden_size = num_heads * dim_per_head; + NVTE_CHECK((hidden_size * c10::elementSize(grad.scalar_type())) % 16 == 0); + + constexpr unsigned int block = 256; + unsigned int grid_x; + if constexpr (functor_idx < 2) { + grid_x = (total_tokens / 2 * 32 + block - 1) / block; + } else { + grid_x = (total_tokens * 32 + block - 1) / block; + } + unsigned int grid_y = 1; + for (int i = 0; i < seq_dim; i++) { + grid_y *= grad.size(i); + } + dim3 grid = {grid_x, grid_y}; + + thd_grad_correction_kernel + <<>>( + grad.data_ptr(), grad_per_step.data_ptr(), cu_seqlens.data_ptr(), + batch, hidden_size, total_tokens); +} + +template +static void thd_grad_dispatcher(at::Tensor grad, const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, const std::string &first_half, + const std::string &second_half) { + if (first_half == "add" && second_half == "none") { + thd_grad_correction_helper, EmptyFunctor, 0>(grad, grad_per_step, + cu_seqlens); + } else if (first_half == "copy" && second_half == "none") { + thd_grad_correction_helper(grad, grad_per_step, + cu_seqlens); + } else if (first_half == "none" && second_half == "add") { + thd_grad_correction_helper, 1>(grad, grad_per_step, + cu_seqlens); + } else if (first_half == "none" && second_half == "copy") { + thd_grad_correction_helper(grad, grad_per_step, + cu_seqlens); + } else if (first_half == "add" && second_half == "copy") { + thd_grad_correction_helper, CopyFunctor, 2>(grad, grad_per_step, + cu_seqlens); + } else if (first_half == "copy" && second_half == "add") { + thd_grad_correction_helper, 2>(grad, grad_per_step, + cu_seqlens); + } else { + NVTE_ERROR("Unsupported Functor of first half and second_half\n"); + } +} + +void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, const std::string &first_half, + const std::string &second_half) { + if (grad.scalar_type() == at::ScalarType::Half) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else if (grad.scalar_type() == at::ScalarType::BFloat16) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else if (grad.scalar_type() == at::ScalarType::Float) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else { + NVTE_ERROR("Unsupported dtype of grad\n"); + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Generate partitioned indices for input tokens + **************************************************************************************************/ + +at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, + int world_size, int rank) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(rank >= 0 && rank < world_size); + NVTE_CHECK(world_size > 0); + NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); + + int batch = cu_seqlens.size(0) - 1; + + std::vector shape = {total_tokens / world_size}; + at::Tensor output = at::empty(shape, + c10::TensorOptions() + .dtype(at::ScalarType::Int) + .device(c10::kPrivateUse1, c10::musa::current_device())); + + constexpr unsigned int block = 256; + unsigned int grid = (output.size(0) + block - 1) / block; + thd_partition_indices_kernel<<>>( + output.data_ptr(), cu_seqlens.data_ptr(), batch, total_tokens, world_size, rank); + + return output; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/bias.cpp b/transformer_engine/musa/pytorch/csrc/extensions/bias.cpp new file mode 100644 index 0000000000..4f69519a29 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/bias.cpp @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" +#include "pybind.h" +#include "transformer_engine/cast.h" + +namespace transformer_engine::pytorch { + +std::vector bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) { + auto quantizer = convert_quantizer(py_quantizer); + + auto input_tensor = makeTransformerEngineTensor(input); + + auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype()); + + std::vector output_shape; + for (auto s : input.sizes()) { + output_shape.emplace_back(static_cast(s)); + } + auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype()); + + // Return immediately if tensors are empty + if (product(output_shape) == 0) { + return {py::cast(dbias.zero_()), out}; + } + + auto dbias_tensor = makeTransformerEngineTensor(dbias); + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::musa::getCurrentMUSAStream()); + + void* workspace_data_ptr = nullptr; + if (workspace.shape().ndim > 0) { + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace_data_ptr = workspace_data.data_ptr(); + } + workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype()); + + // Launch kernel + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::musa::getCurrentMUSAStream()); + + return {py::cast(dbias), out}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/musa/pytorch/csrc/extensions/cast.cpp b/transformer_engine/musa/pytorch/csrc/extensions/cast.cpp new file mode 100644 index 0000000000..57c11b5bb7 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/cast.cpp @@ -0,0 +1,129 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/cast.h" + +#include "common.h" +#include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, + std::optional noop) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = tensor.contiguous(); + + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = tensor.scalar_type(); + if (!detail::IsFloatingPointType(fake_tensor_type)) { + fake_tensor_type = at::kFloat; + } + + TensorWrapper te_output; + py::object out; + if (output.is_none()) { + DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); + std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type); + } else { + out = output; + te_output = makeTransformerEngineTensor(output, quantizer); + } + + TensorWrapper te_noop; + if (noop.has_value()) { + te_noop = makeTransformerEngineTensor(*noop); + } else { + te_noop = TensorWrapper(); + } + + if (te_output.numel() == 0) return out; + nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), + at::musa::getCurrentMUSAStream()); + + return out; +} + +py::object dequantize(const py::handle& input, transformer_engine::DType otype) { + init_extension(); + + const auto none = py::none(); + + const auto& input_tensor = makeTransformerEngineTensor(input, none); + + NoneQuantizer q(none); + + const auto& shape = convertShape(input_tensor.shape()); + + auto [out_tensor, out] = q.create_tensor(shape, otype); + + nvte_dequantize(input_tensor.data(), out_tensor.data(), at::musa::getCurrentMUSAStream()); + + return out; +} + +template +std::vector dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + + auto grad_tensor = makeTransformerEngineTensor(grad_output); + + auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); + auto act_input_tensor = makeTransformerEngineTensor(act_input); + + const auto& shape = convertShape(grad_tensor.shape()); + auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); + + auto dbias_tensor = makeTransformerEngineTensor(grad_bias); + + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::musa::getCurrentMUSAStream()); + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // Launch kernel + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::musa::getCurrentMUSAStream()); + + return {py::cast(grad_bias), dact}; +} + +std::vector dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} + +std::vector dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} + +std::vector dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} + +std::vector dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} + +std::vector dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/musa/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/musa/pytorch/csrc/extensions/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..956637f0d8 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -0,0 +1,327 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" +#include "transformer_engine/transformer_engine.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +using namespace torch::indexing; +using namespace std::placeholders; + +namespace te = transformer_engine; + +/*************************************************************************************************** + * CommOverlapHelper + **************************************************************************************************/ + +CommOverlapHelper::CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!"); +#endif +} // empty constructor for NVTE_UB_WITH_MPI=1 + +CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, + std::optional intra_domain_group, + std::optional inter_domain_group) { +#ifndef NVTE_UB_WITH_MPI + pgs.insert({"world", world_group}); + myrank = pgs["world"]->getRank(); + numranks = pgs["world"]->getSize(); + c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); + backend_is_mccl = true; // (backend != c10d::ProcessGroup::BackendType::MCCL); + + if (intra_domain_group.has_value()) { + // Get local rank on node and number of local ranks + NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"intra", intra_domain_group.value()}); + mylocal = pgs["intra"]->getRank(); + numlocal = pgs["intra"]->getSize(); + + if (numlocal == numranks) { + // Intra-node group is same as the world group so there can only be 1 node + NVTE_CHECK( + mylocal == myrank, + "Internal TE error: Local rank must be equal to global rank when intra-node group size ", + "is equal to the world group size!"); + mynode = 0; + numnodes = 1; + } else { + // Intra-node group is different than the world group so there must be multiple nodes + NVTE_CHECK( + inter_domain_group.has_value(), + "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", + "identical to the world_group!"); + + // Get node ID and number of nodes + NVTE_CHECK( + inter_domain_group.value()->getBackendType() == backend, + "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"inter", inter_domain_group.value()}); + mynode = pgs["inter"]->getRank(); + numnodes = pgs["inter"]->getSize(); + } + } else { + // Intra-node group is not set so we assume there is only 1 node + mylocal = myrank; + numlocal = numranks; + pgs.insert({"intra", world_group}); + + mynode = 0; + numnodes = 1; + } + + initialized = true; +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); +#endif +} + +CommOverlapHelper::~CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + for (auto &pg : pgs) pg.second = nullptr; + backend_is_mccl = false; + initialized = false; +#endif +} + +void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, + size_t localbytes, ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + + auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto localtmp = (backend_is_mccl) + ? localtensor.to(c10::Device(c10::kPrivateUse1), false, false) + : localtensor; + auto globaltensor = + torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto globaltmp = (backend_is_mccl) + ? globaltensor.to(c10::Device(c10::kPrivateUse1), false, false) + : globaltensor; + + std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; + std::vector localchunk = {localtmp}; + auto work = pgs[group]->allgather(globalchunks, localchunk); + work->wait(); + + if (backend_is_mccl) { + globaltensor.copy_(globaltmp.cpu()); + globaltmp = torch::Tensor(); + localtmp = torch::Tensor(); + } +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +void CommOverlapHelper::ub_barrier(ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + auto work = pgs[group]->barrier(); + work->wait(); +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +/*************************************************************************************************** + * CommOverlap + **************************************************************************************************/ + +CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool use_ce, bool rs_overlap_first_gemm) + : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), + helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, + helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, + num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, use_ce, rs_overlap_first_gemm) {} + +void CommOverlap::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + _ubuf_scale_inv_initialized = true; +} + +/* +** Helper function to copy input to _ubuf +*/ +void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); + if (local_chunk) { + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); + } else { + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + } + + // Copy either row or columnwise data into the communication buffer's columnwise data + // NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with + // the Userbuffers communicator. + c10::musa::MUSAStream stream_main = at::musa::getCurrentMUSAStream(); + NVTE_CHECK_CUDA(musaEventRecord(_start_d2dcopy, (musaStream_t)stream_main)); + NVTE_CHECK_CUDA(musaStreamWaitEvent((musaStream_t)_stream_comm, _start_d2dcopy, 0)); + NVTE_CHECK_CUDA(musaMemcpyAsync(ubuf_ptr, input_tensor.dptr(), + input_tensor.numel() * input_tensor.element_size(), + musaMemcpyDeviceToDevice, (musaStream_t)_stream_comm)); +} + +py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device( + c10::DeviceType::PrivateUse1)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; +} + +/*************************************************************************************************** + * CommOverlapP2P + **************************************************************************************************/ + +CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + te::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, + bool aggregate) + : te::CommOverlapP2PBase( + buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, + tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm, aggregate) {} + +void CommOverlapP2P::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]); +} + +/* +** Copy input to _ubufs[0] +*/ +void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + + c10::musa::MUSAStream stream_main = at::musa::getCurrentMUSAStream(); + if (local_chunk) { + // Copy input to the target ubuf chunk by rank offset + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(musaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + musaMemcpyDeviceToDevice, (musaStream_t)stream_main)); + + } else { + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(musaMemcpyAsync(_ubuf.dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + musaMemcpyDeviceToDevice, (musaStream_t)stream_main)); + } +} + +py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; + char *ubuf_wt_ptr = reinterpret_cast(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device( + c10::DeviceType::PrivateUse1));; + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.mu b/transformer_engine/musa/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.mu new file mode 100644 index 0000000000..b1eb1c732f --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.mu @@ -0,0 +1,229 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common/common.h" +#include "common/utils.muh" +#include "extensions.h" +#include "type_shim.h" + +constexpr int kTileDim = 128; +constexpr int kThreadsPerBlock = 256; + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + fp8_block_scaling_compute_partial_amax_kernel(const IType *input, float *amax_ptr, + const size_t amax_stride_h, + const size_t amax_stride_w, const size_t h, + const size_t w, const size_t start_offset, + const size_t len) { + constexpr int kThreadsPerWarp = 32; + constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + constexpr int kLoopsPerCol = kTileDim / kNumWarps; + + const int tile_col = blockIdx.x; + const int tile_row = blockIdx.y; + const size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + + __shared__ float smem[kNumWarps]; + float amax = 0.0f; + + for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) { + size_t r = tile_row * kTileDim + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp; + for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) { + size_t c = tile_col * kTileDim + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp); + size_t idx = r * w + c; + if (r < h && c < w && idx >= start_offset && idx < end_offset) { + float other_amax = fabs(static_cast(input_minus_offset[idx])); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + } + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { + float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + + if (threadIdx.x % kThreadsPerWarp == 0) { + smem[threadIdx.x / kThreadsPerWarp] = amax; + } + + __syncthreads(); + + if (threadIdx.x == 0) { + for (int i = 0; i < kNumWarps; ++i) { + float other_amax = smem[i]; + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax; + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + fp8_block_scaling_partial_cast_kernel(const IType *input, OType *output, const float *scale_ptr, + const size_t scale_stride_h, const size_t scale_stride_w, + const size_t h, const size_t w, const size_t start_offset, + const size_t len) { + using transformer_engine::Vec; + + static_assert(sizeof(OType) == 1); + constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType); + constexpr int kThreadsPerWarp = 32; + constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp; + constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; + constexpr int kRowsPerWarp = kTileDim / kNumWarps; + + __shared__ OType smem[kTileDim][kTileDim + kNumOutputElemsPerBank]; + + const int tile_w = blockIdx.x; + const int tile_h = blockIdx.y; + const size_t end_offset = start_offset + len; + const IType *input_minus_offset = input - start_offset; + OType *output_minus_offset = output - start_offset; + + const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w]; + + // Load input data into shared memory + bool skip_store = true; + for (int i = 0; i < kRowsPerWarp; ++i) { + for (int j = 0; j < kLoopsPerRow; ++j) { + const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j; + const int h_in_input = tile_h * kTileDim + h_in_smem; + const int w_in_input = tile_w * kTileDim + w_in_smem; + const size_t idx_in_input = static_cast(h_in_input) * w + w_in_input; + if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && + idx_in_input < end_offset) { + float inp = static_cast(input_minus_offset[idx_in_input]) * scale; + smem[h_in_smem][w_in_smem] = static_cast(inp); + skip_store = false; + } + } + } + + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { + bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); + skip_store = skip_store && other_skip_store; + } + skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0); + if (skip_store) { + return; + } + + // Store the casted data into the output. + // Note that this store operation might write "out-of-bounds", but it is intentional: + // 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region + // from start_offset to end_offset), not the boundary of the entire output memory. Therefore, + // this out-of-bounds write will not cause illegal memory access. + // 2. We assume that the subsequent all-gather operation happens in-place, so any parts that + // should not be updated here will be overwritten by the all-gather. + // This tricky approach allows us to avoid checking whether each output index falls within + // [start, end), resulting in a significant performance improvement. + Vec vec_output; + for (int i = 0; i < kRowsPerWarp; ++i) { + const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i; + const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank; + for (int j = 0; j < kNumOutputElemsPerBank; ++j) { + vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j]; + } + const int row_in_output = tile_h * kTileDim + row_in_smem; + const int col_in_output = tile_w * kTileDim + col_in_smem; + const size_t idx_in_output = static_cast(row_in_output) * w + col_in_output; + if (row_in_output < h) { + if constexpr (kWidthAligned) { + vec_output.store_to(output_minus_offset + idx_in_output); + } else { + int num = min(static_cast(kNumOutputElemsPerBank), + static_cast(col_in_output < w ? w - col_in_output : 0)); + vec_output.store_to_elts(output_minus_offset, idx_in_output, num); + } + } + } +} + +void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, + size_t w, size_t start_offset, size_t block_len) { + TORCH_CHECK(block_len == 128, "Currently only support block_len = 128"); + TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); + TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); + TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float || + tensor.scalar_type() == at::ScalarType::BFloat16, + "tensor must be a float or bfloat16 tensor"); + + size_t amax_stride_h = amax.stride(0); + size_t amax_stride_w = amax.stride(1); + size_t len = tensor.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + auto stream = at::musa::getCurrentMUSAStream(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor.scalar_type(), 0, "compute_partial_amax", + fp8_block_scaling_compute_partial_amax_kernel + <<>>( + tensor.data_ptr(), amax.data_ptr(), + amax_stride_h, amax_stride_w, h, w, start_offset, len);) +} + +void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, + size_t h, size_t w, size_t start_offset, size_t block_len, + const transformer_engine::DType out_dtype) { + TORCH_CHECK(block_len == 128, "Currently only support block_len = 128"); + TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); + TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); + TORCH_CHECK( + inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16, + "input must be a float or bfloat16 tensor"); + TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor"); + TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 || + out_dtype == transformer_engine::DType::kFloat8E5M2, + "out_dtype must be kFloat8E4M3 or kFloat8E5M2"); + + size_t scale_stride_h = scale.stride(0); + size_t scale_stride_w = scale.stride(1); + size_t len = inp.numel(); + + assert(h > 0 && w > 0); + assert(start_offset < h * w); + assert(start_offset + len <= h * w); + + size_t blocks_x = (w + kTileDim - 1) / kTileDim; + size_t blocks_y = (h + kTileDim - 1) / kTileDim; + assert(blocks_x <= std::numeric_limits::max()); + assert(blocks_y <= std::numeric_limits::max()); + dim3 grid(blocks_x, blocks_y); + + auto stream = at::musa::getCurrentMUSAStream(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + inp.scalar_type(), 0, "partial_cast", + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + out_dtype, fp8_type, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + w % kTileDim == 0, kWidthAligned, + fp8_block_scaling_partial_cast_kernel + <<>>(inp.data_ptr(), + reinterpret_cast(out.data_ptr()), + scale.data_ptr(), scale_stride_h, + scale_stride_w, h, w, start_offset, len);))) +} \ No newline at end of file diff --git a/transformer_engine/musa/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/musa/pytorch/csrc/extensions/gemm.cpp new file mode 100644 index 0000000000..6a6f6ca2ba --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/gemm.cpp @@ -0,0 +1,404 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include + +#include "../common.h" +#include "common.h" +#include "common/util/musa_runtime.h" +#include "common/util/system.h" +#include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace { + +void* get_data_ptr(MaybeTensor tensor) { + if (tensor.has_value()) return tensor->data_ptr(); + return nullptr; +} + +size_t get_size(MaybeTensor tensor, int dim) { + if (tensor.has_value()) return static_cast(tensor->size(dim)); + return 0; +} + +} // namespace + +namespace transformer_engine::pytorch { + +namespace detail { + +std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, + const NVTEShape& B_shape, const bool transb) { + // Flatten outer dims to get 2D matrices + const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); + const size_t A1 = A_shape.data[A_shape.ndim - 1]; + const size_t B0 = product(B_shape, 0, B_shape.ndim - 1); + const size_t B1 = B_shape.data[B_shape.ndim - 1]; + + // Check matrix dims + NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", + A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); + + // Construct output dims + std::vector ret; + if (transb) { + ret.emplace_back(B1); + } else { + // Unflatten B0 + for (size_t i = 0; i < B_shape.ndim - 1; ++i) { + ret.emplace_back(B_shape.data[i]); + } + } + if (transa) { + ret.emplace_back(A0); + } else { + ret.emplace_back(A1); + } + return ret; +} + +bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { + if (expected.size() != actual.ndim) return false; + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != actual.data[i]) return false; + } + return true; +} + +} // namespace detail + +std::pair createOutputTensor(const std::vector& shape, + DType dtype, py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore* comm_overlap, + std::optional comm_type, MaybeTensor extra_output, + bool bulk_overlap) { + // Input tensors + NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); + NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); + auto none = py::none(); + TensorWrapper A_tensor = makeTransformerEngineTensor(A, none); + TensorWrapper B_tensor = makeTransformerEngineTensor(B, none); + + // Check tensor dimensions + const auto& A_shape = A_tensor.shape(); + const auto& B_shape = B_tensor.shape(); + const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); + NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); + + // Output tensor + TensorWrapper D_tensor; + if (D.is_none()) { + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); + std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + } else { + D_tensor = makeTransformerEngineTensor(D, quantizer); + NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), + "GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ", + std::to_string(D_tensor.shape()), ")"); + if (out_dtype) { + NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ", + static_cast(*out_dtype), ", found ", static_cast(D_tensor.dtype()), ")"); + } + } + + // Bias tensor + TensorWrapper bias_tensor; + MaybeTensor bias_grad = std::nullopt; + if (bias.has_value()) { + if (grad) { + auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(c10::kPrivateUse1); + bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); + bias_tensor = makeTransformerEngineTensor(*bias_grad); + } else { + if (!bias->is_contiguous()) { + bias = bias->contiguous(); + } + bias_tensor = makeTransformerEngineTensor(*bias); + } + } + + // Activation input tensor + MaybeTensor pre_gelu_out = std::nullopt; + DType gelu_type = bias_type; + if (gelu) { + if (!grad) { + auto dtype = GetATenDType(gelu_type); + auto opts = torch::TensorOptions().dtype(dtype).device(c10::kPrivateUse1); + std::vector torch_shape; + for (auto v : D_shape) { + torch_shape.push_back(v); + } + pre_gelu_out = at::empty(torch_shape, opts); + } else { + if (gelu_in.has_value()) { + pre_gelu_out = *gelu_in; + } + } + } + const auto gelu_shape = gelu ? D_shape : std::vector{0}; + + auto te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); + + // Workspace + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); + + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + const int device_id = c10::musa::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + auto main_stream = at::musa::getCurrentMUSAStream(); + if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { + if (comm_overlap) { + // Prepare extra output tensor + TensorWrapper extra_output_tensor; + if (extra_output.has_value()) { + extra_output_tensor = makeTransformerEngineTensor(*extra_output); + } else { + extra_output_tensor = + makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + } + + // Direct GEMM call to the correct overlap + if (bulk_overlap) { + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, comm_type.value(), extra_output_tensor, + main_stream); + } else if (comm_type.value() == CommOverlapType::AG) { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } else { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } + } else { + // Launch GEMM + nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), + te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), + accumulate, use_split_accumulator, num_math_sms, main_stream); + } + } else { + if (D_tensor.numel() != 0 && !accumulate) { + D_tensor.zero_(main_stream); + } + if (bias.has_value()) { + if (bias->numel() != 0 && grad) { + bias_grad->zero_(); + } + } + } + + // Pack outputs + std::vector out; + out.emplace_back(std::move(D)); + out.emplace_back(py::cast(bias_grad)); + if (gelu && !grad) { + out.emplace_back(py::cast(*pre_gelu_out)); + } else { + out.emplace_back(py::none()); + } + if (extra_output.has_value()) { + out.emplace_back(py::cast(extra_output)); + } else { + out.emplace_back(py::none()); + } + return out; +} + +} // namespace transformer_engine::pytorch + +void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count, int m_split, int n_split, + bool gemm_producer, at::Tensor counter) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + // TODO: Handle scaling modes + NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; + NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + + auto te_A = makeTransformerEngineTensor( + A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, + nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), + nvte_scaling_modeA); + auto te_B = makeTransformerEngineTensor( + B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, + nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), + nvte_scaling_modeB); + // TODO: D_scale_inv cannot be nullptr when D_type is FP8. + auto te_D = makeTransformerEngineTensor( + D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, + D_amax.data_ptr(), D_scale.data_ptr(), nullptr); + auto te_bias = + makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); + auto te_counter = makeTransformerEngineTensor( + counter.data_ptr(), {static_cast(counter.size(0))}, DType::kInt32); + + const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out.size(0))} + : std::vector{static_cast(pre_gelu_out.size(0)), + static_cast(pre_gelu_out.size(1))}; + auto te_pre_gelu_out = makeTransformerEngineTensor( + pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); + + nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), + te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), + accumulate, use_split_accumulator, math_sm_count, m_split, n_split, + gemm_producer, te_counter.data(), at::musa::getCurrentMUSAStream()); +} + +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector, te_workspace_vector; + std::vector wrappers; + std::vector D_vectors; + + auto none = py::none(); + + std::vector single_output_begins; + std::vector single_output_ends; + int slicing_dim; + if (single_output && D == std::nullopt) { + NVTE_ERROR("not implemented, D should be allocated for single output case."); + } + + void* output_data_ptr; + if (single_output) { + output_data_ptr = (*D)[0].data_ptr(); + } + + for (size_t i = 0; i < A.size(); i++) { + auto te_A = makeTransformerEngineTensor(A[i], none); + auto te_B = makeTransformerEngineTensor(B[i], none); + + // if there is single output + at::Tensor out_tensor; + auto size_t_shape = + pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); + std::vector D_shape; + for (size_t t : size_t_shape) { + D_shape.push_back(t); + } + auto dtype = GetATenDType(D_type); + auto opts = torch::TensorOptions().dtype(dtype).device(c10::kPrivateUse1); + if (single_output) { + bool hasZeroDim = std::any_of(D_shape.begin(), D_shape.end(), + [](int64_t value) { return value == 0; }); + if (!hasZeroDim) { + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); + } + char* char_ptr = reinterpret_cast(output_data_ptr); + char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size(); + output_data_ptr = reinterpret_cast(char_ptr); + D_vectors.emplace_back(out_tensor); + } else { + if (D == std::nullopt) { + auto opts = torch::TensorOptions().dtype(dtype).device(c10::kPrivateUse1); + out_tensor = at::empty(D_shape, opts); + D_vectors.emplace_back(out_tensor); + } else { + out_tensor = (*D)[i]; + } + } + + if (te_A.numel() == 0 || te_B.numel() == 0) { + if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_(); + if (bias[i].numel() != 0 && grad) { + bias[i].zero_(); + } + if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); + continue; + } + + auto te_D = makeTransformerEngineTensor(out_tensor); + auto te_bias = makeTransformerEngineTensor(bias[i]); + auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); + + const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr + ? std::vector{static_cast(te_pre_gelu_out.size(0))} + : std::vector{static_cast(te_pre_gelu_out.size(0)), + static_cast(te_pre_gelu_out.size(1))}; + + DType gelu_type = bias_type; + te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); + + te_A_vector.emplace_back(te_A.data()); + te_B_vector.emplace_back(te_B.data()); + te_D_vector.emplace_back(te_D.data()); + te_bias_vector.emplace_back(te_bias.data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); + + wrappers.emplace_back(std::move(te_A)); + wrappers.emplace_back(std::move(te_B)); + wrappers.emplace_back(std::move(te_D)); + wrappers.emplace_back(std::move(te_bias)); + wrappers.emplace_back(std::move(te_pre_gelu_out)); + } + for (size_t i = 0; i < workspace.size(); i++) { + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte); + te_workspace_vector.emplace_back(wsp.data()); + wrappers.emplace_back(std::move(wsp)); + } + // For now, we only have multi-stream cublas backend. + nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), + te_A_vector.size(), transa, transb, grad, + te_workspace_vector.data(), accumulate, use_split_accumulator, + math_sm_count, at::musa::getCurrentMUSAStream()); + return bias; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/misc.cpp b/transformer_engine/musa/pytorch/csrc/extensions/misc.cpp new file mode 100644 index 0000000000..e7eab96e64 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/misc.cpp @@ -0,0 +1,11 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +size_t get_mublas_version() { return MUBLAS_VERSION_MAJOR * 10000ul + MUBLAS_VERSION_MINOR * 100ul + MUBLAS_VERSION_PATCH; } + +size_t get_mudnn_version() { return ::musa::dnn::GetVersion(); } diff --git a/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.mu b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.mu new file mode 100644 index 0000000000..f6d03e5be7 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.mu @@ -0,0 +1,644 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "common/utils.muh" +#include "multi_tensor_apply.muh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 +#define THREADS_PER_WARP 32 + +typedef enum { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; +using fp8e4m3 = __mt_fp8_e4m3; +using fp8e5m2 = __mt_fp8_e5m2; +using transformer_engine::DType; + +template +struct is_fp8 : std::false_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template +struct FP8Data { + float scale; + float *amax_ptr; + float *scale_inv_ptr; + float max; + int warp_id; +}; + +template <> +struct FP8Data {}; + +template +struct AdamFunctorMaster { + static constexpr bool is_fp8_type = is_fp8::value; + + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, + TensorListMetadata<5, is_fp8_type> &tl, // NOLINT(*) + const float beta1, const float beta2, + const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + FP8Data fp8_data; + + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); + p_master += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + if constexpr (is_fp8_type) { + float *scale_ptr = reinterpret_cast(tl.fp8_meta_addresses[0][tensor_loc]); + fp8_data.scale = scale_ptr != nullptr ? *scale_ptr : 1; + fp8_data.amax_ptr = reinterpret_cast(tl.fp8_meta_addresses[1][tensor_loc]); + fp8_data.scale_inv_ptr = reinterpret_cast(tl.fp8_meta_addresses[2][tensor_loc]); + fp8_data.warp_id = threadIdx.x / THREADS_PER_WARP; + fp8_data.max = 0; + } + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]); + r_p[ii] = static_cast(p_master[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p_master[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + if constexpr (is_fp8_type) { + __builtin_assume(fp8_data.max >= 0); + fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max); + p[i] = static_cast(r_p[ii] * fp8_data.scale); + } else { + p[i] = static_cast(r_p[ii]); + } + } + } + } + + if constexpr (is_fp8_type) { + fp8_data.max = transformer_engine::reduce_max( + fp8_data.max, fp8_data.warp_id); + if (threadIdx.x == 0) { + if (fp8_data.amax_ptr != nullptr) { + transformer_engine::atomicMaxFloat(fp8_data.amax_ptr, fp8_data.max); + } + if (fp8_data.scale_inv_ptr != nullptr) { + *fp8_data.scale_inv_ptr = __frcp_rn(fp8_data.scale); + } + } + } + } +}; + +template +struct AdamFunctor { + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, + TensorListMetadata<4> &tl, // NOLINT(*) + const float beta1, const float beta2, + const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]); + r_p[ii] = static_cast(p[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + +template +struct AdamCapturableFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<4> &tl, // NOLINT(*) + const float beta1, const float beta2, const int *step, + const int bias_correction, const float epsilon, + const float *lr, adamMode_t mode, const float decay, + const float *inv_scale) { + if (*noop_gmem == 1) return; + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - pow(beta1, *step); + beta2_correction = 1 - pow(beta2, *step); + } + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]) * (*inv_scale); + g[i] = static_cast(r_g[ii]); + r_p[ii] = static_cast(p[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (*lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (*lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + +template +struct AdamCapturableMasterFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<5> &tl, // NOLINT(*) + const float beta1, const float beta2, const int *step, + const int bias_correction, const float epsilon, + const float *lr, adamMode_t mode, const float decay, + const float *inv_scale) { + if (*noop_gmem == 1) return; + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - pow(beta1, *step); + beta2_correction = 1 - pow(beta2, *step); + } + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); + p_master += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]) * (*inv_scale); + g[i] = static_cast(r_g[ii]); + r_p[ii] = static_cast(p_master[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (*lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (*lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = static_cast(r_p[ii]); + p_master[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + +void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { + break; + } + } + + const auto g_in_type = tensor_lists[0][0].scalar_type(); + const auto p_in_type = tensor_lists[1][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 4: g, p, m, v + // case 5: g, p, m, v, p_master + TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5"); + + if (requires_64bit_indexing) { + if (tl_size == 4) { + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } else { + // g, p, m, v, p_master + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } + } else { + if (tl_size == 4) { + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } else { + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + } + } + AT_MUSA_CHECK(musaGetLastError()); +} + +void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { + break; + } + } + + const auto g_in_type = tensor_lists[0][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv + TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors"); + + if (requires_64bit_indexing) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + fp8_dtype, FP8_T, + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 0, "adam", + multi_tensor_apply<5, true>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + fp8_dtype, FP8_T, + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 0, "adam", + multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, + lr, (adamMode_t)mode, weight_decay);)); + } + AT_MUSA_CHECK(musaGetLastError()); +} + +void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, const float beta1, const float beta2, + const float epsilon, at::Tensor step, const int mode, + const int bias_correction, const float weight_decay, + at::Tensor inv_scale) { + using namespace at; + + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamCapturableFunctor(), beta1, beta2, + step.data_ptr(), bias_correction, epsilon, lr.data_ptr(), + (adamMode_t)mode, weight_decay, inv_scale.data_ptr());) + + AT_MUSA_CHECK(musaGetLastError()); +} + +void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor lr, const float beta1, const float beta2, + const float epsilon, at::Tensor step, const int mode, + const int bias_correction, const float weight_decay, + at::Tensor inv_scale) { + using namespace at; + + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "adam", + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamCapturableMasterFunctor(), beta1, beta2, + step.data_ptr(), bias_correction, epsilon, lr.data_ptr(), + (adamMode_t)mode, weight_decay, inv_scale.data_ptr());) + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.mu b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.mu new file mode 100644 index 0000000000..ce69549b5e --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.mu @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include +// Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include + +#include "common/recipe/recipe_common.muh" +#include "common/utils.muh" +#include "multi_tensor_apply.muh" +#include "type_shim.h" + +#define BLOCK_SIZE 256 + +struct ComputeScaleAndScaleInvFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<3> &tl, // NOLINT(*) + float max_fp8, bool force_pow_2_scales, + float epsilon) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + float *amax = reinterpret_cast(tl.addresses[0][tensor_loc]); + amax += chunk_idx * chunk_size; + + float *scale = reinterpret_cast(tl.addresses[1][tensor_loc]); + scale += chunk_idx * chunk_size; + + float *scale_inv = reinterpret_cast(tl.addresses[2][tensor_loc]); + scale_inv += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { + float scale_val = transformer_engine::compute_scale_from_amax( + amax[i_start], max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); + scale[i_start] = scale_val; + transformer_engine::reciprocal(scale_inv + i_start, scale_val); + } + } +}; + +void multi_tensor_compute_scale_and_scale_inv_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + float max_fp8, bool force_pow_2_scales, float epsilon) { + using namespace at; + + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.mu b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.mu new file mode 100644 index 0000000000..e9af2c65c2 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.mu @@ -0,0 +1,412 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.muh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) +} + +template +struct L2NormFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<1> &tl, // NOLINT(*), + float *output, float *output_per_tensor, + bool per_tensor, int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = reinterpret_cast(tl.addresses[0][tensor_loc]); + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] += next * next; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] += next * next; + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val += vals[i]; + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] += final; + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +template +struct UnscaleL2NormFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<1> &tl, // NOLINT(*), + const float *inv_scale, float *output, + float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = reinterpret_cast(tl.addresses[0][tensor_loc]); + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]) * (*inv_scale); + vals[ii] += next * next; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]) * (*inv_scale); + vals[ii] += next * next; + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val += vals[i]; + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] += final; + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +// Probably better to template, but since we are not likely to support other norm +template +struct MaxNormFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<1> &tl, // NOLINT(*), + float *output, float *output_per_tensor, + bool per_tensor, int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = reinterpret_cast(tl.addresses[0][tensor_loc]); + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + __shared__ float s_vals[512]; + + float vals[ILP]; // = {0}; // this probably works too but I want to be sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; + } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i])); + + float final = reduce_block_into_lanes_max_op(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + + chunk_idx] = final; + } + } +}; + +__global__ void cleanup(float *output, float *output_per_tensor, float *ret, float *ret_per_tensor, + bool per_tensor, int max_chunks_per_tensor) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) *ret = sqrt(final); + } + + if (per_tensor) { + float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final); + } +} + +__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, int max_chunks_per_tensor, + int norm_type, float alpha, float beta) { + __shared__ float vals[512]; + + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) val = output[threadIdx.x]; + + if (norm_type == 0) { + float final = reduce_block_into_lanes_max_op(vals, val); + if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final; + } else { + float final = reduce_block_into_lanes(vals, val); + if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); + } + } + + if (per_tensor) { + float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor; + + if (norm_type == 0) { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val = fmaxf(fabsf(val), fabsf(output_this_tensor[i])); + + float final = reduce_block_into_lanes_max_op(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final; + } else { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = + sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final); + } + } +} + +std::tuple multi_tensor_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + at::optional per_tensor_python) { + bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; + + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + auto output = at::zeros({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + if (per_tensor) { + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); + ret_per_tensor = at::empty({ntensors}, float_options); + } else { + ret_per_tensor = at::empty({0}, float_options); + } + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, per_tensor, + max_chunks_per_tensor);) + + AT_MUSA_CHECK(musaGetLastError()); + // AT_MUSA_CHECK(musaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to end. + // I could get rid of these by hacking the functor + multi tensor harness with persistence + // logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + const at::musa::OptionalMUSAGuard device_guard(device_of(output)); + auto stream = at::musa::getCurrentMUSAStream(); + cleanup<<>>( + output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, + ret.data_ptr(), per_tensor ? ret_per_tensor.data_ptr() : nullptr, per_tensor, + max_chunks_per_tensor); + + return std::tuple(ret, ret_per_tensor); +} + +std::tuple multi_tensor_unscale_l2norm_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + at::Tensor inv_scale, at::optional per_tensor_python) { + bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; + + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + auto output = at::zeros({320}, float_options); + + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; + + if (per_tensor) { + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } + output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); + ret_per_tensor = at::empty({ntensors}, float_options); + } else { + ret_per_tensor = at::empty({0}, float_options); + } + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda", + multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + UnscaleL2NormFunctor(), inv_scale.data_ptr(), + output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, per_tensor, + max_chunks_per_tensor);) + + AT_MUSA_CHECK(musaGetLastError()); + // AT_MUSA_CHECK(musaDeviceSynchronize()); + + // This involves one more small kernel launches, but will be negligible end to end. + // I could get rid of these by hacking the functor + multi tensor harness with persistence + // logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + const at::musa::OptionalMUSAGuard device_guard(device_of(output)); + auto stream = at::musa::getCurrentMUSAStream(); + cleanup<<>>( + output.data_ptr(), per_tensor ? output_per_tensor.data_ptr() : nullptr, + ret.data_ptr(), per_tensor ? ret_per_tensor.data_ptr() : nullptr, per_tensor, + max_chunks_per_tensor); + + return std::tuple(ret, ret_per_tensor); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.mu b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.mu new file mode 100644 index 0000000000..f1851af4a5 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.mu @@ -0,0 +1,120 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include +// Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include + +#include "multi_tensor_apply.muh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) +} + +template +struct ScaleFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<2> &tl, // NOLINT(*) + float scale) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + in_t *in = reinterpret_cast(tl.addresses[0][tensor_loc]); + in += chunk_idx * chunk_size; + + out_t *out = reinterpret_cast(tl.addresses[1][tensor_loc]); + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + bool finite = true; + in_t r_in[ILP]; + out_t r_out[ILP]; + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_in, in, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } + // store + load_store(out, r_out, i_start, 0); + } + } else { + // Non-divergent exit condition for __syncthreads, not necessary here + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_in[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) r_in[ii] = in[i]; + } + // note for clarification to future michael: + // From a pure memory dependency perspective, there's likely no point unrolling + // the write loop, since writes just fire off once their LDGs arrive. + // Put another way, the STGs are dependent on the LDGs, but not on each other. + // There is still compute ILP benefit from unrolling the loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + r_out[ii] = static_cast(r_in[ii]) * scale; + finite = finite && isfinite(r_in[ii]); + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) out[i] = r_out[ii]; + } + } + } + if (!finite) *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. + } +}; + +void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, float scale) { + using namespace at; + // The output (downscaled) type is always float. + // If build times suffer, think about where to put this dispatch, + // and what logic should be moved out of multi_tensor_apply. + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", + DISPATCH_FLOAT_HALF_AND_BFLOAT( + tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", + multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + ScaleFunctor(), scale);)) + AT_MUSA_CHECK(musaGetLastError()); + + // AT_MUSA_CHECK(musaDeviceSynchronize()); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.mu b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.mu new file mode 100644 index 0000000000..c11e6ce884 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.mu @@ -0,0 +1,203 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include "multi_tensor_apply.muh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +/** + * Perform fused SGD on multiple buffers + * N: number of tensors + * tl[0] : gradients + * tl[1] : weights + * tl[2] : momentum buffers + * tl[3] : fp16 weights (if appropriate) + * wd : weight_decay (scalar) + * momentum : momentum (scalar) + * dampening : momentum dampening (scalar) + * lr : learning rate (scalar) + * nesterov : enable nesterov (bool) + * first run : necessary for proper momentum handling & init + * wd_after_momentum : apply weight decay _after_ momentum instead of before + **/ +template +struct SGDFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, + TensorListMetadata& tl, // NOLINT(*) + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, bool wd_after_momentum, + float scale) { + // Early exit if we don't need to do anything + if (*noop_gmem) return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T_grad* grad_in = reinterpret_cast(tl.addresses[0][tensor_loc]); + grad_in += chunk_idx * chunk_size; + + T_weight* weight_in = reinterpret_cast(tl.addresses[1][tensor_loc]); + weight_in += chunk_idx * chunk_size; + + T_weight* mom_in = reinterpret_cast(tl.addresses[2][tensor_loc]); + mom_in += chunk_idx * chunk_size; + + at::Half* model_weights_out = nullptr; + if (N == 4) { + model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; + model_weights_out += chunk_idx * chunk_size; + } + + n -= chunk_idx * chunk_size; + + // Non-divergent exit condition for the __syncthreads + float incoming_grads[ILP]; + float incoming_weights[ILP]; + float incoming_moms[ILP]; + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + incoming_grads[ii] = 0; + incoming_weights[ii] = 0; + incoming_moms[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + incoming_grads[ii] = static_cast(grad_in[i]) * scale; + incoming_weights[ii] = static_cast(weight_in[i]); + incoming_moms[ii] = static_cast(mom_in[i]); + } + } + +// note for clarification to future michael: +// From a pure memory dependency perspective, there's likely no point unrolling +// the write loop, since writes just fire off once their LDGs arrive. +// Put another way, the STGs are dependent on the LDGs, but not on each other. +// There is still compute ILP benefit from unrolling the loop though. +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + // apply weight decay before momentum if necessary + if (wd != 0.f && !wd_after_momentum) incoming_grads[ii] += wd * incoming_weights[ii]; + + if (momentum != 0.f) { + if (!first_run) + incoming_moms[ii] = + incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; + else // initialize momentums to current incoming grads + incoming_moms[ii] = incoming_grads[ii]; + + if (nesterov) + incoming_grads[ii] += momentum * incoming_moms[ii]; + else + incoming_grads[ii] = incoming_moms[ii]; + } + + // Apply WD after momentum if desired + if (wd != 0.f && wd_after_momentum) incoming_grads[ii] += wd * incoming_weights[ii]; + + // adjust the weight and write out + weight_in[i] += (-lr * incoming_grads[ii]); + + // if necessary, write out an fp16 copy of the weights + if (N == 4) model_weights_out[i] = static_cast(weight_in[i]); + + // also write out the new momentum + if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; + } + } + } + } +}; + +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, float wd, + float momentum, float dampening, float lr, bool nesterov, bool first_run, + bool wd_after_momentum, float scale) { + auto num_tensors = tensor_lists.size(); + auto grad_type = tensor_lists[0][0].scalar_type(); + auto weight_type = tensor_lists[1][0].scalar_type(); + + if (num_tensors == 4) { + for (int i = 0; i < tensor_lists[3].size(); i++) + TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, + "Additional output tensors should always be fp16."); + } + + TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), + "expected noop flag to be on the same device as tensors"); + + // We have 3 possibilities to handle here, in terms of + // grad_type, param_type, momentum_type, requires_fp16_copy + // 1. fp16, fp16, fp16, No + // 2. fp32, fp32, fp32, No + // 3. fp16, fp32, fp32, Yes + // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case + // It's easier to hardcode these possibilities than to use + // switches etc. to handle the cross-product of cases where + // we don't want the majority of them. + + // Case 1. fp16, fp16, fp16, No + if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && + num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); + } + // Case 2. fp16, fp32, fp32, No + // else if (grad_type == at::ScalarType::Half && + // weight_type == at::ScalarType::Float && + // num_tensors == 3) { + // multi_tensor_apply<3>( + // BLOCK_SIZE, + // chunk_size, + // noop_flag, + // tensor_lists, + // SGDFunctor<3, at::Half, float>(), + // wd, + // momentum, + // dampening, + // lr, + // nesterov, + // first_run, + // wd_after_momentum); + // } + // Case 2. fp32, fp32, fp32, No + else if (grad_type == at::ScalarType::Float && // NOLINT(*) + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov, + first_run, wd_after_momentum, scale); + } + // Case 3. fp16, fp32, fp32, Yes + else if (grad_type == at::ScalarType::Half && // NOLINT(*) + weight_type == at::ScalarType::Float && num_tensors == 4) { + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov, + first_run, wd_after_momentum, scale); + } + // Case 4. fp32, fp32, fp32, Yes + else if (grad_type == at::ScalarType::Float && // NOLINT(*) + weight_type == at::ScalarType::Float && num_tensors == 4) { + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov, + first_run, wd_after_momentum, scale); + } else { + AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", + "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); + } + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/musa/pytorch/csrc/extensions/normalization.cpp new file mode 100644 index 0000000000..a32fc91f8f --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/normalization.cpp @@ -0,0 +1,275 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +#include +#include + +namespace transformer_engine::pytorch { +std::pair createOutputTensor(const NVTEShape &shape, DType dtype, + py::handle quantizer) { + std::vector shape_vec; + for (int i = 0; i < shape.ndim; i++) { + size_t t = shape.data[i]; + shape_vec.push_back(t); + } + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape_vec, dtype); +} +std::pair createOutputTensor(std::vector &shape, DType dtype, + py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} +} // namespace transformer_engine::pytorch + +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &mu, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; + const auto &dz_ = dz.contiguous(); + const auto &x_ = x.contiguous(); + const auto &mu_ = mu.contiguous(); + const auto &rsigma_ = rsigma.contiguous(); + const auto &gamma_ = gamma.contiguous(); + + auto dx = at::empty_like(x_); + auto dgamma = at::empty_like(gamma_); + auto dbeta = at::empty_like(gamma_); + transformer_engine::TensorWrapper workspace; + + auto dz_cu = makeTransformerEngineTensor(dz_); + auto x_cu = makeTransformerEngineTensor(x_); + auto mu_cu = makeTransformerEngineTensor(mu_); + auto rsigma_cu = makeTransformerEngineTensor(rsigma_); + auto gamma_cu = makeTransformerEngineTensor(gamma_); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + auto dbeta_cu = makeTransformerEngineTensor(dbeta); + + // This call populates tensors with the required config. + // nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + // dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + // Alloc space for Tensors. + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // Actual call to bwd kernel. + // nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + // dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)}; +} + +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object ln_out, py::handle quantizer, + DType out_dtype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; + using namespace transformer_engine; + + auto none = py::none(); + const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); + + TensorWrapper bias_tensor; + MaybeTensor bias_grad = std::nullopt; + if (bias.has_value()) { + bias_tensor = makeTransformerEngineTensor(*bias); + } + + // Tensor dimensions + size_t N = static_cast(input_tensor.size(0)); + size_t H = static_cast(input_tensor.size(1)); + std::vector size = {N, H}; + + // Construct Transformer Engine tensors + at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + + TensorWrapper ln_out_tensor; + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + py::object ln_output; + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // Use high precision output from normalization + NoneQuantizer q{none}; + std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype); + } else { + if (ln_out.is_none()) { + std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); + } else { + ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + } + TensorWrapper mu_cu = makeTransformerEngineTensor(mu); + TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); + + // Query workspace sizes + transformer_engine::TensorWrapper workspace; + // nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, + // ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + // Allocate workspaces + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // Launch kernel + // nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, + // ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + TensorWrapper cast_out_tensor; + if (ln_out.is_none()) { + std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); + } else { + cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + + nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + at::musa::getCurrentMUSAStream()); + } + + return {ln_out, py::cast(mu), py::cast(rsigma)}; +} + +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &rsigma, const at::Tensor &gamma, + const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; + const auto &dz_ = dz.contiguous(); + const auto &x_ = x.contiguous(); + const auto &rsigma_ = rsigma.contiguous(); + const auto &gamma_ = gamma.contiguous(); + + auto dx = at::empty_like(x_); + auto dgamma = at::empty_like(gamma_); + transformer_engine::TensorWrapper workspace; + + auto dz_cu = makeTransformerEngineTensor(dz_); + auto x_cu = makeTransformerEngineTensor(x_); + auto rsigma_cu = makeTransformerEngineTensor(rsigma_); + auto gamma_cu = makeTransformerEngineTensor(gamma_); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + + std::tie(dx, dgamma) = at::_fused_rmsnorm_backward( + dz_, rsigma_, x_, {x_.size(-1)}, 1e-5, gamma_); + + // This call populates tensors with the required config. + // nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + // dgamma_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + // Alloc space for Tensors. + // auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + // workspace = + // makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // Actual call to bwd kernel. + // nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + // dgamma_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + return {py::cast(dx), py::cast(dgamma)}; +} + +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object ln_out, py::handle quantizer, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; + using namespace transformer_engine; + + auto none = py::none(); + const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); + + // Tensor dimensions + size_t N = static_cast(input_tensor.shape().data[0]); + size_t H = static_cast(input_tensor.shape().data[1]); + + // Construct Transformer Engine tensors + auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + std::vector size = {N, H}; + TensorWrapper ln_out_tensor; + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + py::object ln_output; + + auto rsigma_cu = makeTransformerEngineTensor(rsigma); + const at::Tensor& th_input = input.cast(); + const at::Tensor& th_weight = weight.cast(); + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + // Use high precision output from normalization + NoneQuantizer q{none}; + std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype); + + at::Tensor th_out = ln_output.cast(); + std::tie(th_out, rsigma) = at::musa::FusedRMSNormForwardOut( + th_input, th_out, {th_input.size(-1)}, eps, th_weight); + + } else { + if (ln_out.is_none()) { + std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); + } else { + ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + + at::Tensor th_out = ln_out.cast(); + std::tie(th_out, rsigma) = at::musa::FusedRMSNormForwardOut( + th_input, th_out, {th_input.size(-1)}, eps, th_weight); + + } + // auto rsigma_cu = makeTransformerEngineTensor(rsigma); + + // Query workspace sizes + // transformer_engine::TensorWrapper workspace; + // nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), + // rsigma_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + // Allocate workspaces + // auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + // workspace = + // makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // Launch kernel + // nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), + // rsigma_cu.data(), workspace.data(), + // at::musa::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + // zero_centered_gamma, at::musa::getCurrentMUSAStream()); + + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + TensorWrapper cast_out_tensor; + if (ln_out.is_none()) { + std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); + } else { + cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer); + } + + nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr, + at::musa::getCurrentMUSAStream()); + } + + return {ln_out, py::none(), py::cast(rsigma)}; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/padding.cpp b/transformer_engine/musa/pytorch/csrc/extensions/padding.cpp new file mode 100644 index 0000000000..d1334b361d --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/padding.cpp @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +void fused_multi_row_padding(at::Tensor input, at::Tensor output, + std::vector input_row_list, + std::vector padded_input_row_list) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), + "Number of input row list and padded row list must match."); + NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); + NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); + + const int num_tensors = input_row_list.size(); + // Extract properties from PyTorch tensors + std::vector input_dptr_list, output_dptr_list; + std::vector> input_shape_list, output_shape_list; + std::vector input_type_list; + void* d_input_ptr = reinterpret_cast(input.data_ptr()); + void* d_output_ptr = reinterpret_cast(output.data_ptr()); + for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { + input_dptr_list.push_back(d_input_ptr); + output_dptr_list.push_back(d_output_ptr); + + // Move the input pointer to the next split. + char* input_char_ptr = reinterpret_cast(d_input_ptr); + const size_t input_dptr_offset = + input_row_list[tensor_id] * input.size(1) * input.element_size(); + input_char_ptr += input_dptr_offset; + d_input_ptr = reinterpret_cast(input_char_ptr); + + input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); + + // Move the output pointer to the next split. + char* output_char_ptr = reinterpret_cast(d_output_ptr); + const size_t output_dptr_offset = + padded_input_row_list[tensor_id] * output.size(1) * output.element_size(); + output_char_ptr += output_dptr_offset; + d_output_ptr = reinterpret_cast(output_char_ptr); + + output_shape_list.push_back( + {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + } + + // Construct TE tensors + std::vector nvte_input_list, nvte_output_list; + std::vector tensor_wrappers; + auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + transformer_engine::DType dtype) -> NVTETensor { + tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); + return tensor_wrappers.back().data(); + }; + + std::vector padded_num_rows_list; + for (size_t i = 0; i < input_dptr_list.size(); ++i) { + if (input_dptr_list[i] == nullptr || input_row_list[i] == 0) continue; + nvte_input_list.emplace_back( + make_tensor(input_dptr_list[i], input_shape_list[i], input_type_list[i])); + nvte_output_list.emplace_back( + make_tensor(output_dptr_list[i], output_shape_list[i], input_type_list[i])); + padded_num_rows_list.emplace_back(padded_input_row_list[i]); + } + + // Check tensor lists + NVTE_CHECK(nvte_output_list.size() == nvte_input_list.size(), + "Number of input and output tensors must match"); + NVTE_CHECK(padded_num_rows_list.size() == nvte_input_list.size() && + "Number of input and padded row list must match"); + + // Launch TE kernel + nvte_multi_padding(nvte_input_list.size(), nvte_input_list.data(), nvte_output_list.data(), + padded_num_rows_list.data(), at::musa::getCurrentMUSAStream()); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/permutation.mu b/transformer_engine/musa/pytorch/csrc/extensions/permutation.mu new file mode 100644 index 0000000000..d72a46d851 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/permutation.mu @@ -0,0 +1,314 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "extensions.h" + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { + using namespace transformer_engine::pytorch; + const int num_tokens = input.size(0); + int num_cols = input.size(1); + const int topK = indices.size(1); + + // Initialize the workspace on the first run + if (workspace.empty()) { + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kPrivateUse1).requires_grad(false); + + at::Tensor sorted_indices = torch::empty(max_expanded_token_num, options); + at::Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); + at::Tensor sorted_row_id = + torch::empty(max_expanded_token_num, + torch::dtype(torch::kInt32).device(torch::kPrivateUse1).requires_grad(false)); + + size_t temp_storage_bytes = 0; + int *temp_ptr = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, + temp_ptr, max_expanded_token_num); + at::Tensor temp_storage = torch::empty( + temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kPrivateUse1).requires_grad(false)); + + workspace.push_back(sorted_indices); + workspace.push_back(row_id); + workspace.push_back(sorted_row_id); + workspace.push_back(temp_storage); + } + + int *indices_ptr = reinterpret_cast(getDataPtr(indices, 0)); + int *sorted_indices_ptr = reinterpret_cast(getDataPtr(workspace[0], 0)); + int *row_id_ptr = reinterpret_cast(getDataPtr(workspace[1], 0)); + int *sorted_row_id_ptr = reinterpret_cast(getDataPtr(workspace[2], 0)); + + void *d_temp_storage = getDataPtr(workspace[3], 0); + size_t temp_storage_bytes = std::numeric_limits::max(); + + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, + sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, + num_tokens * topK); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + at::Tensor permuted_output = torch::empty( + {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kPrivateUse1).requires_grad(false)); + at::Tensor row_id_map = torch::empty( + {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kPrivateUse1).requires_grad(false)); + + auto stream = at::musa::getCurrentMUSAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto permuted_output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + {static_cast(permuted_output.size(0)), static_cast(num_cols)}, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast(num_tokens * topK)}, + transformer_engine::DType::kInt32); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + + nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), + row_id_map_cu.data(), transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols, + num_out_tokens, stream); + + return std::make_tuple(permuted_output, row_id_map, workspace); +} + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); +} + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + using namespace transformer_engine::pytorch; + int num_cols = input.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + at::Tensor unpermuted_output = torch::empty( + {num_tokens, num_cols}, torch::dtype(_st).device(torch::kPrivateUse1).requires_grad(false)); + + auto stream = at::musa::getCurrentMUSAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto unpermuted_output_cu = makeTransformerEngineTensor( + unpermuted_output.data_ptr(), + {static_cast(unpermuted_output.size(0)), static_cast(num_cols)}, dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + + nvte_unpermute(input_cu.data(), unpermuted_output_cu.data(), row_id_map_cu.data(), prob_cu.data(), + num_tokens, topK, num_cols, stream); + + return unpermuted_output; +} + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob) { + using namespace transformer_engine::pytorch; + const int topK = (prob.numel() > 0) ? prob.size(1) : 1; + const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); + int num_cols = input_bwd.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input_bwd.scalar_type(); + + // Output buffer alloc + at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(_st).device(torch::kPrivateUse1).requires_grad(false)); + at::Tensor prob_grad = torch::empty( + {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kPrivateUse1).requires_grad(false)); + + auto stream = at::musa::getCurrentMUSAStream().stream(); + + auto input_bwd_cu = makeTransformerEngineTensor( + input_bwd.data_ptr(), {static_cast(input_bwd.size(0)), static_cast(num_cols)}, + dtype); + auto act_grad_cu = makeTransformerEngineTensor( + act_grad.data_ptr(), {static_cast(act_grad.size(0)), static_cast(num_cols)}, + dtype); + auto input_fwd_cu = makeTransformerEngineTensor( + input_fwd.data_ptr(), {static_cast(input_fwd.size(0)), static_cast(num_cols)}, + dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); + + nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(), + row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), + num_tokens, topK, num_cols, 0, stream); + + return std::make_tuple(act_grad, prob_grad); +} + +// HACK(sherry): suppport fp32/fp64 router +std::tuple moe_permute_mask(const transformer_engine::DType dtype, + at::Tensor input, at::Tensor row_id_map, + at::Tensor probs, int num_tokens, + int num_experts, int num_out_tokens, + int hidden_size) { + using namespace transformer_engine::pytorch; + const transformer_engine::DType probs_dtype = GetTransformerEngineDType(probs.scalar_type()); + + at::Tensor output = + torch::empty({num_out_tokens, hidden_size}, + torch::dtype(input.dtype()).device(torch::kPrivateUse1).requires_grad(false)); + at::Tensor permuted_probs = + torch::empty({num_out_tokens}, + torch::dtype(probs.dtype()).device(torch::kPrivateUse1).requires_grad(false)); + + auto stream = at::musa::getCurrentMUSAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(num_tokens), static_cast(hidden_size)}, dtype); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), {static_cast(num_out_tokens), static_cast(hidden_size)}, + dtype); + auto row_id_map_cu = makeTransformerEngineTensor( + row_id_map.data_ptr(), + {static_cast(row_id_map.size(0)), static_cast(row_id_map.size(1))}, + transformer_engine::DType::kInt64); + auto probs_cu = makeTransformerEngineTensor( + probs.data_ptr(), {static_cast(num_tokens), static_cast(num_experts)}, probs_dtype); // probs dtype + auto permuted_probs_cu = makeTransformerEngineTensor( + permuted_probs.data_ptr(), {static_cast(num_out_tokens)}, probs_dtype); // probs dtype + + if(dtype == probs_dtype){ + nvte_permute_mask(input_cu.data(), output_cu.data(), row_id_map_cu.data(), probs_cu.data(), + permuted_probs_cu.data(), num_tokens, num_experts, num_out_tokens, hidden_size, + stream); + } + else{ + nvte_permute_mask_high_precision_probs(input_cu.data(), output_cu.data(), row_id_map_cu.data(), probs_cu.data(), + permuted_probs_cu.data(), num_tokens, num_experts, num_out_tokens, hidden_size, + stream); + } + return std::make_tuple(output, permuted_probs); +} + +std::tuple moe_unpermute_mask(const transformer_engine::DType dtype, + at::Tensor input, at::Tensor row_id_map, + at::Tensor merging_probs, + at::Tensor permuted_probs, int num_tokens, + int num_experts, int hidden_size) { + using namespace transformer_engine::pytorch; + const transformer_engine::DType probs_dtype = GetTransformerEngineDType(permuted_probs.scalar_type()); + + at::Tensor output = + torch::empty({num_tokens, hidden_size}, + torch::dtype(input.dtype()).device(torch::kPrivateUse1).requires_grad(false)); + at::Tensor unpermuted_probs = + torch::empty({num_tokens, num_experts}, + torch::dtype(permuted_probs.dtype()).device(torch::kPrivateUse1).requires_grad(false)); + + auto stream = at::musa::getCurrentMUSAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(hidden_size)}, + dtype); + auto output_cu = makeTransformerEngineTensor( + output.data_ptr(), {static_cast(num_tokens), static_cast(hidden_size)}, + dtype); + auto row_id_map_cu = makeTransformerEngineTensor( + row_id_map.data_ptr(), + {static_cast(row_id_map.size(0)), static_cast(row_id_map.size(1))}, + transformer_engine::DType::kInt64); + + auto merging_probs_cu = makeTransformerEngineTensor( + merging_probs.data_ptr(), {static_cast(num_tokens), static_cast(num_experts)}, + probs_dtype); +// auto permuted_probs_cu = makeTransformerEngineTensor(permuted_probs); + auto permuted_probs_cu = makeTransformerEngineTensor(permuted_probs.data_ptr(), {static_cast(num_tokens), static_cast(num_experts)}, + probs_dtype); + auto unpermuted_probs_cu = makeTransformerEngineTensor( + unpermuted_probs.data_ptr(), + {static_cast(num_tokens), static_cast(num_experts)}, probs_dtype); + + if(dtype == probs_dtype){ + nvte_unpermute_mask(input_cu.data(), output_cu.data(), row_id_map_cu.data(), + merging_probs_cu.data(), permuted_probs_cu.data(), unpermuted_probs_cu.data(), + num_tokens, num_experts, hidden_size, stream); + }else{ + nvte_unpermute_mask_high_precision_probs(input_cu.data(), output_cu.data(), row_id_map_cu.data(), + merging_probs_cu.data(), permuted_probs_cu.data(), unpermuted_probs_cu.data(), + num_tokens, num_experts, hidden_size, stream); + } + return std::make_tuple(output, unpermuted_probs); +} +// HACK(sherry) + + + +std::tuple moe_unpermute_mask_bwd_with_merging_probs( + const transformer_engine::DType dtype, at::Tensor fwd_output_grad, at::Tensor fwd_input, + at::Tensor merging_probs, at::Tensor row_id_map, int num_tokens, int num_experts, + int num_out_tokens, int hidden_size) { + using namespace transformer_engine::pytorch; + + at::Tensor fwd_input_grad = + torch::empty({num_out_tokens, hidden_size}, + torch::dtype(torch::kBFloat16).device(torch::kPrivateUse1).requires_grad(false)); + at::Tensor merging_probs_grad = + torch::empty({num_tokens, num_experts}, + torch::dtype(torch::kBFloat16).device(torch::kPrivateUse1).requires_grad(false)); + + auto stream = at::musa::getCurrentMUSAStream().stream(); + + auto fwd_output_grad_cu = makeTransformerEngineTensor( + fwd_output_grad.data_ptr(), + {static_cast(num_tokens), static_cast(hidden_size)}, dtype); + auto fwd_input_grad_cu = makeTransformerEngineTensor( + fwd_input_grad.data_ptr(), + {static_cast(num_out_tokens), static_cast(hidden_size)}, dtype); + auto fwd_input_cu = makeTransformerEngineTensor( + fwd_input.data_ptr(), {static_cast(num_out_tokens), static_cast(hidden_size)}, + dtype); + auto merging_probs_cu = makeTransformerEngineTensor( + merging_probs.data_ptr(), {static_cast(num_tokens), static_cast(num_experts)}, + dtype); + auto merging_probs_grad_cu = makeTransformerEngineTensor( + merging_probs_grad.data_ptr(), + {static_cast(num_tokens), static_cast(num_experts)}, dtype); + auto row_id_map_cu = makeTransformerEngineTensor( + row_id_map.data_ptr(), + {static_cast(row_id_map.size(0)), static_cast(row_id_map.size(1))}, + transformer_engine::DType::kInt64); + + nvte_unpermute_mask_bwd_with_merging_probs(fwd_output_grad_cu.data(), fwd_input_grad_cu.data(), + fwd_input_cu.data(), merging_probs_cu.data(), + merging_probs_grad_cu.data(), row_id_map_cu.data(), + num_tokens, num_experts, hidden_size, stream); + + return std::make_tuple(fwd_input_grad, merging_probs_grad); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/musa/pytorch/csrc/extensions/pybind.cpp new file mode 100644 index 0000000000..5475cc71ea --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/pybind.cpp @@ -0,0 +1,362 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "pybind.h" + +#include +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../extensions.h" +#include "common.h" + +namespace transformer_engine::pytorch { + +PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *Float8TensorBasePythonClass = nullptr; +PyTypeObject *Float8QuantizerClass = nullptr; +PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *MXFP8TensorBasePythonClass = nullptr; +PyTypeObject *MXFP8QuantizerClass = nullptr; + +PyTypeObject *MTFP8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *MTFP8TensorBasePythonClass = nullptr; +PyTypeObject *MTFP8QuantizerClass = nullptr; + +void init_float8_extension() { + if (Float8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); + Float8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); + Float8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); + Float8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); + NVTE_CHECK(Float8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch Float8 extension."); +} + +void init_mxfp8_extension() { + if (MXFP8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); + MXFP8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); + MXFP8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); + MXFP8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); + NVTE_CHECK(MXFP8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch MXFP8 extension."); +} + +void init_mtfp8_extension() { + if (MTFP8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.musa.pytorch.tensor.mtfp8_tensor"); + MTFP8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MTFP8Quantizer")); + MTFP8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MTFP8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.musa.pytorch.tensor.mtfp8_tensor_base"); + MTFP8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MTFP8TensorBase")); + + NVTE_CHECK(MTFP8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch MTFP8 extension."); + NVTE_CHECK(MTFP8TensorBasePythonClass != nullptr, + "Internal error: could not initialize pyTorch MTFP8 extension."); + NVTE_CHECK(MTFP8QuantizerClass != nullptr, + "Internal error: could not initialize pyTorch MTFP8 extension."); +} + +void init_extension() { + init_float8_extension(); + init_mxfp8_extension(); + init_mtfp8_extension(); +} + +} // namespace transformer_engine::pytorch + +#include "common/util/pybind_helper.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), + py::arg("output") = py::none(), py::arg("noop") = py::none()); + m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), + py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, + "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); + m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", + py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), + py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), + py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), + py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), + py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, + py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); + m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.", + py::call_guard()); + m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.", + py::call_guard()); + m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + py::arg("quantizer")); + m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu, + "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer")); + + // Permutation functions + m.def("moe_permute_fwd", moe_permute_fwd); + m.def("moe_permute_bwd", moe_permute_bwd); + m.def("moe_unpermute_fwd", moe_unpermute_fwd); + m.def("moe_unpermute_bwd", moe_unpermute_bwd); + + // Permutation with mask functions + m.def("moe_permute_mask", moe_permute_mask); + m.def("moe_unpermute_mask", moe_unpermute_mask); + m.def("moe_unpermute_mask_bwd_with_merging_probs", moe_unpermute_mask_bwd_with_merging_probs); + + // Softmax functions + m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD", + py::call_guard()); + m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD", + py::call_guard()); + m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward, + "Scaled Masked Softmax FWD", py::call_guard()); + m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward, + "Scaled Masked Softmax BWD", py::call_guard()); + m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward, + "Scaled Upper-Triangular Masked Softmax FWD", py::call_guard()); + m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward, + "Scaled Upper-Triangular Masked Softmax BWD", py::call_guard()); + m.def("scaled_aligned_causal_masked_softmax_forward", + &scaled_aligned_causal_masked_softmax_forward, + "Scaled Bottom-Right Corner Aligned Masked Softmax FWD", + py::call_guard()); + m.def("scaled_aligned_causal_masked_softmax_backward", + &scaled_aligned_causal_masked_softmax_backward, + "Scaled Bottom-Right Corner Aligned Masked Softmax BWD", + py::call_guard()); + + // Other granular functions + m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), + py::arg("sm_margin"), py::arg("zero_centered_gamma")); + m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm"); + m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), + py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), + py::arg("zero_centered_gamma")); + m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); + m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", + py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); + m.def("fused_multi_quantize_batch_init", &transformer_engine::pytorch::fused_multi_quantize_batch_init, "Fused Multi-tensor Init + Cast + Transpose", + py::arg("input_list"),py::arg("hidden_dim"), py::arg("m_splits"), py::arg("quantizer_list"), py::arg("otype")); + m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); + m.def("fused_attn_fwd", &fused_attn_fwd, + "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); + m.def("fused_attn_bwd", &fused_attn_bwd, + "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); + m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), + py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", + py::call_guard()); + m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", + py::call_guard()); + m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", + py::call_guard()); + m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, + "Update amax history and FP8 scale/scale_inv after reduction", + py::call_guard()); + m.def("fp8_block_scaling_compute_partial_amax", &fp8_block_scaling_compute_partial_amax, + "Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"), + py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len")); + m.def("fp8_block_scaling_partial_cast", &fp8_block_scaling_partial_cast, + "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), + py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), + py::arg("out_dtype")); + m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", + py::call_guard()); + // fused apply rope + m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", + py::call_guard()); + m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", + py::call_guard()); + m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format", + py::call_guard()); + m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format", + py::call_guard()); + + // Misc + m.def("get_cublasLt_version", &get_mublas_version, "Get cublasLt version", + py::call_guard()); + m.def("get_cudnn_version", &get_mudnn_version, "Get cuDNN version", + py::call_guard()); + m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); + + // Support THD format for Context Parallel + m.def("thd_read_half_tensor", &thd_read_half_tensor, + "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " + "tensor", + py::call_guard()); + m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction, + "Correct the second half of the softmax_lse", py::call_guard()); + m.def("thd_read_second_half_lse", &thd_read_second_half_lse, + "Read the second half of the softmax_lse", py::call_guard()); + m.def("thd_out_correction", &thd_out_correction, + "Correct the THD format output of context parallelism in forward pass", + py::call_guard()); + m.def("thd_grad_correction", &thd_grad_correction, + "Correct the THD format gradients of context parallelism in backward pass", + py::call_guard()); + m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, + "Generate partitioned indices for inputs in THD format", + py::call_guard()); + + // multi-tensor functions + m.def("multi_tensor_scale", &multi_tensor_scale_cuda, + "Fused overflow check + scale for a list of contiguous tensors", + py::call_guard()); + m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, + "Computes L2 norm for a list of contiguous tensors", + py::call_guard()); + m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda, + "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only " + "performed for L2 norm computation, and tensors are not updated)", + py::call_guard()); + m.def("multi_tensor_adam", &multi_tensor_adam_cuda, + "Compute and apply gradient update to parameters for Adam optimizer", + py::call_guard()); + m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, + "Compute and apply gradient update to parameters for Adam optimizer", + py::call_guard()); + m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, + "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " + "support and LR scheduling", + py::call_guard()); + m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda, + "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " + "support, LR scheduling and FP32 master weights", + py::call_guard()); + m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, + "Fused SGD optimizer for list of contiguous tensors", + py::call_guard()); + m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda, + "Fused compute scale and scale_inv from amax", py::call_guard()); + + // Data structures + py::class_(m, "FP8TensorMeta") + .def(py::init<>()) + .def_readwrite("scale", &transformer_engine::pytorch::FP8TensorMeta::scale) + .def_readwrite("scale_inv", &transformer_engine::pytorch::FP8TensorMeta::scale_inv) + .def_readwrite("amax_history", &transformer_engine::pytorch::FP8TensorMeta::amax_history); + + py::enum_(m, "FP8FwdTensors") + .value("GEMM1_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_INPUT) + .value("GEMM1_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_WEIGHT) + .value("GEMM1_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_OUTPUT) + .value("GEMM2_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_INPUT) + .value("GEMM2_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_WEIGHT) + .value("GEMM2_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_OUTPUT); + + py::enum_(m, "FP8BwdTensors") + .value("GRAD_OUTPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT1) + .value("GRAD_INPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT1) + .value("GRAD_OUTPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT2) + .value("GRAD_INPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); + + py::class_(m, "CommOverlapHelper") + .def(py::init<>(), py::call_guard()) + .def(py::init, + std::optional>(), + py::call_guard(), py::arg("world_group"), + py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); + + py::class_, transformer_engine::CommOverlapBase, + transformer_engine::CommOverlapCore>(m, "CommOverlap") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, + int, int, int, int, bool, bool, bool, bool>(), + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), + py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, + py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, + py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, + py::arg("atomic_gemm") = false, py::arg("use_ce") = false, py::arg("rs_overlap_first_gemm") = false) + .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlap::set_buffer_params); + + py::class_, + transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( + m, "CommOverlapP2P") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, + transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, + bool>(), + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), + py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, + py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, + py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, + py::arg("use_ce") = true, py::arg("aggregate") = false) + .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlapP2P::set_buffer_params); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/musa/pytorch/csrc/extensions/quantizer.cpp new file mode 100644 index 0000000000..8c6551d7c8 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/quantizer.cpp @@ -0,0 +1,324 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "common.h" +#include "pybind.h" +#include "torch/torch.h" +#include "util.h" + +namespace transformer_engine::pytorch { + +constexpr size_t MXFP8_BLOCK_SIZE = 32; + +Quantizer::Quantizer(const py::handle& quantizer) { + if (quantizer.is_none()) { + this->rowwise_usage = true; + this->columnwise_usage = true; + this->internal = false; + } else { + this->rowwise_usage = quantizer.attr("rowwise_usage").cast(); + this->columnwise_usage = quantizer.attr("columnwise_usage").cast(); + this->internal = quantizer.attr("internal").cast(); + this->quantizer = quantizer; + } +} + +Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + const at::Tensor& scale = quantizer.attr("scale").cast(); + const at::Tensor& amax = quantizer.attr("amax").cast(); + const DType type = quantizer.attr("dtype").cast(); + + this->amax = amax; + this->scale = scale; + this->dtype = type; +} + +std::pair NoneQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + at::TensorOptions opts; + opts = opts.dtype(GetATenDType(dtype)).device(torch::kPrivateUse1); + std::vector torch_shape; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + } + at::Tensor ret; + if (rowwise_data.has_value()) { + ret = std::move(*rowwise_data); + } else { + ret = at::empty(torch_shape, opts); + } + + TensorWrapper tensor; + tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); + return {std::move(tensor), py::cast(ret)}; +} + +void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + at::TensorOptions opts; + opts = opts.dtype(torch::kFloat32).device(torch::kPrivateUse1); + tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector rowwise_torch_shape; + std::vector columnwise_torch_shape; + + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size(); ++i) { + if (i < shape.size() - 1) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + rowwise_torch_shape.emplace_back(static_cast(shape[i])); + } + at::TensorOptions opts; + opts = opts.dtype(torch::kUInt8).device(torch::kPrivateUse1); + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_torch_shape, opts); + } + } + const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); + at::Tensor columnwise_data; + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + if (create_transpose) { + columnwise_data = at::empty(columnwise_torch_shape, opts); + } + const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + opts = opts.dtype(torch::kFloat32); + at::Tensor scale_inv = at::reciprocal(scale); + py::object ret; + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } + TensorWrapper tensor(this->get_scaling_mode()); + if (rowwise_usage) { + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (create_transpose) { + std::vector transposed_shape; + for (auto s : columnwise_torch_shape) { + transposed_shape.emplace_back(static_cast(s)); + } + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + this->set_quantization_params(&tensor); + return {std::move(tensor), std::move(ret)}; +} + +MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); +} + +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair MXFP8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); + at::TensorOptions opts; + at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, + columnwise_scale_inv; // TODO(pgadzinski) - change + opts = opts.dtype(torch::kUInt8).device(torch::kPrivateUse1); + auto last_dim = static_cast(torch_shape.back()); + + NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", torch_shape, ")"); + + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(torch_shape, opts); + } + auto sinv0 = roundup(numel / last_dim, 128); + auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); + rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{static_cast(sinv0), static_cast(sinv1)}); + } + + if (columnwise_usage) { + auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + auto sinv1 = roundup(last_dim, 128); + columnwise_data = at::empty(torch_shape, opts); + columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); + tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{static_cast(sinv0), static_cast(sinv1)}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); + ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + +MTFP8Quantizer::MTFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + dtype = quantizer.attr("dtype").cast(); + block_m = quantizer.attr("block_m").cast(); + block_n = quantizer.attr("block_n").cast(); + + NVTE_CHECK(block_m > 0); + if (block_m == 1) { + NVTE_CHECK(block_n > 0 && block_n % 16 == 0); + } else { + NVTE_CHECK((block_m % 16 == 0) && (block_m == block_n)); + } +} + +void MTFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data( + rowwise_data.data_ptr, static_cast(rowwise_data.dtype), rowwise_data.shape); + tensor->set_columnwise_data( + columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); +} + +std::pair MTFP8Quantizer::create_tensor( + const std::vector& shape, DType fake_dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + int64_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= static_cast(s); + } + const auto dim_n = torch_shape.back(); + const auto dim_m = numel / dim_n; + + TensorWrapper tensor(NVTE_MTFP8_BLOCK_SCALING); + auto opt = at::TensorOptions().device(torch::kPrivateUse1); + + at::Tensor data, rowwise_scale_inv; + at::Tensor columnwise_data, columnwise_scale_inv; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(torch_shape, opt.dtype(torch::kUInt8)); + } + + const auto sinv0 = (dim_m + block_m - 1) / block_m; + const auto sinv1 = (dim_n + block_n - 1) / block_n; + + rowwise_scale_inv = at::empty({sinv0, sinv1}, opt.dtype(torch::kFloat)); + tensor.set_rowwise_data(data.data_ptr(), dtype, shape); + tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat32, + std::vector{static_cast(sinv0), static_cast(sinv1)}); + } + + const bool can_not_share = (block_m != block_n); + if (columnwise_usage && can_not_share) { + const auto sinv0 = (dim_m + block_n - 1) / block_n; + const auto sinv1 = (dim_n + block_m - 1) / block_m; + + columnwise_data = at::empty(torch_shape, opt.dtype(torch::kUInt8)); + columnwise_scale_inv = at::empty({sinv0, sinv1}, opt.dtype(torch::kFloat)); + tensor.set_columnwise_data(columnwise_data.data_ptr(), dtype, shape); + tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat32, + std::vector{static_cast(sinv0), static_cast(sinv1)}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle MTFP8TensorClass(reinterpret_cast(MTFP8TensorBasePythonClass)); + ret = MTFP8TensorClass("rowwise_data"_a = data, + "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = dtype, + "quantizer"_a = quantizer); + } else { + py::handle MTFP8TensorClass(reinterpret_cast(MTFP8TensorPythonClass)); + ret = MTFP8TensorClass("shape"_a = torch_shape, + "dtype"_a = GetATenDType(fake_dtype), + "rowwise_data"_a = data, + "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = dtype, + "quantizer"_a = quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/musa/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/musa/pytorch/csrc/extensions/recipe.cpp new file mode 100644 index 0000000000..9a60b08b78 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/recipe.cpp @@ -0,0 +1,48 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include + +#include "common/common.h" +#include "extensions.h" + +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { + using namespace transformer_engine; + using namespace transformer_engine::pytorch; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories(num_tensors); + std::vector t_scales(num_tensors); + std::vector te_amax_histories(num_tensors); + std::vector te_scales(num_tensors); + for (size_t i = 0; i < num_tensors; i++) { + t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); + auto amax_sizes = amax_histories[i].sizes().vec(); + std::vector amax_shape{amax_sizes.begin(), amax_sizes.end()}; + t_amax_histories[i].data.shape = amax_shape; + t_amax_histories[i].data.dtype = DType::kFloat32; + + t_scales[i].data.dptr = scales[i].data_ptr(); + auto scale_sizes = scales[i].sizes().vec(); + std::vector scale_shape{scale_sizes.begin(), scale_sizes.end()}; + t_scales[i].data.shape = scale_shape; + t_scales[i].data.dtype = DType::kFloat32; + + te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); + te_scales[i] = reinterpret_cast(&t_scales[i]); + } + nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, + amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, + at::musa::getCurrentMUSAStream()); +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/musa/pytorch/csrc/extensions/softmax.cpp new file mode 100644 index 0000000000..447aecfb52 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/softmax.cpp @@ -0,0 +1,247 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { + using namespace transformer_engine::pytorch; + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); + + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, + at::musa::getCurrentMUSAStream()); + + return softmax_results; +} + +at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor) { + using namespace transformer_engine::pytorch; + + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), + output_grads_cu.data(), scale_factor, + at::musa::getCurrentMUSAStream()); + + return output_grads; +} + +at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { + using namespace transformer_engine::pytorch; + + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + if (!input.is_contiguous()) input = input.contiguous(); + if (!mask.is_contiguous()) mask = mask.contiguous(); + + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); + TORCH_CHECK(pad_batches == 1 || pad_batches == batches); + TORCH_CHECK(mask.size(1) == 1); + TORCH_CHECK(mask.size(2) == query_seq_len); + TORCH_CHECK(mask.size(3) == key_seq_len); + + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto mask_cu = makeTransformerEngineTensor(mask); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), + scale_factor, at::musa::getCurrentMUSAStream()); + + return softmax_results; +} + +at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, + float scale_factor) { + using namespace transformer_engine::pytorch; + + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), + output_grads_cu.data(), scale_factor, + at::musa::getCurrentMUSAStream()); + + return output_grads; +} + +at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { + using namespace transformer_engine::pytorch; + + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less"); + + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), + scale_factor, at::musa::getCurrentMUSAStream()); + + return softmax_results; +} + +at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, + at::Tensor softmax_results_, + float scale_factor) { + using namespace transformer_engine::pytorch; + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + TORCH_CHECK(output_grads.size(1) == output_grads.size(2)); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_upper_triang_masked_softmax_backward( + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, + at::musa::getCurrentMUSAStream()); + + return output_grads; +} + +at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { + using namespace transformer_engine::pytorch; + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + + AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less"); + AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); + AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1"); + + // Output + auto act_options = input.options().requires_grad(false); + auto softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + nvte_scaled_aligned_causal_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), + scale_factor, at::musa::getCurrentMUSAStream()); + + return softmax_results; +} + +at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, + at::Tensor softmax_results_, + float scale_factor) { + using namespace transformer_engine::pytorch; + + auto output_grads = output_grad_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); + + // Produce gradients in place. + nvte_scaled_aligned_causal_masked_softmax_backward( + output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, + at::musa::getCurrentMUSAStream()); + + return output_grads; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/musa/pytorch/csrc/extensions/swizzle.cpp new file mode 100644 index 0000000000..28af37ce3b --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/swizzle.cpp @@ -0,0 +1,120 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" +#include "transformer_engine/transformer_engine.h" + +void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (input.scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + return; + } + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = input.get_rowwise_scale_inv(); + } else { + scale_inv = input.get_columnwise_scale_inv(); + } + + auto input_shape = nvte_shape_to_vector(input.shape()); + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + + // Allocate memory for swizzled output. + auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kPrivateUse1); + std::vector scale_inv_shape_int; + for (size_t i = 0; i < scale_inv_shape.size(); ++i) { + scale_inv_shape_int.push_back(static_cast(scale_inv_shape[i])); + } + auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options); + void* scale_inv_dptr = scale_inv.data_ptr; + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, + scale_inv_shape); + } + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::musa::getCurrentMUSAStream()); + + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } +} + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kPrivateUse1); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input), + DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr, + getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::musa::getCurrentMUSAStream()); + + return swizzled_scale_inv; +} + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kPrivateUse1); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + // Return immediately if tensor is empty + if (scale_inv.numel() == 0) { + return swizzled_scale_inv; + } + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv), + NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::musa::getCurrentMUSAStream()); + + return swizzled_scale_inv; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/musa/pytorch/csrc/extensions/transpose.cpp new file mode 100644 index 0000000000..bb992164e8 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/transpose.cpp @@ -0,0 +1,434 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "ATen/core/TensorBody.h" +#include "extensions.h" +#include "pybind.h" +#include "util.h" +#include "common.h" +namespace transformer_engine::pytorch { + +void _batch_init_alloc_outputs(size_t hidden_dim, std::vector& m_splits, + std::vector>& quantizers, + std::vector quantizer_list, + std::vector& output_list, + std::vector& output_list_py, + transformer_engine::DType otype) { + using namespace py::literals; + int num_splits = m_splits.size(); + + // Validate all quantizers are consistent + bool rowwise_usage = quantizers[0]->rowwise_usage; + bool columnwise_usage = quantizers[0]->columnwise_usage; + transformer_engine::DType fp8_dtype = static_cast(quantizers[0].get())->dtype; + NVTEScalingMode scaling_mode = static_cast(quantizers[0].get())->get_scaling_mode(); + + for (size_t i = 1; i < quantizers.size(); i++) { + NVTE_CHECK(rowwise_usage == quantizers[i]->rowwise_usage, + "All quantizers must have same rowwise usage"); + NVTE_CHECK(columnwise_usage == quantizers[i]->columnwise_usage, + "All quantizers must have same columnwise usage"); + NVTE_CHECK(fp8_dtype == static_cast(quantizers[i].get())->dtype, + "All quantizers must have same dtype"); + } + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + + // size_t hidden_dim = input_view.size(1); + size_t fp8_elem_size = 1; // FP8 uses 1 byte per element + + // Precompute all shapes and sizes + std::vector> rowwise_shapes; + std::vector> columnwise_shapes; + std::vector rowwise_sizes; + std::vector columnwise_sizes; + + for (int i = 0; i < num_splits; i++) { + // Rowwise shape is [m_splits[i], hidden_dim] + std::vector r_shape = {(size_t)m_splits[i], hidden_dim}; + rowwise_shapes.push_back(std::move(r_shape)); + rowwise_sizes.push_back(m_splits[i] * hidden_dim * fp8_elem_size); + + // Columnwise shape is [hidden_dim, m_splits[i]] (transposed) + std::vector c_shape = {hidden_dim, (size_t)m_splits[i]}; + columnwise_shapes.push_back(std::move(c_shape)); + columnwise_sizes.push_back(hidden_dim * m_splits[i] * fp8_elem_size); + } + + // Compute total sizes for bulk allocation + size_t total_rowwise = std::accumulate(rowwise_sizes.begin(), rowwise_sizes.end(), 0); + size_t total_columnwise = std::accumulate(columnwise_sizes.begin(), columnwise_sizes.end(), 0); + + // Allocate memory in bulk + at::TensorOptions opts = at::TensorOptions() + .dtype(torch::kUInt8) + .device(torch::kMUSA); + + // Create scale inverse tensors (batched) + std::vector scale_tensors; + for (auto& quantizer : quantizers) { + Float8Quantizer* fq = static_cast(quantizer.get()); + scale_tensors.push_back(fq->scale); + } + + at::Tensor all_scales = torch::stack(scale_tensors); + at::Tensor all_scale_invs = at::reciprocal(all_scales); + + at::Tensor rowwise_full_tensor; + at::Tensor columnwise_full_tensor; + // each from_blob will hold a reference to the full tensor, since we need to keep the full tensor alive + // when all the views are gone, the full tensor will be garbage collected + std::shared_ptr rowwise_full_tensor_holder; + std::shared_ptr columnwise_full_tensor_holder; + + // Allocate and split rowwise data + std::vector rowwise_data_list; + if (rowwise_usage > 0) { + rowwise_full_tensor = at::empty({(int64_t)total_rowwise}, opts); + rowwise_full_tensor_holder = std::make_shared(rowwise_full_tensor); + uint8_t* rowwise_ptr = rowwise_full_tensor.data_ptr(); + + for (int i = 0; i < num_splits; i++) { + if (rowwise_sizes[i] == 0) { + rowwise_data_list.emplace_back(at::empty({static_cast(rowwise_shapes[i][0]), + static_cast(rowwise_shapes[i][1])}, + opts + )); + } else { + rowwise_data_list.emplace_back(at::from_blob( + rowwise_ptr, + {static_cast(rowwise_shapes[i][0]),static_cast(rowwise_shapes[i][1])}, + [rowwise_full_tensor_holder](void*) {}, // Keep buffer alive + opts + )); + } + // rowwise_data_list.push_back(tensor); + rowwise_ptr += rowwise_sizes[i]; + } + } + + // Allocate and split columnwise data + std::vector columnwise_data_list; + if (create_transpose > 0) { + columnwise_full_tensor = at::empty({(int64_t)total_columnwise}, opts); + columnwise_full_tensor_holder = std::make_shared(columnwise_full_tensor); + uint8_t* columnwise_ptr = columnwise_full_tensor.data_ptr(); + + for (int i = 0; i < num_splits; i++) { + if (columnwise_sizes[i] == 0) { + columnwise_data_list.emplace_back(at::empty({static_cast(columnwise_shapes[i][0]), + static_cast(columnwise_shapes[i][1])}, + opts + )); + } else { + columnwise_data_list.emplace_back(at::from_blob( + columnwise_ptr, + {static_cast(columnwise_shapes[i][0]),static_cast(columnwise_shapes[i][1])}, + [columnwise_full_tensor_holder](void*) {}, // Keep buffer alive + opts + )); + } + columnwise_ptr += columnwise_sizes[i]; + } + } + + float* scale_invs_ptr = all_scale_invs.data_ptr(); + + // Create output tensors and Python objects + for (int i = 0; i < num_splits; i++) { + + // Create Python Float8Tensor object + py::object rowwise_py = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); + py::object columnwise_py = create_transpose ? py::cast(columnwise_data_list[i]) : py::none(); + py::object scale_inv_py = py::cast(all_scale_invs[i]); + + py::object py_tensor; + if (quantizers[i]->internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + py_tensor = Float8TensorClass( + "data"_a = rowwise_py, + "fp8_scale_inv"_a = scale_inv_py, + "fp8_dtype"_a = fp8_dtype, + "data_transpose"_a = columnwise_py, + "quantizer"_a = quantizer_list[i] + ); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + std::vector rowwise_torch_shape = { + static_cast(rowwise_shapes[i][0]), + static_cast(rowwise_shapes[i][1]) + }; + py_tensor = Float8TensorClass( + "shape"_a = rowwise_torch_shape, + "dtype"_a = GetATenDType(otype), + "data"_a = rowwise_py, + "fp8_scale_inv"_a = scale_inv_py, + "fp8_dtype"_a = fp8_dtype, + "data_transpose"_a = columnwise_py, + "quantizer"_a = quantizer_list[i] + ); + } + output_list_py.emplace_back(std::move(py_tensor)); + + // as for tensor wrappers, these tensor wrappers are going to be quantized, so no need to insert empty tensors here + // even if m_split[i]==0 we also need to perform the operation below, + // otherwise will meet "Unable to cast Python instance of type to C++ type 'at::Tensor'" before following gemm + + // Create TensorWrapper + TensorWrapper tensor(scaling_mode); + + if (rowwise_usage) { + tensor.set_rowwise_data( + rowwise_data_list[i].data_ptr(), + fp8_dtype, + rowwise_shapes[i] + ); + // Explicitly specify the shape type as std::vector + tensor.set_rowwise_scale_inv>( + scale_invs_ptr + i, + DType::kFloat32, + {1} // Scale shape is always [1] + ); + + } + + if (create_transpose) { + tensor.set_columnwise_data( + columnwise_data_list[i].data_ptr(), + fp8_dtype, + columnwise_shapes[i] + ); + // Explicitly specify the shape type as std::vector + tensor.set_columnwise_scale_inv>( + scale_invs_ptr + i, + DType::kFloat32, + {1} // Scale shape is always [1] + ); + + } + + // Set quantization parameters + static_cast(quantizers[i].get())->set_quantization_params(&tensor); + if (m_splits[i] == 0) { + continue; + } + output_list.emplace_back(std::move(tensor)); + } +} + +std::vector fused_multi_quantize_batch_init(std::vector input_list, + size_t hidden_dim, + std::vector m_splits, + std::vector quantizer_list, + transformer_engine::DType otype) { + init_extension(); + std::vector nvte_inputs; + std::vector nvte_outputs; + std::vector py_outputs; + std::vector input_wrappers; + std::vector output_wrappers; + std::vector tensor_wrappers; + auto none = py::none(); + + // Validate inputs + NVTE_CHECK(input_list.size() == quantizer_list.size(), + "Input list and quantizer list must have same size"); + NVTE_CHECK(input_list.size() == m_splits.size(), + "Input list and m_splits must have same size"); + + // Convert quantizers + std::vector> quantizers; + for (auto& q : quantizer_list) { + quantizers.push_back(convert_quantizer(q)); + } + + // Check if we can use bulk allocation (all Float8 quantizers with same config) + bool use_batch_init = true; + if (!detail::IsFloat8Quantizers(quantizer_list[0].ptr())) { + use_batch_init = false; + } else { + auto* first_q = static_cast(quantizers[0].get()); + for (size_t i = 1; i < quantizers.size(); i++) { + auto* q = static_cast(quantizers[i].get()); + if (q->rowwise_usage != first_q->rowwise_usage || + q->columnwise_usage != first_q->columnwise_usage || + q->dtype != first_q->dtype) { + use_batch_init = false; + break; + } + } + } + + // Process inputs + if (use_batch_init) { + // Create input tensor wrappers + for (size_t i = 0; i < input_list.size(); i++) { + if (m_splits[i] == 0){ + continue; + } + auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + nvte_inputs.emplace_back(input_tensor.data()); + input_wrappers.emplace_back(std::move(input_tensor)); + } + + // Bulk allocate outputs + _batch_init_alloc_outputs(hidden_dim, m_splits, quantizers, quantizer_list, + output_wrappers, py_outputs, otype); + + // Prepare output tensor list + for (auto& wrapper : output_wrappers) { + if (wrapper.data()) { // Skip empty tensors + nvte_outputs.emplace_back(wrapper.data()); + } + } + } else { + // Fallback to original per-tensor allocation + for (size_t i = 0; i < input_list.size(); i++) { + auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + const NVTEShape input_shape = input_tensor.shape(); + + TensorWrapper output_tensor; + + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + py::object o; + std::tie(output_tensor, o) = + quantizers[i]->create_tensor(output_shape, otype); + py_outputs.push_back(o); + if (input_tensor.numel() == 0) continue; + + nvte_inputs.emplace_back(input_tensor.data()); + nvte_outputs.emplace_back(output_tensor.data()); + tensor_wrappers.emplace_back(std::move(input_tensor)); + tensor_wrappers.emplace_back(std::move(output_tensor)); + } + } + + // Validate tensor lists + NVTE_CHECK(nvte_outputs.size() == nvte_inputs.size(), + "Input/output tensor count mismatch"); + + // Check if we can use fused kernel + bool with_fused_kernel = true; + for (auto& tensor : nvte_outputs) { + if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING || + nvte_tensor_columnwise_data(tensor) == nullptr) { + with_fused_kernel = false; + break; + } + } + + // Launch TE kernel + if (with_fused_kernel) { + nvte_multi_cast_transpose(nvte_inputs.size(), nvte_inputs.data(), + nvte_outputs.data(), at::musa::getCurrentMUSAStream()); + } else { + for (size_t i = 0; i < nvte_outputs.size(); i++) { + nvte_quantize(nvte_inputs[i], nvte_outputs[i], + at::musa::getCurrentMUSAStream()); + } + } + + return py_outputs; +} +} + +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype) { + using namespace transformer_engine::pytorch; + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + std::vector py_output_objects_list; + std::vector tensor_wrappers; + auto none = py::none(); + + // create TE tensors from input + for (int i = 0; i < input_list.size(); i++) { + auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + const NVTEShape input_shape = input_tensor.shape(); + + transformer_engine::TensorWrapper output_tensor; + + if (output_list == std::nullopt) { + std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + py::object o; + std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype); + py_output_objects_list.push_back(o); + } else { + output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]); + } + if (input_tensor.numel() == 0) continue; + + nvte_tensor_output_list.emplace_back(output_tensor.data()); + nvte_tensor_input_list.emplace_back(input_tensor.data()); + tensor_wrappers.emplace_back(std::move(input_tensor)); + tensor_wrappers.emplace_back(std::move(output_tensor)); + } + + // Check tensor lists + NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), + "Number of input and output tensors must match"); + + // Choose implementation + // Note: Currently only have fused kernel for FP8 cast-transpose + bool with_fused_kernel = true; + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + const auto& tensor = nvte_tensor_output_list[i]; + if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) { + with_fused_kernel = false; + break; + } + if (nvte_tensor_columnwise_data(tensor) == nullptr) { + with_fused_kernel = false; + break; + } + } + + // Launch TE kernel + if (with_fused_kernel) { + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), at::musa::getCurrentMUSAStream()); + } else { + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], + at::musa::getCurrentMUSAStream()); + } + } + return py_output_objects_list; +} + +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output) { + using namespace transformer_engine::pytorch; + + const auto dim = input.dim(); + NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); + + if (input.dim() > 2) { + input = input.view({-1, input.size(dim - 1)}); + } + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + at::Tensor out; + if (output.has_value()) { + out = *output; + } else { + out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + } + if (M == 0 || N == 0) return out; + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype); + + nvte_transpose(input_cu.data(), output_cu.data(), at::musa::getCurrentMUSAStream()); + + return out; +} diff --git a/transformer_engine/musa/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/musa/pytorch/csrc/extensions/type_converters.cpp new file mode 100644 index 0000000000..714b00b743 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/type_converters.cpp @@ -0,0 +1,109 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" +#include "pybind.h" + +namespace transformer_engine::pytorch { +namespace detail { + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { + const at::Tensor &data = tensor.attr("_data").cast(); + const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast(); + float *scale_inv_dptr = reinterpret_cast(scale_inv.data_ptr()); + const DType dtype = tensor.attr("_fp8_dtype").cast(); + + const auto &shape = getTensorShape(data); + + bool transpose_valid = !tensor.attr("_transpose_invalid").cast(); + std::optional transpose = std::nullopt; + if (transpose_valid) { + transpose = tensor.attr("_transpose").cast>(); + } + + auto ret = TensorWrapper(quantizer->get_scaling_mode()); + + ret.set_rowwise_data(data.data_ptr(), dtype, shape); + if (transpose_valid && transpose != std::nullopt) { + const auto &transpose_shape = getTensorShape(*transpose); + ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape); + } + + const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type()); + const auto scale_inv_shape = getTensorShape(scale_inv); + ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + quantizer->set_quantization_params(&ret); + return ret; +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); + + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape); + } + + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0, + scale_inv_colwise_shape); + } + + quantizer->set_quantization_params(&ret); + return ret; +} + +TensorWrapper NVTETensorFromMTFP8Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(NVTE_MTFP8_BLOCK_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + if (rowwise_usage) { + const at::Tensor &rowwise_data = tensor.attr("_rowwise_data").cast(); + const auto &shape = getTensorShape(rowwise_data); + ret.set_rowwise_data(rowwise_data.data_ptr(), dtype, shape); + + const at::Tensor &rowwise_scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + const auto rowwise_scale_inv_shape = getTensorShape(rowwise_scale_inv); + ret.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat32, rowwise_scale_inv_shape); + } + + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + if (columnwise_usage) { + const at::Tensor &colwise_data = tensor.attr("_columnwise_data").cast(); + const auto &shape = getTensorShape(colwise_data); + ret.set_columnwise_data(colwise_data.data_ptr(), dtype, shape); + + const at::Tensor &colwise_scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + const auto colwise_scale_inv_shape = getTensorShape(colwise_scale_inv); + ret.set_columnwise_scale_inv(colwise_scale_inv.data_ptr(), DType::kFloat32, colwise_scale_inv_shape); + } + + quantizer->set_quantization_params(&ret); + return ret; +} + +} // namespace detail + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/musa/pytorch/csrc/extensions/util.cpp b/transformer_engine/musa/pytorch/csrc/extensions/util.cpp new file mode 100644 index 0000000000..4dc4f83282 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/extensions/util.cpp @@ -0,0 +1,16 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "util.h" + +// #include "ATen/cuda/CUDAContextLight.h" +#include "common/util/musa_runtime.h" + +bool non_tn_fp8_gemm_supported() { + // int major = at::cuda::getCurrentDeviceProperties()->major; + // return major >= 10; + return transformer_engine::cuda::sm_arch() >= 31; +} diff --git a/transformer_engine/musa/pytorch/csrc/multi_tensor_apply.muh b/transformer_engine/musa/pytorch/csrc/multi_tensor_apply.muh new file mode 100644 index 0000000000..3dcb283b57 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/multi_tensor_apply.muh @@ -0,0 +1,141 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "common/common.h" + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; +constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; + +template +struct TensorListMetadataBase { + void *addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +struct TensorListMetadata : public TensorListMetadataBase {}; + +template +struct TensorListMetadata : public TensorListMetadataBase { + void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]]; +}; + +template +__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl, + U callable, ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however + // it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag, + const std::vector> &tensor_lists, T callable, + ArgTypes... args) { + if constexpr (USE_FP8) { + TORCH_CHECK(tensor_lists.size() == depth + 3, + "tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, " + "amax, scale_inv) for fp8"); + } else { + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + } + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kPrivateUse1, "expected input to be on cuda"); + for (int l = 0; l < depth; l++) { // No range-based for because I need indices + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); + contiguous_memory = + (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d)); + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + if constexpr (USE_FP8) { + TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0, + "Size mismatch among tensor lists"); + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::musa::OptionalMUSAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::musa::getCurrentMUSAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + if constexpr (USE_FP8) { + for (int i = 0; i < 3; i++) + tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr(); + } + loc_tensor_info++; + + auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + multi_tensor_apply_kernel<<>>( + chunk_size, noop_flag.data_ptr(), tl, callable, args...); + + AT_MUSA_CHECK(musaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + } + if constexpr (USE_FP8) { + for (int i = 0; i < 3; i++) { + tl.fp8_meta_addresses[i][0] = tl.fp8_meta_addresses[i][loc_tensor_info - 1]; + } + } + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/transformer_engine/musa/pytorch/csrc/pybind.h b/transformer_engine/musa/pytorch/csrc/pybind.h new file mode 100644 index 0000000000..af6d061ea0 --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/pybind.h @@ -0,0 +1,119 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#define PYBIND11_DETAILED_ERROR_MESSAGES // TODO remove + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#include +#include +#include +#include + +#include "common.h" +#include "transformer_engine/transformer_engine.h" + +#if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4) +namespace pybind11::detail { + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); + + type_caster() : value(at::kFloat) {} + + bool load(handle src, bool) { + PyObject* obj = src.ptr(); + if (THPDtype_Check(obj)) { + value = reinterpret_cast(obj)->scalar_type; + return true; + } + return false; + } + + static handle cast( + const at::ScalarType& src, + return_value_policy /* policy */, + handle /* parent */) { + return Py_NewRef(torch::getTHPDtype(src)); + } +}; + +} // namespace pybind11::detail +#endif + +namespace transformer_engine::pytorch { + +extern PyTypeObject *Float8TensorPythonClass; +extern PyTypeObject *Float8TensorBasePythonClass; +extern PyTypeObject *Float8QuantizerClass; +extern PyTypeObject *MXFP8TensorPythonClass; +extern PyTypeObject *MXFP8TensorBasePythonClass; +extern PyTypeObject *MXFP8QuantizerClass; + +extern PyTypeObject *MTFP8TensorPythonClass; +extern PyTypeObject *MTFP8TensorBasePythonClass; +extern PyTypeObject *MTFP8QuantizerClass; + +void init_extension(); + +void init_float8_extension(); + +void init_mxfp8_extension(); + +namespace detail { +inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } +inline bool IsFloat8QParams(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } + +inline bool IsFloat8Tensor(PyObject *obj) { + return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; +} + +inline bool IsMXFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } + +inline bool IsMXFP8Tensor(PyObject *obj) { + return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; +} + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); + +template +std::unique_ptr CreateQuantizer(const py::handle quantizer) { + return std::make_unique(quantizer); +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantization_params); + +std::unique_ptr CreateMXFP8Params(const py::handle params); + +inline bool IsFloatingPointType(at::ScalarType type) { + return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; +} + +inline bool IsMTFP8Tensor(PyObject *obj) { + return Py_TYPE(obj) == MTFP8TensorPythonClass || Py_TYPE(obj) == MTFP8TensorBasePythonClass; +} + +inline bool IsMTFP8QParams(PyObject *obj) { return Py_TYPE(obj) == MTFP8QuantizerClass; } + +TensorWrapper NVTETensorFromMTFP8Tensor(py::handle tensor, Quantizer *quantization_params); + +std::unique_ptr CreateMTFP8Params(const py::handle params); + +constexpr std::array custom_types_converters = { + std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsMTFP8Tensor, IsMTFP8QParams, NVTETensorFromMTFP8Tensor, + CreateQuantizer), + std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, + CreateQuantizer)}; + +} // namespace detail + +} // namespace transformer_engine::pytorch + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ diff --git a/transformer_engine/musa/pytorch/csrc/type_shim.h b/transformer_engine/musa/pytorch/csrc/type_shim.h new file mode 120000 index 0000000000..e3cd94bfda --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/type_shim.h @@ -0,0 +1 @@ +../../../pytorch/csrc/type_shim.h \ No newline at end of file diff --git a/transformer_engine/musa/pytorch/csrc/util.h b/transformer_engine/musa/pytorch/csrc/util.h new file mode 120000 index 0000000000..c02b622b2d --- /dev/null +++ b/transformer_engine/musa/pytorch/csrc/util.h @@ -0,0 +1 @@ +../../../pytorch/csrc/util.h \ No newline at end of file diff --git a/transformer_engine/musa/pytorch/distributed.py b/transformer_engine/musa/pytorch/distributed.py new file mode 100644 index 0000000000..9580c9c81c --- /dev/null +++ b/transformer_engine/musa/pytorch/distributed.py @@ -0,0 +1,681 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Methods needed for distributed training (DP/TP).""" +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import warnings + +import torch +from torch.utils.checkpoint import detach_variable, noop_context_fn +from torch.nn import Identity + +from transformer_engine.pytorch.constants import dist_group_type +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast +from transformer_engine.pytorch.utils import safely_set_viewless_tensor_data +from transformer_engine.pytorch.distributed import ( + _get_cuda_rng_state, _get_active_autocast_contexts, activation_recompute_forward, + gather_split_1d_tensor, split_tensor_into_1d_equal_chunks, _set_cuda_rng_state, + has_te_modules,) + + +from transformer_engine.pytorch.tensor.float8_tensor import Float8TensorBase +from transformer_engine.pytorch.cpu_offload import set_offloading_param, get_fine_grained_offload_handler + +_USE_REENTRANT_ACTIVATION_RECOMPUTE = True + +# HACK(huang.huang): recompute-variance for [somefunc+fa] and [somefunc+linear/groupedLinear], +# which can save a forward for fa/linear when backward recompute +# 2025.4.7: support list of linear as last_function, and args "mid_function" to support complex situations +class IdentityTupleOp(torch.nn.Module): + """ + This is a placeholder for IdentityTupleOp(*args) -> args, + """ + + def __init__(self,): + super().__init__() + + def forward(self, *args): + return args + +class _CheckpointFunctionVirance(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + + @staticmethod + def forward( + ctx, + run_function: Callable, + last_function: Union[Callable, tuple[Callable]], + mid_function: Union[Callable, tuple[Callable]], + fine_grained_offload: bool, + distribute_saved_activations: bool, + get_rng_state_tracker: Union[Callable, None], + tp_group: Union[dist_group_type, None], + context_fn: Union[Callable, None], + kwargs: Dict[str, Any], + *args: Tuple[torch.Tensor, ...], + ) -> Tuple[torch.Tensor, ...]: + """Call forward function while saving state to be able to + redo the computation later.""" + if not isinstance(last_function, tuple): + last_function = (last_function, ) + mid_function = tuple(IdentityTupleOp() for _ in last_function) if mid_function is None else mid_function + ctx.run_function = run_function + ctx.last_function = last_function + ctx.mid_function = mid_function + ctx.distribute_saved_activations = distribute_saved_activations + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + if get_rng_state_tracker is not None: + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() + + if context_fn is not None: + forward_ctx, recompute_ctx = context_fn() + else: + forward_ctx, recompute_ctx = noop_context_fn() + # Preserve torch autocast context for the backward pass + torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts() + with torch.no_grad(), forward_ctx: + with activation_recompute_forward(activation_recompute=True, recompute_phase=False): + outputs = run_function(*args) + outputs = outputs if isinstance(outputs, tuple) else (outputs, ) + total_outputs = [] + for i, func in enumerate(last_function): + outputs_f = mid_function[i](*outputs) + outputs_f = outputs_f if isinstance(outputs_f, tuple) else (outputs_f, ) + outputs_f = func(*outputs_f) + total_outputs.append(outputs_f) + if len(total_outputs)==1: + #maintain original behavior when only one last_function + total_outputs=total_outputs[0] + else: + flat_outputs = [] + for outputs_f in total_outputs: + if isinstance(outputs_f, tuple): + #Manually remove bias_out which is 'None', and assign 'None' to grad-bias in the corresponding backward direction + outputs_f = tuple([x for x in outputs_f if x is not None]) + flat_outputs.append(outputs_f) + total_outputs = flat_outputs + #The reentrant version does not consider tensors in nested structures (e.g., custom objects, lists, dicts, etc) + # as participating in autograd, while the non-reentrant version does + total_outputs = sum( [x if isinstance(x, tuple) else (x,) for x in total_outputs ], tuple()) + # Divide hidden states across model parallel group and only keep + # the chunk corresponding to the current rank. + if distribute_saved_activations: + ctx.input_0_shape = args[0].data.shape + safely_set_viewless_tensor_data( + args[0], + split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True), + ) + + # Store everything. + ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args] + tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args] + + fine_grained_offload_handler = get_fine_grained_offload_handler() + if fine_grained_offload and not fine_grained_offload_handler.is_last_layer(): + assert len(tensor_inputs) == 2 # [input, prob] + fc1_input = tensor_inputs[0] + if isinstance(fc1_input, Float8TensorBase): + # now type(fc1_input) == bf16-Tensor + fc1_input_data = fc1_input._data + set_offloading_param(fc1_input_data, 'fine_grained_offloading', 'fc1_inp') + ctx.fc1_input_scale_inv = fc1_input._scale_inv + ctx.fc1_input_fp8_dtype = fc1_input._fp8_dtype + ctx.tensor_tags = fine_grained_offload_handler.register_offload(fc1_input_data) + + ctx.save_for_backward(*tensor_inputs[1:]) + else: + ctx.fc1_input_scale_inv = None + set_offloading_param(fc1_input, 'fine_grained_offloading', 'fc1_inp') + ctx.tensor_tags = fine_grained_offload_handler.register_offload(fc1_input) + + ctx.save_for_backward(*tensor_inputs[1:]) + # fc1_input_prob = tensor_inputs[1] + # set_offloading_param(fc1_input_prob, 'fine_grained_offloading', 'fc1_inp_prob') + # ctx.tensor_tags_1 = fine_grained_offload_handler.register_offload(fc1_input_prob) + else: + ctx.tensor_tags = None + ctx.save_for_backward(*tensor_inputs) + + fp8 = FP8GlobalStateManager.is_fp8_enabled() + ctx.get_rng_state_tracker = get_rng_state_tracker + ctx.tp_group = tp_group + ctx.recompute_ctx = recompute_ctx + ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx + ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx + ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.kwargs = kwargs + + return total_outputs + + @staticmethod + def backward( + ctx, *args: Tuple[Union[torch.Tensor, None], ...] + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Call backward function with activation recomputation.""" + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), please use .backward() if possible" + ) + + if ctx.tensor_tags is None: + inputs = tuple( + t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs) + ) + else: + fine_grained_offload_handler = get_fine_grained_offload_handler() + assert not fine_grained_offload_handler.is_last_layer() + fc1_input = fine_grained_offload_handler.wait_reload(ctx.tensor_tags) + if ctx.fc1_input_scale_inv is not None: + fc1_input = Float8TensorBase(data=fc1_input, + fp8_scale_inv=ctx.fc1_input_scale_inv, + fp8_dtype=ctx.fc1_input_fp8_dtype) + fc1_input.requires_grad = True + # need grad when reload a detached cpu tensor + inputs = tuple( + t if t is not None else arg for (t, arg) in zip((fc1_input, *ctx.saved_tensors), ctx.inputs) + ) + + get_rng_state_tracker = ctx.get_rng_state_tracker + + if ctx.distribute_saved_activations: + safely_set_viewless_tensor_data( + inputs[0], + gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape), + ) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + if get_rng_state_tracker is not None: + bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) + if get_rng_state_tracker is not None: + get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + # ori_outputs is not requires_grad + + with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( + activation_recompute=True, recompute_phase=True + ), fp8_autocast( + enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe + ): + outputs = ctx.run_function(*detached_inputs) + outputs = outputs if isinstance(outputs, tuple) else (outputs, ) + total_outputs = [] + for i,func in enumerate(ctx.mid_function): + outputs_f = func(*outputs) + if isinstance(outputs_f, torch.Tensor): + outputs_f = [outputs_f,] + total_outputs.append(outputs_f) + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) + if get_rng_state_tracker is not None: + get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) + + + #backward_custom need to be executed under this context while something like self.fp8 will change outside of context + total_grad_input = [] + for i,func in enumerate(ctx.last_function): + if isinstance(func, Identity): + grad_input_f = args[i] + else: + grad_out_bias = args[i] if isinstance(args[i], tuple) else (args[i], None) + grad_input_f = func.backward_custom(*total_outputs[i], *grad_out_bias) + if isinstance(grad_input_f, torch.Tensor): + grad_input_f = (grad_input_f,) + total_grad_input.append(grad_input_f) + + + total_outputs_with_grad = [] + total_args_with_grad = [] + for j, outputs in enumerate(total_outputs): + outputs_with_grad = [] + args_with_grad = [] + for i, output in enumerate(outputs): + if torch.is_tensor(output) and output.requires_grad: + outputs_with_grad.append(output) + args_with_grad.append(total_grad_input[j][i]) + total_outputs_with_grad += outputs_with_grad + total_args_with_grad += args_with_grad + + if len(total_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True, this checkpoint() is not necessary" + ) + torch.autograd.backward(total_outputs_with_grad, total_args_with_grad) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs + ) + return (None, None, None, None, None, None, None, None, None) + grads + + +@torch._disable_dynamo +def checkpointVirance( + function: Callable, + last_function: Callable, + *args: Tuple[torch.Tensor, ...], + mid_function: Optional[Callable] = None, + fine_grained_offload: bool = False, + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + """ + Checkpoint a part of the model by trading compute for memory. This function is based on + `torch.utils.checkpoint.checkpoint `_. + + .. warning:: + + It is the user's responsibility to ensure identical behavior when calling + :attr:`function` from the forward and backward pass. If different output is + produced (e.g. due to global state), then the checkpointed version won't + be numerically equivalent. + + .. warning:: + `use_reentrant=False` does not support early stopping, and will execute the entire forward + pass for the checkpointed module when recomputing activations in the backward pass. + + Parameters + ---------- + function: Callable + pytorch module used to run the forward and backward passes using + the specified :attr:`args` and :attr:`kwargs`. + distribute_saved_activations: bool, default = False + if set to `True` and `use_reentrant=True`, first tensor argument is distributed + across the specified tensor parallel group (`tp_group`) before saving it for the + backward pass. This has no effect when `use_reentrant=False`. + get_rng_state_tracker: `Callable`, default = None + python callable which returns an instance of :func:`CudaRNGStatesTracker`. + tp_group : ProcessGroup, default = None + tensor parallel process group. Used only when `distribute_saved_activations=True` + and `use_reentrant=True`. If `None`, it falls back to the default group. + use_reentrant : bool, default = True + perform checkpointing in reentrant mode. + args : tuple + tuple of torch tensors for inputs to :attr:`function`. + kwargs : dict + dictionary of string keys for keyword arguments to :attr:`function`. + """ + # Pop out te.distributed.checkpoint() arguments + global _USE_REENTRANT_ACTIVATION_RECOMPUTE + _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True) + distribute_saved_activations = kwargs.pop("distribute_saved_activations", False) + tp_group = kwargs.pop("tp_group", None) + get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None) + + # Ensure backward compatibility. + if ( + len(args) > 3 + and isinstance(args[0], bool) + and callable(args[1]) + and isinstance(args[2], None | dist_group_type) + ): + warnings.warn( + "Passing non-tensor non-keyword arguments is deprecated and support will be removed in " + "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and " + "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.", + DeprecationWarning, + stacklevel=2, + ) + distribute_saved_activations = args[0] + get_rng_state_tracker = args[1] + tp_group = args[2] + args = args[3:] + + # Trigger the native PyTorch checkpoint if the function is not or does not contain a + # Transformer Engine module. + context_fn = kwargs.pop("context_fn", noop_context_fn) + determinism_check = kwargs.pop("determinism_check", "default") + debug = kwargs.pop("debug", False) + + assert has_te_modules(function) and has_te_modules(last_function), "only support when has te modules" + + # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need + # to scatter/gather activations that we will recompute anyway. + setattr(function, "fsdp_wrapped", False) + setattr(function, "fsdp_group", None) + if isinstance(last_function, tuple): + for func in last_function: + setattr(func, "fsdp_wrapped", False) + setattr(func, "fsdp_group", None) + else: + setattr(last_function, "fsdp_wrapped", False) + setattr(last_function, "fsdp_group", None) + if mid_function is not None: + if isinstance(mid_function, tuple): + setattr(func, "fsdp_wrapped", False) + setattr(func, "fsdp_group", None) + else: + setattr(mid_function, "fsdp_wrapped", False) + setattr(mid_function, "fsdp_group", None) + # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments + # and execute TE's own checkpointing + # NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we + # cannot be sure there are no TE modules inside the function. It also means we might run + # the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential + # user context function. + del determinism_check, debug + + if _USE_REENTRANT_ACTIVATION_RECOMPUTE: + # If saved activations need to be distributed but there is no process group, + # default to the world group. + if distribute_saved_activations: + assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group + + return _CheckpointFunctionVirance.apply( + function, + last_function, + mid_function, + fine_grained_offload, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + context_fn, + kwargs, + *args, + ) + + +class _CheckpointFunctionViranceAttention(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + + @staticmethod + def forward( + ctx, + run_function: Callable, + last_function: Callable, + distribute_saved_activations: bool, + get_rng_state_tracker: Union[Callable, None], + tp_group: Union[dist_group_type, None], + context_fn: Union[Callable, None], + kwargs: Dict[str, Any], + *args: Tuple[torch.Tensor, ...], + ) -> Tuple[torch.Tensor, ...]: + """Call forward function while saving state to be able to + redo the computation later.""" + ctx.run_function = run_function + ctx.last_function = last_function + ctx.distribute_saved_activations = distribute_saved_activations + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + if get_rng_state_tracker is not None: + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() + + if context_fn is not None: + forward_ctx, recompute_ctx = context_fn() + else: + forward_ctx, recompute_ctx = noop_context_fn() + + # Preserve torch autocast context for the backward pass + torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts() + with torch.no_grad(), forward_ctx: + with activation_recompute_forward(activation_recompute=True, recompute_phase=False): + outputs = run_function(*args) + outputs = last_function.forward_before_fa(*outputs[:4], **outputs[4]) + outputs = last_function.forward_fa(*outputs) + #outputs: Union[output=Union[Tensor output, Tensor logsumexp, Tensor dropout_mask], + # qkv_format, indices_q, batch_size, attn_mask_type, max_seqlen_q, q_shape, v_shape] + core_attn_out = last_function.forward_after_fa(*outputs) + + # Divide hidden states across model parallel group and only keep + # the chunk corresponding to the current rank. + if distribute_saved_activations: + ctx.input_0_shape = args[0].data.shape + safely_set_viewless_tensor_data( + args[0], + split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True), + ) + + # Store everything. + ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args] + [None]*len(outputs[0]) #pad None to match len of tensor_inputs + tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args] + ctx.save_for_backward(*tensor_inputs, *outputs[0]) + + fp8 = FP8GlobalStateManager.is_fp8_enabled() + ctx.get_rng_state_tracker = get_rng_state_tracker + ctx.tp_group = tp_group + ctx.recompute_ctx = recompute_ctx + ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx + ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx + ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.kwargs = kwargs + (ctx.qkv_format, ctx.indices_q, ctx.batch_size, + ctx.attn_mask_type, ctx.max_seqlen_q, ctx.q_shape, ctx.v_shape) = outputs[1:] + + return core_attn_out + + @staticmethod + def backward( + ctx, *args: Tuple[Union[torch.Tensor, None], ...] + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Call backward function with activation recomputation.""" + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), please use .backward() if possible" + ) + + inputs = tuple( + t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs) + ) + fa_output = inputs[-3:] + inputs = inputs[:-3] + get_rng_state_tracker = ctx.get_rng_state_tracker + + if ctx.distribute_saved_activations: + safely_set_viewless_tensor_data( + inputs[0], + gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape), + ) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + if get_rng_state_tracker is not None: + bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) + if get_rng_state_tracker is not None: + get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(inputs) + detached_ori_outputs = detach_variable(fa_output) + detached_ori_outputs[0].requires_grad = True #only 0 element need grad in output of FA: [Tensor output, Tensor logsumexp, Tensor dropout_mask] + # ori_outputs is not requires_grad + + with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( + activation_recompute=True, recompute_phase=True + ), fp8_autocast( + enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe + ): + outputs_before_fa = ctx.run_function(*detached_inputs) + # outputs_before_fa: query, key, value, attention_mask, {"attn_mask_type":attn_mask_type, "attention_bias":attention_bias, "packed_seq_params":packed_seq_params} + outputs_before_fa = ctx.last_function.forward_before_fa(*outputs_before_fa[:4], **outputs_before_fa[4]) + outputs = ctx.last_function.forward_after_fa(detached_ori_outputs, + ctx.qkv_format, ctx.indices_q, + ctx.batch_size, ctx.attn_mask_type, + ctx.max_seqlen_q, ctx.q_shape, ctx.v_shape) + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) + if get_rng_state_tracker is not None: + get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + outputs_with_grad = [] + args_with_grad = [] + for i, output in enumerate(outputs): + if torch.is_tensor(output) and output.requires_grad: + outputs_with_grad.append(output) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True, this checkpoint() is not necessary" + ) + + torch.autograd.backward(outputs_with_grad, args_with_grad) + + #costum bwd fa + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): + with torch.no_grad(): + grad_input = torch.ops.aten._scaled_dot_product_attention_flash_musa_backward( + # ori_outputs[0][0].grad, + detached_ori_outputs[0].grad, + *outputs_before_fa[:3], #q, k, v + *detached_ori_outputs, #(Tensor output, Tensor logsumexp, Tensor dropout_mask) + is_causal="causal" in ctx.attn_mask_type, #causal same as fwd + ) + + #bwd before fa: for qkv + torch.autograd.backward(outputs_before_fa[:3], grad_input) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs + ) + return (None, None, None, None, None, None, None) + grads + + +@torch._disable_dynamo +def checkpointViranceAttention( + function: Callable, + last_function: Callable, + *args: Tuple[torch.Tensor, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + """ + Checkpoint a part of the model by trading compute for memory. This function is based on + `torch.utils.checkpoint.checkpoint `_. + + .. warning:: + + It is the user's responsibility to ensure identical behavior when calling + :attr:`function` from the forward and backward pass. If different output is + produced (e.g. due to global state), then the checkpointed version won't + be numerically equivalent. + + .. warning:: + `use_reentrant=False` does not support early stopping, and will execute the entire forward + pass for the checkpointed module when recomputing activations in the backward pass. + + Parameters + ---------- + function: Callable + pytorch module used to run the forward and backward passes using + the specified :attr:`args` and :attr:`kwargs`. + distribute_saved_activations: bool, default = False + if set to `True` and `use_reentrant=True`, first tensor argument is distributed + across the specified tensor parallel group (`tp_group`) before saving it for the + backward pass. This has no effect when `use_reentrant=False`. + get_rng_state_tracker: `Callable`, default = None + python callable which returns an instance of :func:`CudaRNGStatesTracker`. + tp_group : ProcessGroup, default = None + tensor parallel process group. Used only when `distribute_saved_activations=True` + and `use_reentrant=True`. If `None`, it falls back to the default group. + use_reentrant : bool, default = True + perform checkpointing in reentrant mode. + args : tuple + tuple of torch tensors for inputs to :attr:`function`. + kwargs : dict + dictionary of string keys for keyword arguments to :attr:`function`. + """ + # Pop out te.distributed.checkpoint() arguments + global _USE_REENTRANT_ACTIVATION_RECOMPUTE + _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True) + distribute_saved_activations = kwargs.pop("distribute_saved_activations", False) + tp_group = kwargs.pop("tp_group", None) + get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None) + + # Ensure backward compatibility. + if ( + len(args) > 3 + and isinstance(args[0], bool) + and callable(args[1]) + and isinstance(args[2], None | dist_group_type) + ): + warnings.warn( + "Passing non-tensor non-keyword arguments is deprecated and support will be removed in " + "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and " + "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.", + DeprecationWarning, + stacklevel=2, + ) + distribute_saved_activations = args[0] + get_rng_state_tracker = args[1] + tp_group = args[2] + args = args[3:] + + # Trigger the native PyTorch checkpoint if the function is not or does not contain a + # Transformer Engine module. + context_fn = kwargs.pop("context_fn", noop_context_fn) + determinism_check = kwargs.pop("determinism_check", "default") + debug = kwargs.pop("debug", False) + + assert has_te_modules(function) and has_te_modules(last_function), "only support when has te modules" + + # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need + # to scatter/gather activations that we will recompute anyway. + setattr(function, "fsdp_wrapped", False) + setattr(function, "fsdp_group", None) + setattr(last_function, "fsdp_wrapped", False) + setattr(last_function, "fsdp_group", None) + # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments + # and execute TE's own checkpointing + # NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we + # cannot be sure there are no TE modules inside the function. It also means we might run + # the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential + # user context function. + del determinism_check, debug + + if _USE_REENTRANT_ACTIVATION_RECOMPUTE: + # If saved activations need to be distributed but there is no process group, + # default to the world group. + if distribute_saved_activations: + assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group + + return _CheckpointFunctionViranceAttention.apply( + function, + last_function, + distribute_saved_activations, + get_rng_state_tracker, + tp_group, + context_fn, + kwargs, + *args, + ) +# HACK(huang.huang) + +from .utils import add_attr +import transformer_engine +add_attr(transformer_engine.pytorch.distributed, "checkpointViranceAttention", checkpointViranceAttention) +add_attr(transformer_engine.pytorch.distributed, "checkpointVirance", checkpointVirance) \ No newline at end of file diff --git a/transformer_engine/musa/pytorch/fp8.py b/transformer_engine/musa/pytorch/fp8.py new file mode 100644 index 0000000000..8f85325ca8 --- /dev/null +++ b/transformer_engine/musa/pytorch/fp8.py @@ -0,0 +1,261 @@ +from pydantic.dataclasses import dataclass +from typing import Tuple, Optional, Dict, Any +import os + +import torch, torch_musa + +import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + Format, + Recipe, +) +from transformer_engine.pytorch.fp8 import ( + RecipeState, + get_fp8_te_dtype, + FP8GlobalStateManager, + DelayedScaling +) +from .tensor.mtfp8_tensor import ( + MTFP8Quantizer, + MTFP8Tensor +) +from transformer_engine.pytorch.utils import get_device_compute_capability + +from .utils import add_attr, wrap_attr, replace_attr + + +@dataclass() +class MTFP8BlockScaling(Recipe): + margin: int = 0 + fp8_format: Format = Format.HYBRID + fp8_dpa: bool = False + fp8_mha: bool = False + + tile_size: int = 128 + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert self.tile_size == 128, "Only supports 128 tile_size yet." + + def __repr__(self) -> str: + return ( + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}, " + f"tile_size={self.tile_size}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) + + +def musa_recipe_mtfp8(self): + return isinstance(self, MTFP8BlockScaling) + + +def common_recipe___init___workaround(): + from transformer_engine.common import recipe + add_attr(recipe, "MTFP8BlockScaling", MTFP8BlockScaling) + add_attr(recipe.Recipe, "mtfp8", musa_recipe_mtfp8) + replace_attr(recipe, "MXFP8BlockScaling", MTFP8BlockScaling) +common_recipe___init___workaround() + +def replace_mtfp8_tensor(): + from transformer_engine.pytorch.tensor import mxfp8_tensor + replace_attr(mxfp8_tensor, "MXFP8Tensor", MTFP8Tensor) +replace_mtfp8_tensor() + +class MTFP8BlockScalingRecipeState(RecipeState): + recipe: MTFP8BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: MTFP8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + activation_blocks = { + "block_m": 1, + "block_n": self.recipe.tile_size, + } + weight_blocks = { + "block_m": self.recipe.tile_size, + "block_n": self.recipe.tile_size, + } + + if mode == "forward": + assert num_quantizers % 3 == 0 + n_gemms = self.num_quantizers // 3 + self.blocks = [activation_blocks] * n_gemms + self.blocks += [weight_blocks] * n_gemms + self.blocks += [activation_blocks] * n_gemms + else: + assert num_quantizers % 2 == 0 + n_gemms = self.num_quantizers // 2 + self.blocks = [activation_blocks] * n_gemms + self.blocks += [activation_blocks] * n_gemms + + if device is None: + device = torch.device("musa") + + def make_quantizers(self) -> list: + return [MTFP8Quantizer( + self.dtype, + **(self.blocks[i % self.num_quantizers]), + ) for i in range(self.num_quantizers)] + + +def musa_recipe_state_create( + recipe: Recipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, +) -> RecipeState: + if recipe.mtfp8(): + return MTFP8BlockScalingRecipeState( + recipe, + mode=mode, + num_quantizers=num_quantizers, + device=device, + ) + return RecipeState._orig_create( + recipe, + mode=mode, + num_quantizers=num_quantizers, + device=device, + ) + + +def musa_check_fp8_support() -> Tuple[bool, str]: + if get_device_compute_capability() >= (3, 1): + return True, "" + return False, "Device compute capability 3.1 or higher required for FP8 execution." + + +@classmethod +def musa_add_fp8_tensors_to_global_buffer( + cls, + fp8_meta: Dict[str, Any], +) -> None: + if fp8_meta["recipe"].mtfp8(): + return + cls._orig_add_fp8_tensors_to_global_buffer(fp8_meta) + + +@classmethod +def musa_copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: + if fp8_meta["recipe"].mtfp8(): + return + cls._orig_copy_forward_fp8_meta_tensors_for_recompute(fp8_meta) + + +@classmethod +def musa_get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any], quantizers=None) -> None: + if fp8_meta["recipe"].mtfp8(): + return + # [Previous Version HACK - Preserved for historical context] + #HACK(huang.huang): not call _orig_get_old_fp8_meta_tensors_for_recompute directly while needs + #to modify the ori implement of get_old_fp8_meta_tensors_for_recompute; + #add .clone() when save meta into updated*, otherwise updated tensor will change along with meta and cause precision issue + # + # [New Optimization HACK - Pointer Swap for D2D Overhead] + #Replace clone()/copy() with pointer swapping to avoid D2D transfers (~100μs each). + # + # [New Fix HACK - change scale in quantizer instead of fp8_meta] + #Set the stash scale to the quantizer, as the scale used in the cast is actually the scale saved in quantizer, not fp8_meta + #On the other hand, since scale in quantizer and fp8_meta are not same ptr after pointer swapping, it's not necessary to save fp8_meta. + #Since we only update scale with amax once forward_step or backward_step finished, it's okay to temporarily decouple fp8_meta and quantizer + if not int(os.getenv("USE_RECOMPUTE_VARIANCE", 0)): + cls._orig_get_old_fp8_meta_tensors_for_recompute(fp8_meta) + else: + # below is revised vesrion of ori get_old_fp8_meta_tensors_for_recompute + + # Retrieve stashed amaxes and scales from phase 1 pre forward. + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" + stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() + + # Replace amaxes and scales with stashed values for phase 2 forward + for i, quantizer in enumerate(quantizers["scaling_fwd"]): + quantizer.amax_history = stashed_fp8_meta[0][0][i] + quantizer.scale = stashed_fp8_meta[1][i] + #HACK(huang.huang) + +def musa_restore_fp8_meta_tensors(fp8_meta: Dict[str, Any], quantizers=None) -> None: + if fp8_meta["recipe"].mtfp8(): + return + #HACK(huang.huang): Replace clone()/copy() with pointer swapping to avoid D2D transfers (~100μs each), + # worked with musa_get_old_fp8_meta_tensors_for_recompute + # [New Fix HACK - change scale in quantizer instead of fp8_meta] + # restore scale in quantizer from fp8_meta + if not int(os.getenv("USE_RECOMPUTE_VARIANCE", 0)): + FP8GlobalStateManager._orig_restore_fp8_meta_tensors(fp8_meta) + else: + # below is revised vesrion of ori restore_fp8_meta_tensors + for i, quantizer in enumerate(quantizers["scaling_fwd"]): + quantizer.amax_history = fp8_meta["scaling_fwd"].amax_history[0][i] + quantizer.scale = fp8_meta["scaling_fwd"].scale[i] + ##HACK(huang.huang) + +def musa_get_default_fp8_recipe() -> Recipe: + """FP8 recipe with default args.""" + if os.getenv("FP8_PER_BLOCK", False): + return MTFP8BlockScaling() + return DelayedScaling() + +#HACK(huang.huang): add flag `skip` to change the behavior of reduce which is not necessary in recompute +# TE will reduce amx history once exit a recompute context in forward and backward, we move them to the end of forward and backward, +# the corresponding call is in megatron/core/pipeline_parallel/schedules.py +@classmethod +def musa_reduce_and_update_fp8_tensors( + cls, + forward: bool = True, + skip: bool = True, +) -> None: + if skip: + return + cls._orig_reduce_and_update_fp8_tensors(forward) +#HACK(huang.huang) + +def pytorch_fp8_workaround(): + from transformer_engine.pytorch import fp8 + add_attr(fp8, "MTFP8BlockScalingRecipeState", MTFP8BlockScalingRecipeState) + wrap_attr(fp8.RecipeState, "create", musa_recipe_state_create) + replace_attr(fp8, "check_fp8_support", musa_check_fp8_support) + wrap_attr( + fp8.FP8GlobalStateManager, + "add_fp8_tensors_to_global_buffer", + musa_add_fp8_tensors_to_global_buffer, + ) + wrap_attr( + fp8.FP8GlobalStateManager, + "copy_forward_fp8_meta_tensors_for_recompute", + musa_copy_forward_fp8_meta_tensors_for_recompute, + ) + wrap_attr( + fp8.FP8GlobalStateManager, + "get_old_fp8_meta_tensors_for_recompute", + musa_get_old_fp8_meta_tensors_for_recompute, + ) + wrap_attr( + fp8.FP8GlobalStateManager, + "restore_fp8_meta_tensors", + musa_restore_fp8_meta_tensors, + ) + if int(os.getenv("USE_RECOMPUTE_VARIANCE", 0)): + wrap_attr( + fp8.FP8GlobalStateManager, + "reduce_and_update_fp8_tensors", + musa_reduce_and_update_fp8_tensors, + ) + replace_attr(fp8, "get_default_fp8_recipe", musa_get_default_fp8_recipe) + + +pytorch_fp8_workaround() diff --git a/transformer_engine/musa/pytorch/module/__init__.py b/transformer_engine/musa/pytorch/module/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/transformer_engine/musa/pytorch/module/base.py b/transformer_engine/musa/pytorch/module/base.py new file mode 100644 index 0000000000..bdfcaca39b --- /dev/null +++ b/transformer_engine/musa/pytorch/module/base.py @@ -0,0 +1,82 @@ +import os +from typing import Any, Dict, Generator +from contextlib import contextmanager +import torch +from transformer_engine.pytorch.fp8 import Recipe, FP8GlobalStateManager +from ..fp8 import MTFP8BlockScalingRecipeState +from transformer_engine.pytorch.distributed import ( + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, +) +from ..utils import wrap_attr, replace_attr + + +def musa_set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + + if self.fp8_meta_tensors_initialized: + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if recipe.mtfp8() and isinstance(recipe_state, MTFP8BlockScalingRecipeState): + return + + self._orig_set_meta_tensor(fwd, recipe) + +#HACK(huang.huang): support Pointer Swap for D2D Overhead +#just change the args parse to get_old_fp8_meta_tensors_for_recompute and restore_fp8_meta_tensors +@contextmanager +def TransformerEngineBaseModule_prepare_forward( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, +) -> Generator[torch.Tensor, None, None]: + """Checks and prep for FWD. + The context manager is needed because there isn't a way for a module to know + if it's the last FP8 module in the forward autocast. It is useful + to setup the forward aggregated amax reduction for every module + just in case. The autocast exit will pick up the most recent one. + """ + if not int(os.getenv("USE_RECOMPUTE_VARIANCE", 0)): + with self._orig_prepare_forward(inp, num_gemms, allow_non_contiguous) as processed_inp: + yield processed_inp + else: + # Activation recomputation is used and this is the second forward phase. + if self.fp8 and in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta, self.quantizers) + else: + assert inp.is_cuda, "TransformerEngine needs CUDA." + + if self.tp_size > 1: + assert self.tp_group_initialized, "TP group not initialized." + + self.set_activation_dtype(inp) + self.init_fp8_metadata(num_gemms=num_gemms) + + if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): + assert self.fp8_meta["recipe"].reduce_amax, ( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) + + if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) + + # Activation recomputation is used and this is the first forward phase. + if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): + FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) + + with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + yield inp + + if self.fp8 and in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta, self.quantizers) +##HACK(huang.huang) + +def pytorch_module_base_workaround(): + from transformer_engine.pytorch.module.base import TransformerEngineBaseModule + from transformer_engine.pytorch import fp8 + wrap_attr(TransformerEngineBaseModule, "set_meta_tensor", musa_set_meta_tensor) + replace_attr(TransformerEngineBaseModule, "prepare_forward", TransformerEngineBaseModule_prepare_forward) +pytorch_module_base_workaround() diff --git a/transformer_engine/musa/pytorch/module/grouped_linear.py b/transformer_engine/musa/pytorch/module/grouped_linear.py new file mode 100644 index 0000000000..5b32f57a09 --- /dev/null +++ b/transformer_engine/musa/pytorch/module/grouped_linear.py @@ -0,0 +1,680 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""GroupedLinear API""" +import os +from typing import Union, Optional, Callable, Tuple, List + +import torch + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.module.base import ( + get_multi_stream_cublas_workspace, + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.utils import ( + cast_if_needed, + assert_dim_for_fp8_exec, + clear_tensor_data, + requires_grad, +) +from transformer_engine.pytorch.distributed import ( + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, +) +from transformer_engine.pytorch import cpp_extensions as ceg + +from transformer_engine.pytorch.cpp_extensions import ( + general_grouped_gemm, +) +from transformer_engine.pytorch.jit import no_torch_dynamo +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.cpu_offload import is_cpu_offload_enabled + +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + +# HACK(huang.huang): recompute-variance for groupedLinear: add functions "backward_custom" +@no_torch_dynamo() +def backward_custom(self, inp, m_splits, grad_output: torch.Tensor, is_first_microbatch=None) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring + #recompute forward before gemm + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + + with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + num_gemms = len(m_splits) + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + if not self.fp8: + weight_tensors = [ + w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors + ] + input_quantizers, weight_quantizers, output_quantizers = ( + [None] * self.num_gemms, + [None] * self.num_gemms, + [None] * self.num_gemms, + ) + grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms + if self.fp8: + input_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + input_quantizers[i].internal = True + weight_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["weight"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + weight_quantizers[i].internal = True + if torch.is_grad_enabled(): + grad_output_quantizers = [ + self.quantizers["scaling_bwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + grad_output_quantizers[i].internal = True + + # code in _GroupedLinear-forward + + num_gemms = len(m_splits) + weights = weight_tensors + biases = bias_tensors + device = inp.device + + # Make sure input dimensions are compatible + in_features = weights[0].shape[-1] + assert inp.shape[-1] == in_features, "GEMM not possible" + inputmats = torch.split(inp.view(-1, in_features), m_splits) + if self.fp8: + assert_dim_for_fp8_exec(*inputmats, *weights) + + # Cast input to expected dtype + inputmats_no_fp8 = [cast_if_needed(mat, self.activation_dtype) for mat in inputmats] + inputmats = [] + + weight_requires_grad = weights[0].requires_grad + if input_quantizers[0] is not None: + for input_quantizer in input_quantizers: + input_quantizer.set_usage( + rowwise=True, + columnwise=(torch.is_grad_enabled and weight_requires_grad), + ) + columnwise_usage = torch.is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + if weight_quantizers[0] is not None: + for weight_quantizer in weight_quantizers: + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + if output_quantizers[0] is not None: + for output_quantizer in output_quantizers: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + if self.fp8: + if os.getenv("ENABLE_CAST_BATCH_INIT", "0") == "0": + inputmats = tex.fused_multi_quantize( + inputmats_no_fp8, None, input_quantizers, TE_DType[self.activation_dtype] + ) + else: + inputmats = tex.fused_multi_quantize_batch_init( + inputmats_no_fp8, + in_features, + m_splits, + input_quantizers, + TE_DType[self.activation_dtype], + ) + weights_fp8 = [] + bias_dtype = torch.bfloat16 if self.activation_dtype == torch.float32 else self.activation_dtype + if not isinstance(weights[0], QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + for i in range(num_gemms): + weight_fp8 = self.get_weight_workspace( + tensor=weights[i], + quantizer=weight_quantizers[i], + cache_name=(None if is_first_microbatch is None else f"weight{i}"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + ) + weights_fp8.append(weight_fp8) + else: + weights_fp8 = weights + + else: + inputmats = inputmats_no_fp8 + bias_dtype = self.activation_dtype + weights_fp8 = [cast_if_needed(weight, self.activation_dtype) for weight in weights] + use_bias = self.apply_bias and not self.gemm_bias_unfused_add + assert use_bias is False, 'not support bias for custom backwrd now!' #TODO: support bias + biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases + + # general_grouped_gemm forward + if self.fp8_calibration: + for i in range(num_gemms): + # amax of input + for i in range(num_gemms): + input_quantizers[i].calibrate(inputmats[i]) + for i in range(num_gemms): + weight_quantizers[i].calibrate(weights[i]) + #in ori forward, above is all code before gemm which get output + main_grads = [w.main_grad for w in weights] + weights_requires_grad = weights[0].requires_grad + weights_shape_1 = weights[0].shape[1] + + reduce_and_update_bwd_fp8_tensors = False + if self.fp8 and requires_grad(inp, weights[0], biases[0]): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + #backward + with torch.cuda.nvtx.range("_GroupedLinear_backward"): + weights = weights_fp8 + + if is_cpu_offload_enabled() and self.fuse_wgrad_accumulation: # TOSO + for i in self.num_gemms: + w = torch.nn.Parameter(weights[i], weights[i].requires_grad) + w.main_grad = main_grads[i] + weights[i] = w + + # preprocess grad_output + + grad_output = grad_output.contiguous() + grad_output_mats = torch.split( + grad_output.view(-1, grad_output.shape[-1]), m_splits + ) + in_features = grad_output.shape[-1] + grad_output = [None] * self.num_gemms + grad_biases = [None] * self.num_gemms + + if self.fp8: + if use_bias: + for i in range(self.num_gemms): + grad_biases[i], grad_output[i] = tex.bgrad_quantize( + grad_output_mats[i], grad_output_quantizers[i] + ) + else: + if os.getenv("ENABLE_CAST_BATCH_INIT", "0") == "0": + grad_output = tex.fused_multi_quantize( + grad_output_mats, + None, + grad_output_quantizers, + TE_DType[self.activation_dtype], + ) + else: + grad_output = tex.fused_multi_quantize_batch_init( + grad_output_mats, + in_features, + m_splits, + grad_output_quantizers, + TE_DType[self.activation_dtype], + ) + else: + grad_output = grad_output_mats + if is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + self.fuse_wgrad_accumulation and not is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = self.fuse_wgrad_accumulation + + if inp.requires_grad: + dgrad = torch.empty( + (sum(m_splits), weights_shape_1), + dtype=self.activation_dtype, + device=device, + ) + + general_grouped_gemm( + weights, + grad_output, + torch.split(dgrad, m_splits), + self.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NN", + m_splits=m_splits, + grad=True, + use_split_accumulator=_2X_ACC_DGRAD, + ) + + if weights_requires_grad: + if self.fuse_wgrad_accumulation: + # wgrad_list = [w.main_grad for w in weights] + wgrad_list = main_grads + else: + wgrad_list = [ + torch.empty(w.size(), dtype=self.activation_dtype, device=device) + for w in weights + ] + # WGRAD + _, grad_biases_, _ = ceg.general_grouped_gemm( + inputmats, + grad_output, + wgrad_list, + self.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NT", + grad=True, + m_splits=m_splits, + use_bias=use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ) + for i in range(self.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ + + if os.getenv("ENABLE_ZERO_BUBBLE", "0") == "0": + # Deallocate input tensor + clear_tensor_data(*inputmats) + + def handle_custom_ddp_from_mcore(w, wgrad): + if weights_requires_grad: + if self.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): + w.grad_added_to_main_grad = True + if getattr(w, "zero_out_wgrad", False): + wgrad = torch.zeros( + w.main_grad.shape, + dtype=w.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + wgrad = torch.empty( + w.main_grad.shape, + dtype=w.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + elif self.fuse_wgrad_accumulation: + wgrad = None + else: + wgrad = None + return wgrad + + wgrad_list = [ + handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) + ] + else: + wgrad_list = [None] * self.num_gemms + + if not use_bias: + grad_biases = [None] * self.num_gemms + + if reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + dgrad = dgrad.view(inp.shape) if inp.requires_grad else None + + inp.grad = dgrad #TODO: really need set grad for input? will it cause memory leak? + #call post-backward hook mannually + for weight, wgrad in zip(weights, wgrad_list): + if weight.requires_grad and not self.fuse_wgrad_accumulation: + weight.grad = wgrad + if weight.grad is not None and ( + not weight.grad_added_to_main_grad or getattr(weight, 'zero_out_wgrad', False) + ): + weight.main_grad.add_(weight.grad.data) + weight.grad = None + #TODO: support grad bias update + return dgrad +# HACK(huang.huang) + + +# HACK(huang.huang): optimze multi cast in groupedLinear: use fused_multi_quantize_batch_init in fwd and bwd when ENABLE_CAST_BATCH_INIT=1 +@staticmethod +def _GroupedLinear_forward( + ctx, + inp: torch.Tensor, + m_splits: List[int], + use_bias: bool, + is_first_microbatch: Union[bool, None], + fp8: bool, + fp8_calibration: bool, + input_quantizers: List[Quantizer], + weight_quantizers: List[Quantizer], + output_quantizers: List[Quantizer], + grad_output_quantizers: List[Quantizer], + fuse_wgrad_accumulation: bool, + cpu_offloading: bool, + sequence_parallel: bool, + activation_dtype: torch.dtype, + is_grad_enabled: bool, + module, + skip_fp8_weight_update, + *weights_and_biases, +) -> torch.Tensor: + + # pylint: disable=missing-function-docstring + num_gemms = len(m_splits) + weights = weights_and_biases[:num_gemms] + biases = weights_and_biases[num_gemms:] + device = inp.device + + # Make sure input dimensions are compatible + in_features = weights[0].shape[-1] + assert inp.shape[-1] == in_features, "GEMM not possible" + inputmats = torch.split(inp.view(-1, in_features), m_splits) + if fp8: + assert_dim_for_fp8_exec(*inputmats, *weights) + + # Cast input to expected dtype + inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] + inputmats = [] + + weight_requires_grad = weights[0].requires_grad + + if input_quantizers[0] is not None: + for input_quantizer in input_quantizers: + input_quantizer.set_usage( + rowwise=True, + columnwise=(is_grad_enabled and weight_requires_grad), + ) + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + if weight_quantizers[0] is not None: + for weight_quantizer in weight_quantizers: + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + if output_quantizers[0] is not None: + for output_quantizer in output_quantizers: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8: + if os.getenv("ENABLE_CAST_BATCH_INIT", "0") == "0": + inputmats = tex.fused_multi_quantize( + inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] + ) + else: + inputmats = tex.fused_multi_quantize_batch_init( + inputmats_no_fp8, + in_features, + m_splits, + input_quantizers, + TE_DType[activation_dtype], + ) + weights_fp8 = [] + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + if not isinstance(weights[0], QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + for i in range(num_gemms): + weight_fp8 = module.get_weight_workspace( + tensor=weights[i], + quantizer=weight_quantizers[i], + cache_name=(None if is_first_microbatch is None else f"weight{i}"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + ) + weights_fp8.append(weight_fp8) + else: + weights_fp8 = weights + + else: + inputmats = inputmats_no_fp8 + bias_dtype = activation_dtype + weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] + + biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases + + out = torch.empty( + [sum(m_splits), weights_fp8[0].size(0)], + dtype=activation_dtype, + device=device, + ) + + _ = general_grouped_gemm( + weights_fp8, + inputmats, + [out], + activation_dtype, + get_multi_stream_cublas_workspace(), + single_output=True, + m_splits=m_splits, + bias=biases, + use_bias=use_bias, + use_split_accumulator=_2X_ACC_FPROP, + ) + + if fp8_calibration: + for i in range(num_gemms): + # amax of input + for i in range(num_gemms): + input_quantizers[i].calibrate(inputmats[i]) + for i in range(num_gemms): + weight_quantizers[i].calibrate(weights[i]) + + if is_grad_enabled: + + ctx.weights_shape_1 = weights[0].shape[1] + + tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.weights_requires_grad = weights[0].requires_grad + if fuse_wgrad_accumulation and ctx.weights_requires_grad: + ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] + else: + ctx.main_grads = [None] * num_gemms + ctx.device = device + ctx.grad_output_quantizers = grad_output_quantizers + ctx.m_splits = m_splits + ctx.num_gemms = num_gemms + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = cpu_offloading + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = use_bias + ctx.sequence_parallel = sequence_parallel + ctx.inp_shape = inp.shape + ctx.requires_dgrad = inp.requires_grad + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + + # [*, in_features] -> [*, out_features] except first dimension changes for SP + return out.view(-1, *inp.shape[1:-1], out.shape[-1]) + + +@staticmethod +def _GroupedLinear_backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring + with torch.cuda.nvtx.range("_GroupedLinear_backward"): + saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + N = ctx.num_gemms + inputmats = saved_tensors[:N] + weights = saved_tensors[N : 2 * N] + biases = saved_tensors[2 * N : 3 * N] + main_grads = ctx.main_grads + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO + for i in ctx.num_gemms: + w = torch.nn.Parameter(weights[i], weights[i].requires_grad) + w.main_grad = main_grads[i] + weights[i] = w + + # preprocess grad_output + + grad_output = grad_output.contiguous() + in_features = grad_output.shape[-1] + grad_output_view = grad_output.view(-1, in_features) + grad_output_mats = torch.split( + grad_output_view, ctx.m_splits + ) + grad_output = [None] * ctx.num_gemms + grad_biases = [None] * ctx.num_gemms + if ctx.fp8: + if ctx.use_bias: + for i in range(ctx.num_gemms): + grad_biases[i], grad_output[i] = tex.bgrad_quantize( + grad_output_mats[i], ctx.grad_output_quantizers[i] + ) + else: + if os.getenv("ENABLE_CAST_BATCH_INIT", "0") == "0": + grad_output = tex.fused_multi_quantize( + grad_output_mats, + None, + ctx.grad_output_quantizers, + TE_DType[ctx.activation_dtype], + ) + else: + grad_output = tex.fused_multi_quantize_batch_init( + grad_output_mats, + in_features, + ctx.m_splits, + ctx.grad_output_quantizers, + TE_DType[ctx.activation_dtype], + ) + else: + grad_output = grad_output_mats + + if ctx.is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + + if ctx.requires_dgrad: + dgrad = torch.empty( + (sum(ctx.m_splits), ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) + + general_grouped_gemm( + weights, + grad_output, + torch.split(dgrad, ctx.m_splits), + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NN", + m_splits=ctx.m_splits, + grad=True, + use_split_accumulator=_2X_ACC_DGRAD, + ) + + if ctx.weights_requires_grad: + if ctx.fuse_wgrad_accumulation: + wgrad_list = main_grads + else: + wgrad_list = [ + torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) + for w in weights + ] + # WGRAD + _, grad_biases_, _ = ceg.general_grouped_gemm( + inputmats, + grad_output, + wgrad_list, + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NT", + grad=True, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ) + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ + + if os.getenv("ENABLE_ZERO_BUBBLE", "0") == "0": + # Deallocate input tensor + clear_tensor_data(*inputmats) + + def handle_custom_ddp_from_mcore(w, wgrad): + if ctx.weights_requires_grad: + if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): + w.grad_added_to_main_grad = True + if getattr(w, "zero_out_wgrad", False): + wgrad = torch.zeros( + w.main_grad.shape, + dtype=w.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + wgrad = torch.empty( + w.main_grad.shape, + dtype=w.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + elif ctx.fuse_wgrad_accumulation: + wgrad = None + else: + wgrad = None + return wgrad + + wgrad_list = [ + handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) + ] + else: + wgrad_list = [None] * ctx.num_gemms + + if not ctx.use_bias: + grad_biases = [None] * ctx.num_gemms + + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, # is_grad_enabled + None, # is_grad_enabled + *wgrad_list, + *grad_biases, + ) +## HACK(huang.huang) + +from ..utils import add_attr, replace_attr +from transformer_engine.pytorch.module import GroupedLinear +add_attr(GroupedLinear, "backward_custom", backward_custom) + + +from transformer_engine.pytorch.module.grouped_linear import _GroupedLinear +if os.getenv("ENABLE_CAST_BATCH_INIT", "0") == "1": + replace_attr(_GroupedLinear, 'forward', _GroupedLinear_forward) + replace_attr(_GroupedLinear, 'backward', _GroupedLinear_backward) \ No newline at end of file diff --git a/transformer_engine/musa/pytorch/module/linear.py b/transformer_engine/musa/pytorch/module/linear.py new file mode 100644 index 0000000000..0f26f35969 --- /dev/null +++ b/transformer_engine/musa/pytorch/module/linear.py @@ -0,0 +1,592 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear API""" +import os +import logging +from typing import Any, Callable, Dict, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op + +import torch + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.module.base import ( + get_workspace, + get_ub, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.utils import ( + clear_tensor_data, + requires_grad, + non_tn_fp8_gemm_supported, + non_tn_fp8_gemm_supported +) +from transformer_engine.pytorch.distributed import ( + get_distributed_world_size, + allreduce, + reduce_scatter_along_first_dim, + gather_along_first_dim, + _fsdp_scatter_tensors, +) +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + prepare_for_saving, + restore_from_saved, +) + +from transformer_engine.pytorch import cpp_extensions as ceg +from transformer_engine.pytorch.cpp_extensions import ( + general_gemm, +) +from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.module._common import noop_cat, _fix_gathered_fp8_transpose +from transformer_engine.pytorch.cpu_offload import is_cpu_offload_enabled +from transformer_engine.pytorch.utils import cast_if_needed + +import torch + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.module.base import ( + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.utils import ( + cast_if_needed, + clear_tensor_data, + requires_grad, +) +from transformer_engine.pytorch.distributed import ( + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, +) +from transformer_engine.pytorch.jit import no_torch_dynamo +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.cpu_offload import is_cpu_offload_enabled + +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, +) + +# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) +# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 +_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) +log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL +log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +logging.basicConfig( + format="[%(levelname)-8s | %(name)-19s]: %(message)s", + level=log_levels[log_level if log_level in [0, 1, 2] else 2], +) + +from transformer_engine.pytorch.module import Linear + +# HACK(huang.huang): recompute-variance for linear: add functions "backward_custom" +def backward_custom(self, inp, grad_output, is_first_microbatch=None, fp8_output=False, fp8_grad=False): + #recompute forward before gemm + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None + if skip_fp8_weight_update is not None: + is_first_microbatch = False + with self.prepare_forward( + inp, + allow_non_contiguous=isinstance(inp, QuantizedTensor), + ) as inp: + + # Get concatenated weight and bias tensors + unfused_weights = [getattr(self, name) for name in self.weight_names] + if any(isinstance(w, QuantizedTensor) for w in unfused_weights): + if self.fp8: + if len(unfused_weights) != 1: + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) + else: + unfused_weights = [w.dequantize() for w in unfused_weights] + weight_tensor = noop_cat(unfused_weights) + if self.use_bias: + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + else: + bias_tensor = None + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers(fp8_output, fp8_grad) + + # Make sure weight tensor has correct quantizer + # Note: Quantizer might have changed if quantization + # recipe changed + if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor): + weight_tensor._quantizer = weight_quantizer + + #input args of _Linear-forward + weight = weight_tensor + bias = bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None + fp8 = self.fp8 + fp8_calibration = self.fp8_calibration + cpu_offloading = is_cpu_offload_enabled() + tp_group = self.tp_group + tp_size = self.tp_size + sequence_parallel = self.sequence_parallel + activation_dtype = self.activation_dtype + parallel_mode = self.parallel_mode + is_grad_enabled = torch.is_grad_enabled() + ub_overlap_rs_fprop = self.ub_overlap_rs_fprop + ub_overlap_ag = self.ub_overlap_ag_dgrad + ub_overlap_ag_fprop = self.ub_overlap_ag_fprop + ub_overlap_rs_dgrad = self.ub_overlap_rs_dgrad + ub_bulk_dgrad = self.ub_bulk_dgrad + ub_bulk_wgrad = self.ub_bulk_wgrad + ub_name = self.ub_name + fsdp_group = self.fsdp_group + use_bias = self.use_bias + + ## code in _Linear-forward + + # Make sure input dimensions are compatible + out_features, in_features = weight.shape + inp_shape = inp.shape + assert inp_shape[-1] == in_features, "GEMM not possible" + + tp_world_size = get_distributed_world_size(tp_group) + backward_needs_input = is_grad_enabled and weight.requires_grad + + # Prepare input tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + inputmat = inp + inputmat_total = None + with_input_all_gather_nccl = ( + parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop + ) + own_quantized_input = False + if fp8: + if ( + any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if with_input_all_gather_nccl: + assert not isinstance( + inputmat, QuantizedTensor + ), "All gather of fp8 input is not supported" + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, + ) + else: + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, + ) + if not isinstance(inputmat, QuantizedTensor): + inputmat = input_quantizer(inputmat) + elif backward_needs_input: + inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) + inputmat_total = inputmat + else: + inputmat = cast_if_needed(inp, activation_dtype) + if with_input_all_gather_nccl: + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) + else: + inputmat_total = inputmat + + # Cast weight to expected dtype + weightmat = weight + if not fp8: + weightmat = cast_if_needed(weightmat, activation_dtype) + else: + if not isinstance(weight, QuantizedTensor): + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = self.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) + + # Cast bias to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(inputmat_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + ub_obj = None + ub_type = None + rs_out = None + out_dtype = activation_dtype + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG + if fp8: + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True) + inputmat_total = ub_obj.get_buffer(input_quantizer) + + tensors_to_save, tensor_objects = prepare_for_saving( + inputmat, + weightmat, + weight, + bias, + ) #will not save actually, only to match the code format + saved_tensors = tensors_to_save + fuse_wgrad_accumulation = self.fuse_wgrad_accumulation + if fuse_wgrad_accumulation and weight.requires_grad: + main_grad = weight.main_grad + requires_dgrad = inp.requires_grad + requires_wgrad = weight.requires_grad + reduce_and_update_bwd_fp8_tensors = False + # owns_input = saved_inputmat is not inp + owns_input = True # set True mannually, inp is not need after custom backward anyway + # owns_input = False #set False manually now, because we do not cache any tensor in ctx, so clear_tensor_data not needed + is_input_fp8 = not own_quantized_input + if fp8 and requires_grad(inp, weight, bias): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + + + with torch.cuda.nvtx.range("_Linear_backward"): + if ( + fp8 + and any( + [ + ub_overlap_ag, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ] + ) + and not FP8GlobalStateManager.get_fp8_recipe().delayed() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling" + ) + + inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking + restore_from_saved(tensor_objects, saved_tensors) + ) + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + main_grad + if weight is not None and fuse_wgrad_accumulation and requires_wgrad + else None + ) + + if cpu_offloading and fuse_wgrad_accumulation: + weight = torch.nn.Parameter(weight, weight.requires_grad) + weight.main_grad = main_grad + + # Gather intermediate/activation tensors if needed + # NOTE: weight_fp8 = weight when fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + assert fsdp_group is None, 'not support fsdp in backward_custom' + + ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, inp_shape[:-1]), inp_shape[-1]] + rs_out = None + dgrad_bulk = None + if ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ub_obj_gradout = get_ub(ub_name + "_dgrad") + ub_obj_dgrad = ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + + elif ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ub_obj_gradout = get_ub(ub_name + "_dgrad") + ub_obj_dgrad = ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + rs_out = torch.empty( + dgrad_shape, dtype=activation_dtype, device=grad_output.device + ) + + else: + if ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ub_obj_gradout = get_ub(ub_name + "_dgrad") + ub_obj_dgrad = ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(inputmat, input_quantizer, local_chunk=True) + + if ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ub_name + "_wgrad") + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(grad_input_quantizer) + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + grad_bias = None + if grad_output_quantizer is not None: + grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + if use_bias: + grad_output, grad_bias = tex.bgrad_quantize( + grad_output, grad_output_quantizer + ) + else: + # grad_output = grad_output_quantizer(grad_output) #same usage as input + + # usage copy from gourpedlinear + grad_output = tex.fused_multi_quantize( + [grad_output], + None, + [grad_output_quantizer], + TE_DType[activation_dtype], + ) + grad_output = grad_output[0] + + # Prepare input tensor + # Note: Perform tensor-parallel communication if needed + inputmat_total = None + inputmat_total_work = None + if ( + requires_wgrad + and parallel_mode == "column" + and sequence_parallel + and not ub_bulk_dgrad + ): + quantizer = None + if fp8: + quantizer = input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + inputmat_total, inputmat_total_work = gather_along_first_dim( + inputmat, + tp_group, + async_op=True, + quantizer=quantizer, + ) + else: + inputmat_total = inputmat + + # Check whether to output wgrad GEMM directly into main grad + if is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + fuse_wgrad_accumulation and not is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = fuse_wgrad_accumulation + + # Compute grad input tensor + dgrad = None + dgrad_work = None + if requires_dgrad: + + # Update quantizer + if grad_input_quantizer is not None: + grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + # dgrad GEMM + dgrad, *_, rs_out = general_gemm( + weight_fp8, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=grad_input_quantizer, + out=dgrad_bulk, + out_dtype=activation_dtype, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ub_bulk_dgrad, + ) + + # Launch tensor-parallel communication + if ub_overlap_rs_dgrad: + dgrad = rs_out + elif parallel_mode == "column" and not ub_bulk_wgrad: + if sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, tp_group, async_op=True) + + # Compute grad weight tensor + wgrad = None + if requires_wgrad: + if ub_bulk_dgrad: + inputmat_total = ub_obj_dgrad.get_buffer(input_quantizer) + if fp8: + if inputmat._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + inputmat_total = _fix_gathered_fp8_transpose( + inputmat_total, tp_size + ) + elif not non_tn_fp8_gemm_supported(): + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + inputmat_total._create_transpose() + + else: + if inputmat_total_work is not None: + # Synchronize tensor-parallel communication + inputmat_total_work.wait() + inputmat_total_work = None + + if isinstance(grad_output, QuantizedTensor): + # This is a no-op if platform supports non-TN FP8 GEMM or the transpose + # already exists. + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty( + dgrad_shape, dtype=activation_dtype, device=grad_output.device + ) + + # wgrad GEMM + # Note: Fuse with bgrad computation if needed + wgrad, grad_bias_, _, rs_out = ceg.general_gemm( + inputmat_total, + grad_output, + get_workspace(), + layout="NT", + grad=True, + out_dtype=( + main_grad.dtype if fuse_wgrad_accumulation else activation_dtype + ), + bias=(bias if (grad_bias is None and not fp8) else None), + out=main_grad if fuse_wgrad_accumulation else None, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ub_bulk_wgrad, + ) + + if ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out + else: + dgrad = ub_obj_wgrad.get_buffer(grad_input_quantizer, local_chunk=True) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if os.getenv("ENABLE_ZERO_BUBBLE", "0") == "0": + if owns_input: + clear_tensor_data(inputmat_total) + + # Don't return grad bias if not needed + if not use_bias: + grad_bias = None + + # Synchronize tensor parallel communication + if inputmat_total_work is not None: + inputmat_total_work.wait() + inputmat_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None + + if requires_wgrad: + # Handle custom DDP from mcore. + if ( + fuse_wgrad_accumulation + and weight is not None + and hasattr(weight, "grad_added_to_main_grad") + ): + weight.grad_added_to_main_grad = True + if getattr(weight, "zero_out_wgrad", False): + wgrad = torch.zeros( + weight.main_grad.shape, + dtype=weight.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + wgrad = None + elif fuse_wgrad_accumulation: + wgrad = None + else: + wgrad = None + + if reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + # Scatter fp8 weight buffers + if fp8 and not isinstance(weight, QuantizedTensor): + _fsdp_scatter_tensors(fsdp_group, weight_fp8) + dgrad = dgrad.view(inp.shape) if inp.requires_grad else None + inp.grad = dgrad #TODO: really need set grad for input? will it cause memory leak? + #call post-backward hook mannually + if weight.requires_grad and not self.fuse_wgrad_accumulation: + weight.grad = wgrad + if weight.grad is not None and ( + not weight.grad_added_to_main_grad or getattr(weight, 'zero_out_wgrad', False) + ): + weight.main_grad.add_(weight.grad.data) + weight.grad = None + return dgrad +# HACK(huang.huang) + +from ..utils import add_attr +add_attr(Linear, "backward_custom", backward_custom) diff --git a/transformer_engine/musa/pytorch/ops/__init__.py b/transformer_engine/musa/pytorch/ops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/transformer_engine/musa/pytorch/ops/op.py b/transformer_engine/musa/pytorch/ops/op.py new file mode 100644 index 0000000000..67e529a28f --- /dev/null +++ b/transformer_engine/musa/pytorch/ops/op.py @@ -0,0 +1,87 @@ +from typing import Optional + +import torch + +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + Recipe, + DelayedScalingRecipeState, + MXFP8BlockScalingRecipeState, +) +from ..fp8 import MTFP8BlockScalingRecipeState + +from ..utils import replace_attr + +def musa__update_quantization_recipe_state( + self, + *, + recipe: Optional[Recipe] = None, +) -> None: + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() + + # Reset quantization state if needed + if self._fp8_metas is None or self._quantizers is None: + self._reset_quantization_recipe_state(recipe=recipe) + return + for mode in ("forward", "backward"): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: + continue + recipe_state = self._fp8_metas[mode][fp8_meta_key] + need_to_reset_recipe_state = ( + recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) + ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState) + ) or (recipe.mtfp8() and not isinstance(recipe_state, MTFP8BlockScalingRecipeState)) + if need_to_reset_recipe_state: + self._reset_quantization_recipe_state(recipe=recipe) + return + + # Quantization recipe state for forward and backward pass + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: + continue + + # Update FP8 metadata + fp8_meta = self._fp8_metas[mode] + fp8_meta["recipe"] = recipe + fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + + # Get recipe state + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + recipe_state = fp8_meta[fp8_meta_key] + + # Reallocate amax history if needed + if recipe.mxfp8() or recipe.mtfp8(): + continue + + current_length = recipe_state.amax_history.size(0) + target_length = recipe.amax_history_len + if current_length != target_length: + with torch.no_grad(): + if target_length < current_length: + recipe_state.amax_history = recipe_state.amax_history[ + :target_length + ].clone() + else: + recipe_state.amax_history = torch.nn.functional.pad( + recipe_state.amax_history, + pad=(0, 0, 0, target_length - current_length), + ) + self._quantizers[mode] = recipe_state.make_quantizers() + + +def pytorch_ops_op_workaround(): + from transformer_engine.pytorch.ops.op import BasicOperation + replace_attr( + BasicOperation, + "_update_quantization_recipe_state", + musa__update_quantization_recipe_state, + ) +pytorch_ops_op_workaround() diff --git a/transformer_engine/musa/pytorch/tensor/__init__.py b/transformer_engine/musa/pytorch/tensor/__init__.py new file mode 100644 index 0000000000..19d7f6bfd4 --- /dev/null +++ b/transformer_engine/musa/pytorch/tensor/__init__.py @@ -0,0 +1,9 @@ +from ..utils import add_attr + + +def pytorch_mtfp8_tensor_workaround(): + from transformer_engine.pytorch import tensor + from . import mtfp8_tensor, mtfp8_tensor_base + add_attr(tensor, "mtfp8_tensor", mtfp8_tensor) + add_attr(tensor._internal, "mtfp8_tensor_base", mtfp8_tensor_base) +pytorch_mtfp8_tensor_workaround() diff --git a/transformer_engine/musa/pytorch/tensor/mtfp8_tensor.py b/transformer_engine/musa/pytorch/tensor/mtfp8_tensor.py new file mode 100644 index 0000000000..60eb25ad1b --- /dev/null +++ b/transformer_engine/musa/pytorch/tensor/mtfp8_tensor.py @@ -0,0 +1,510 @@ +"""Tensor class with FP8 data and FP32 scales""" +from __future__ import annotations +import math +from typing import List, Optional, Iterable, Tuple + +import torch, torch_musa +import transformer_engine_torch as tex + +from transformer_engine.pytorch.tensor.quantized_tensor import ( + _IdentityFunc, +) +from transformer_engine.pytorch.tensor import ( + Quantizer, + QuantizedTensor, +) +from transformer_engine.pytorch.utils import ( + devices_match, +) + +from .mtfp8_tensor_base import ( + MTFP8TensorBase, + _FromMTFP8Func, +) + +aten = torch.ops.aten + +class MTFP8Quantizer(Quantizer): + dtype: tex.DType + block_m: int + block_n: int + + def __init__( + self, + fp8_dtype: tex.DType, + block_m: int, + block_n: int, + *, + rowwise: bool = True, + columnwise: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp8_dtype + self.block_m = block_m + self.block_n = block_n + assert self.block_m == 1 or (self.block_m == self.block_n) + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, MTFP8Tensor), f"Cannot store quantized MTFP8 in {type(dst)} type." + + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + tex.quantize(src, self, dst, noop_flag) + + dst._fp8_dtype = self.dtype + + return dst + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM. + + Parameters + + """ + M, K = 1, 1 + for i in range(len(shape) - 1): + M *= shape[i] + if len(shape) > 0: + K = shape[-1] + + def ceil_div(a, b): + return (a + b - 1) // b + + outer = ceil_div(M, self.block_m) + inner = ceil_div(K, self.block_n) + + return (outer, inner) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> MTFP8Tensor: + + if device is None: + device = torch.device("musa") + + def ceil_div(a, b): + return (a + b - 1) // b + + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_inv = torch.empty( + ceil_div(math.prod(shape[:-1]), self.block_m), + ceil_div(shape[-1], self.block_n), + dtype=torch.float, + device=device, + ) + + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage and self.block_m != self.block_n: + columnwise_data = torch.empty_like(data) + columnwise_scale_inv = torch.empty( + ceil_div(math.prod(shape[:-1]), self.block_n), + ceil_div(shape[-1], self.block_m), + dtype=torch.float, + device=device, + ) + + return MTFP8Tensor( + shape=shape, + dtype=dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.dtype, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + pass + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"rowwise_usage={self.rowwise_usage}, " + f"columnwise_usage={self.columnwise_usage}, " + f"internal={self.internal}, " + f"block_m={self.block_m}, " + f"block_n={self.block_n}, " + f"dtype={self.dtype}, " + ")" + ) + + +class MTFP8Tensor(MTFP8TensorBase, QuantizedTensor): + def __repr__(self, *, tensor_contents=None): + return f"MTFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromMTFP8Func.apply(self, dtype) + return _FromMTFP8Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + assert self._quantizer is not None + return self._quantizer + # if self._quantizer is not None: + # return self._quantizer + + # rowwise_data_shape = self._rowwise_data.shape + # rowwise_scale_inv_shape = self._rowwise_scale_inv.shape + # assert len(rowwise_data_shape) == 2 + # assert len(rowwise_scale_inv_shape) == 2 + + # m, n = rowwise_data_shape[0], rowwise_data_shape[1] + # sinv_m, sinv_n = rowwise_scale_inv_shape[0], rowwise_scale_inv_shape[1] + + # def next_power_of_2(x): + # assert x >= 1 + # return 2 ** math.ceil(math.log2(x)) + + # if m == 1 or m == sinv_m: + # block_m = 1 + # else: + # block_m = next_power_of_2(m // sinv_m) + # block_n = next_power_of_2(n // sinv_n) + + # return MTFP8Quantizer( + # fp8_dtype=self._fp8_dtype, + # block_m=block_m, + # block_n=block_n, + # ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> MTFP8Tensor: + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> MTFP8Tensor: + return MTFP8Tensor.make_like(self) + + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor." + + if columnwise_usage and rowwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage." + return + + if rowwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise usage." + self._columnwise_data = None + self._columnwise_scale_inv = None + return + + assert ( + self._columnwise_data is not None and self._columnwise_scale_inv is not None + ), "Cannot update to columnwise usage." + self._rowwise_data = None + self._rowwise_scale_inv = None + return + + def clone(self) -> MTFP8Tensor: + rowwise_data = None + if self._rowwise_data is not None: + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> MTFP8Tensor: + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> MTFP8Tensor: + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> MTFP8Tensor: + if ( + self._rowwise_data is not None + and self._rowwise_data.is_contiguous(memory_format=memory_format) + and ( + (self._columnwise_data is None) + or (self._columnwise_data.is_contiguous(memory_format=memory_format)) + ) + ): + return self + raise ValueError("MTFP8Tensor does not support different memory formats!") + + def clear(self): + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + orig_size = data.size() + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + if orig_size != out_data.size(): + raise NotImplementedError( + "Changing shape with view not implemented " + " (scales and columnwise data untouched)." + ) + return MTFP8Tensor.make_like(tensor) + + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: tex.DType, + dtype: torch.dtype, + quantizer: Quantizer, + ) -> MTFP8Tensor: + return MTFP8Tensor( + shape=shape, + dtype=dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=fp8_dtype, + quantizer=quantizer, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + return ( + MTFP8Tensor._make_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + self._quantizer, + ), + ) + + def _get_data(self) -> MTFP8Tensor: + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + new_device = tensor.device if tensor.is_musa else self.device + + if isinstance(tensor, MTFP8Tensor): + if ( + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + MTFP8Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + super(MTFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + assert self._quantizer is not None, "Can't quantize without a quantizer" + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + tensor: MTFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MTFP8Tensor: + ctx.shape = tensor.shape + if shape is None: + return tensor + + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(tensor.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.view(*shape) + if tensor._columnwise_data is not None: + new_columnwise_data = tensor._columnwise_data.view(*shape) + return MTFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + if isinstance(grad, MTFP8Tensor): + new_data = ( + grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None + ) + new_columnwise_data = ( + grad._columnwise_data.view(*ctx.shape) if grad._columnwise_data is not None else None + ) + dgrad = MTFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + tensor: MTFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MTFP8Tensor: + ctx.shape = tensor.shape + if shape is None: + return tensor + + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(tensor.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.reshape(*shape) + if tensor._columnwise_data is not None: + new_columnwise_data = tensor._columnwise_data.reshape(*shape) + + return MTFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + if isinstance(grad, MTFP8Tensor): + new_data = ( + grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None + ) + new_columnwise_data = ( + grad._columnwise_data.view(*ctx.shape) if grad._columnwise_data is not None else None + ) + dgrad = MTFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/musa/pytorch/tensor/mtfp8_tensor_base.py b/transformer_engine/musa/pytorch/tensor/mtfp8_tensor_base.py new file mode 100644 index 0000000000..9d048be63d --- /dev/null +++ b/transformer_engine/musa/pytorch/tensor/mtfp8_tensor_base.py @@ -0,0 +1,111 @@ +from __future__ import annotations +from typing import Optional, Dict, Any, Tuple + +import torch + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import ( + TE_DType as torch_to_transformer_engine_dtype, +) +from transformer_engine.pytorch.tensor import Quantizer + + +class _FromMTFP8Func(torch.autograd.Function): + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: MTFP8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + dtype = torch_to_transformer_engine_dtype[dtype] + + if tensor._rowwise_data is not None: + return tex.dequantize(tensor, dtype) + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + return grad, None + + +class MTFP8TensorBase: + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: tex.DType + _rowwise_scale_inv: Optional[torch.Tensor] + _columnwise_scale_inv: Optional[torch.Tensor] + + def __new__( + cls, + *args, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: tex.DType, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + instance._fp8_dtype = fp8_dtype + instance._quantizer = quantizer + + return instance + + def get_metadata(self) -> Dict[str, Any]: + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MTFP8TensorBase]: + tensors = [self._rowwise_data, self._columnwise_data] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + return _FromMTFP8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + if self._rowwise_data is not None: + return self._rowwise_data.size(*args, **kwargs) + return self._columnwise_data.size(*args, **kwargs) + + def __repr__(self): + data = self.dequantize() + if self._rowwise_data is not None: + descriptor = "rowwise" + sinv = self._rowwise_scale_inv + else: + descriptor = "columnwise" + sinv = self._columnwise_scale_inv + + return ( + "MTFP8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"{descriptor}_scaled_data={data}, " + f"{descriptor}_scale_inv={sinv}" + ")" + ) diff --git a/transformer_engine/musa/pytorch/utils.py b/transformer_engine/musa/pytorch/utils.py new file mode 100644 index 0000000000..0a6ff67830 --- /dev/null +++ b/transformer_engine/musa/pytorch/utils.py @@ -0,0 +1,28 @@ +from typing import List +import torch + + +def wrap_name(src): + return f"_orig_{src}" + +def add_attr(module, name, target): + setattr(module, name, target) + +def wrap_attr(module, name, wrapper): + target = getattr(module, name) + setattr(module, wrap_name(name), target) + setattr(module, name, wrapper) + +def replace_attr(module, name, target): + wrap_attr(module, name, target) + + +def musa_assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: + return +# TODO(yehua.zhang) do not work +import sys +for k in sys.modules: + if 'utils' in k: + for target in ['assert_dim_for_fp8_exec']: + if getattr(sys.modules[k], target, None): + setattr(sys.modules[k], target, musa_assert_dim_for_fp8_exec) \ No newline at end of file diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 95558e30da..02c74db571 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -66,28 +66,28 @@ ) from transformer_engine.pytorch import export from transformer_engine.pytorch.export import is_in_onnx_export_mode - +from flash_attn import flash_attn_func # Global vars for flash attn v2 and v3 imports -flash_attn_cuda_bwd = None -flash_attn_func = None +flash_attn_musa_bwd = None +# flash_attn_func = None flash_attn_varlen_func = None _flash_attn_fwd = None _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None try: - fa_utils.version = PkgVersion(get_pkg_version("flash-attn")) + fa_utils.version = PkgVersion(get_pkg_version("flash-attn11")) except PackageNotFoundError: pass # only print warning if use_flash_attention_2 = True in get_attention_backend else: - if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): + if torch.musa.is_available() and get_device_compute_capability() >= (10, 0): if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version: fa_utils.is_installed = True elif fa_utils.version_required <= fa_utils.version <= fa_utils.max_version: fa_utils.is_installed = True if fa_utils.is_installed: - from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd + from flash_attn_2_musa import varlen_bwd as flash_attn_musa_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd @@ -101,7 +101,7 @@ # Setup Flash attention utils fa_utils.set_flash_attention_version() elif ( - torch.cuda.is_available() + torch.musa.is_available() and get_device_compute_capability() >= (8, 0) and dpa_utils._NVTE_FLASH_ATTN ): @@ -369,7 +369,7 @@ def forward( output_size[2], output_size[3], dtype=query_layer.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) scale = self.softmax_scale @@ -387,10 +387,10 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=S_quantizer.dtype, device="cuda" + fp8_dtype=S_quantizer.dtype, device="musa" ) dP_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=dP_quantizer.dtype, device="cuda" + fp8_dtype=dP_quantizer.dtype, device="musa" ) if "2" in qkv_layout or "3" in qkv_layout: @@ -660,7 +660,7 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, - cp_stream: torch.cuda.Stream = None, + cp_stream: torch.musa.Stream = None, cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, @@ -676,7 +676,7 @@ def forward( for x in [query_layer, key_layer, value_layer] ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( - query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda + query_layer.is_musa and key_layer.is_musa and value_layer.is_musa ), "FlashAttention currently only supports CUDA tensors." assert ( qkv_layout in QKVLayouts @@ -933,6 +933,7 @@ def forward( 1 )[:batch_size] ) + print(func) output = func( query_layer, key_layer, @@ -1413,8 +1414,8 @@ def backward(ctx, d_out, *_args): dk = torch.empty_like(k) dv = torch.empty_like(v) d_out, q, k, v, out = [dpa_utils.maybe_contiguous(x) for x in (d_out, q, k, v, out)] - # from transformer_engine.pytorch.attention.dot_product_attention import flash_attn_cuda_bwd - flash_attn_cuda_bwd( + # from transformer_engine.pytorch.attention.dot_product_attention import flash_attn_musa_bwd + flash_attn_musa_bwd( d_out, q, k, @@ -1719,7 +1720,7 @@ def forward( fast_zero_fill: bool = True, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, - cp_stream: torch.cuda.Stream = None, + cp_stream: torch.musa.Stream = None, cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, @@ -1738,7 +1739,7 @@ def forward( for x in [query_layer, key_layer, value_layer] ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." assert ( - query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda + query_layer.is_musa and key_layer.is_musa and value_layer.is_musa ), "FusedAttention only supports CUDA tensors." assert ( qkv_layout in QKVLayouts diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index e127d91595..ab102ddb1f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -264,7 +264,7 @@ def flash_attn_a2a_communicate( seq_dim: int, cp_size: int, cp_group: dist_group_type, - cp_stream: torch.cuda.Stream, + cp_stream: torch.musa.Stream, before_attn: bool, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" @@ -278,7 +278,7 @@ def flash_attn_a2a_communicate( a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True ) if i > 1: - with torch.cuda.stream(cp_stream): + with torch.musa.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] # reorder the sequence chunks @@ -313,7 +313,7 @@ def flash_attn_a2a_communicate( x, chunk_ids_for_a2a, seq_dim, cp_size ) if i > 1: - with torch.cuda.stream(cp_stream): + with torch.musa.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] # [cp, 2, b, s//2, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] @@ -322,7 +322,7 @@ def flash_attn_a2a_communicate( # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) - torch.cuda.current_stream().wait_stream(cp_stream) + torch.musa.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -331,7 +331,7 @@ def flash_attn_a2a_communicate_softmax_offset( h_dim: int, cp_size: int, cp_group: dist_group_type, - cp_stream: torch.cuda.Stream, + cp_stream: torch.musa.Stream, before_attn: bool, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Split/AllGather communication for softmax offset.""" @@ -361,14 +361,14 @@ def flash_attn_a2a_communicate_softmax_offset( # [1, h//cp, 1, 1] -> [1, h, 1, 1] inp = tensor.view(-1) output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device) - with torch.cuda.stream(cp_stream): + with torch.musa.stream(cp_stream): torch.distributed.all_gather_into_tensor( output, inp, group=cp_group, async_op=False, ) - torch.cuda.current_stream().wait_stream(cp_stream) + torch.musa.current_stream().wait_stream(cp_stream) output = output.view( *tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :] ) @@ -1352,9 +1352,9 @@ def forward( attn_biases = [None for _ in range(cp_size)] # create two streams to resolve wave quantization issue of Flash Attn in each step - flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + flash_attn_streams = [torch.musa.current_stream(), cp_stream] # synchronize fwd results correction across steps - fwd_results_correction_done = torch.cuda.Event() + fwd_results_correction_done = torch.musa.Event() p2p_comm_buffers = [None for _ in range(cp_size)] k_shape = k.shape @@ -1369,7 +1369,7 @@ def forward( out = None for i in range(cp_size + 1): if i < cp_size: - with torch.cuda.stream(flash_attn_streams[i % 2]): + with torch.musa.stream(flash_attn_streams[i % 2]): # wait until KV is received for req in send_recv_reqs[(i + 1) % 2]: req.wait() @@ -1572,7 +1572,7 @@ def forward( if i > 1: flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) - with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): + with torch.musa.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: # [b, h, sq, 1] -> [b, h, sq] or # [t, h, 1] -> [t, np] @@ -1628,7 +1628,7 @@ def forward( if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) - torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + torch.musa.current_stream().wait_stream(flash_attn_streams[1]) if return_max_logit: torch.distributed.all_reduce( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group @@ -2705,10 +2705,10 @@ def forward( # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) - cp_stream.wait_stream(torch.cuda.current_stream()) + cp_stream.wait_stream(torch.musa.current_stream()) # create two streams to resolve wave quantization issue of Flash Attn in each step - flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + flash_attn_streams = [torch.musa.current_stream(), cp_stream] local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] kv_seq_range_per_step = [None, None] @@ -2723,7 +2723,7 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): - with torch.cuda.stream(flash_attn_streams[i]): + with torch.musa.stream(flash_attn_streams[i]): # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() @@ -2815,7 +2815,7 @@ def forward( if return_max_logit and i == 0: max_logit = torch.clone(max_logit_per_step[0]) if i > 0: - with torch.cuda.stream(flash_attn_streams[i - 1]): + with torch.musa.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": @@ -2823,7 +2823,7 @@ def forward( if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) - torch.cuda.current_stream().wait_stream(cp_stream) + torch.musa.current_stream().wait_stream(cp_stream) if return_max_logit: torch.distributed.all_reduce( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group @@ -2896,9 +2896,9 @@ def backward(ctx, dout, *_args): dv_per_step = [None, None] # create two streams to resolve wave quantization issue of Flash Attn in each step - flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream] + flash_attn_streams = [torch.musa.current_stream(), ctx.cp_stream] # synchronize dkv update across steps - dkv_update_done = torch.cuda.Event() + dkv_update_done = torch.musa.Event() # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) @@ -2913,7 +2913,7 @@ def backward(ctx, dout, *_args): # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) - ctx.cp_stream.wait_stream(torch.cuda.current_stream()) + ctx.cp_stream.wait_stream(torch.musa.current_stream()) local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] @@ -2950,7 +2950,7 @@ def backward(ctx, dout, *_args): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): - with torch.cuda.stream(flash_attn_streams[i]): + with torch.musa.stream(flash_attn_streams[i]): # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() @@ -3028,7 +3028,7 @@ def backward(ctx, dout, *_args): ) if i > 0: - with torch.cuda.stream(flash_attn_streams[i - 1]): + with torch.musa.stream(flash_attn_streams[i - 1]): if ctx.qkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": @@ -3050,7 +3050,7 @@ def backward(ctx, dout, *_args): if i < len(local_seq_chunk_ids): flash_attn_streams[i - 1].record_event(dkv_update_done) - torch.cuda.current_stream().wait_stream(ctx.cp_stream) + torch.musa.current_stream().wait_stream(ctx.cp_stream) # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index d62bcc92ac..85ec5bc026 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -325,7 +325,7 @@ def __init__( attention_type: str = "self", cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, - cp_stream: torch.cuda.Stream = None, + cp_stream: torch.musa.Stream = None, cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", @@ -416,16 +416,17 @@ def __init__( self.return_max_logit = return_max_logit self.softmax_type = softmax_type + self.softmax_offset = None if self.softmax_type == "vanilla": self.softmax_offset = None if self.softmax_type == "off-by-one": self.softmax_offset = torch.zeros( - self.num_attention_heads // self.tp_size, device="cuda" + self.num_attention_heads // self.tp_size, device="musa" ) if self.softmax_type == "learnable": self.register_parameter( "softmax_offset", - Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="musa")), get_rng_state_tracker=get_rng_state_tracker, ) @@ -523,7 +524,7 @@ def set_context_parallel_group( self, cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], - cp_stream: torch.cuda.Stream, + cp_stream: torch.musa.Stream, cp_comm_type: str = "p2p", ) -> None: """ @@ -539,8 +540,8 @@ def set_context_parallel_group( and cp_group[1] are for a2a and p2p communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. - cp_stream : torch.cuda.Stream - cuda stream for context parallel execution. + cp_stream : torch.musa.Stream + musa stream for context parallel execution. cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". @@ -986,12 +987,13 @@ def forward( fp8_output: Optional[bool], default = `False` Whether to enforce output to be in FP8 or not. """ - - with torch.cuda.device(query_layer.device), self.prepare_forward( + self.softmax_type = "vanilla" + with torch.musa.device(query_layer.device), self.prepare_forward( query_layer, num_gemms=3, allow_non_contiguous=True, - allow_different_data_and_param_types=self.softmax_type != "vanilla", + # allow_different_data_and_param_types=self.softmax_type != "vanilla", + # allow_different_data_and_param_types=False, ) as query_layer: # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): @@ -1000,7 +1002,7 @@ def forward( ), "Unsupported RNG states tracker." assert ( graph_safe_rng_available() - ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." + ), "Upgrade PyTorch version to get RNG manipulation support for musa graph capture." # checks for FP8 if self.fp8: @@ -1026,7 +1028,7 @@ def forward( # checks for q/k/v shapes assert ( - query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda + query_layer.is_musa and key_layer.is_musa and value_layer.is_musa ), "DotProductAttention only supports CUDA tensors." assert ( query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype @@ -1325,7 +1327,7 @@ def forward( fp8_meta=self.fp8_meta, inference_params=inference_params, softmax_type=self.softmax_type, - return_max_logit=self.return_max_logit, + # return_max_logit=self.return_max_logit, ) global _attention_backends if is_in_onnx_export_mode(): @@ -1385,6 +1387,7 @@ def forward( " disabling all backends." ) + self.softmax_offset = None # run attention softmax_offset = ( self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py index df10fc7905..83dce468c6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/softmax.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/softmax.py @@ -24,7 +24,7 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: def _get_mask(): diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 return torch.triu( - torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset + torch.ones(sq, sk, dtype=torch.bool, device="musa"), diagonal=diagonal_offset ) if is_in_onnx_export_mode(): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 6bcc9f25da..f315d3795b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -66,6 +66,27 @@ _cu_seqlens_cache = {} +def _get_full_cu_seqlens( + batch_size: int, + max_seqlen: int, + device: torch.device, +) -> torch.Tensor: + """Cumulative sequence lengths in full data batch + + All sequences in batch have the maximum sequence length. + + """ + global _cu_seqlens_cache + if (batch_size, max_seqlen) not in _cu_seqlens_cache: + _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( + 0, + (batch_size + 1) * max_seqlen, + step=max_seqlen, + dtype=torch.int32, + device=device, + ) + return _cu_seqlens_cache[(batch_size, max_seqlen)] + class AttentionLogging: """ Manage logging for attention module @@ -1066,7 +1087,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed and not FlashAttentionUtils.v3_warning_printed - and torch.cuda.current_device() == 0 + and torch.musa.current_device() == 0 ): logger.warning( "flash-attn v3 may provide important feature support or performance improvement." @@ -1078,7 +1099,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_flash_attention_2 and not FlashAttentionUtils.is_installed and not FlashAttentionUtils.warning_printed - and torch.cuda.current_device() == 0 + and torch.musa.current_device() == 0 ): logger.warning( "flash-attn may provide important feature support or performance improvement." @@ -1193,13 +1214,13 @@ def get_padding_mask( ], dim=0, ) - attention_mask_q = attention_mask_q.to(device="cuda") + attention_mask_q = attention_mask_q.to(device="musa") if attention_type == "self": attention_mask = attention_mask_q else: attention_mask = ( attention_mask_q, - attention_mask_kv.to(device="cuda"), + attention_mask_kv.to(device="musa"), ) return attention_mask @@ -1318,9 +1339,9 @@ def get_full_mask( actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) # apply SWA mask - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="musa").view( 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="musa").view(1, 1, 1, max_seqlen_kv) swa_left = None swa_right = None if attn_mask_type == "causal_bottom_right" or ( @@ -1416,7 +1437,7 @@ def get_alibi( m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) m = torch.cat([m, m_hat]) - _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda") + _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="musa") _alibi_cache["_num_heads"] = num_heads _alibi_cache["_alibi_slopes_require_update"] = False @@ -1429,9 +1450,9 @@ def get_alibi( else: raise ValueError("ALiBi slopes cannot exceed 2 dimensions.") - bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="musa").view( 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="musa").view( 1, 1, 1, max_seqlen_kv ) if actual_seqlens_q is None and actual_seqlens_kv is None: @@ -1451,7 +1472,7 @@ def get_alibi( _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment bias_dtype = torch.float32 if bias_dtype is None else bias_dtype - _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") + _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="musa") _alibi_cache["_alibi_bias_require_update"] = False return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"] @@ -1466,7 +1487,7 @@ def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: mask = mask.squeeze(1).squeeze(1) reduced_mask = mask.logical_not().sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") + zero = torch.zeros(1, dtype=torch.int32, device="musa") cu_seqlens = torch.cat((zero, cu_seqlens)) return cu_seqlens @@ -1484,7 +1505,7 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch. reduced_mask = mask.logical_not().sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") + zero = torch.zeros(1, dtype=torch.int32, device="musa") cu_seqlens = torch.cat((zero, cu_seqlens)) mask = mask.reshape(-1) @@ -1509,7 +1530,7 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: bs = len(cu_seqlens) - 1 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)] - indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda") + indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="musa") num_nonzeros = indices.shape[0] pad_amount = bs * max_seqlen - num_nonzeros diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index f0ef8d0bd5..e8591d597d 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -207,12 +207,12 @@ def __init__( self.cu_seqlens_q = torch.zeros( self.max_batch_size + 1, dtype=torch.int32, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) self.cu_seqlens_kv = torch.zeros( self.max_batch_size + 1, dtype=torch.int32, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) # This internal buffer holds the running length of each @@ -223,7 +223,7 @@ def __init__( self.pre_step_seqlens = torch.zeros( self.max_batch_size, dtype=torch.int32, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) def reset(self): @@ -429,14 +429,14 @@ def __init__( self.batch_indices = torch.zeros( self.max_batch_size, dtype=torch.int32, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) # after re-indexing, batch indices are always [0, ..., b-1] self.batch_indices_post_step = torch.range( 0, self.max_batch_size - 1, dtype=torch.int32, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) # whether reindexing is needed, i.e. when batch seq_ids have changed self.need_reindex = True @@ -449,7 +449,7 @@ def allocate_memory(self, layer_number): self.num_heads, self.head_dim_k, dtype=self.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) v_cache = torch.zeros( self.max_batch_size, @@ -457,7 +457,7 @@ def allocate_memory(self, layer_number): self.num_heads, self.head_dim_v, dtype=self.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) self.cache[layer_number] = (k_cache, v_cache) @@ -626,7 +626,7 @@ def __init__( self.allocated_pages = defaultdict(list) # page table, [batch_size, max_pages_per_seq] self.page_table = torch.zeros( - self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="musa" ) def reset(self): @@ -646,7 +646,7 @@ def allocate_memory(self, layer_number): self.num_heads, self.head_dim_k, dtype=self.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) v_cache = torch.zeros( self.total_num_pages, @@ -654,7 +654,7 @@ def allocate_memory(self, layer_number): self.num_heads, self.head_dim_v, dtype=self.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), ) self.cache[layer_number] = (k_cache, v_cache) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index b3bda677bb..3bb5d37d67 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -128,7 +128,7 @@ class MultiheadAttention(torch.nn.Module): whether to use interleaved rotary position embeddings. bias : bool, default = `True` if set to `False`, the transformer layer will not learn any additive biases. - device : Union[torch.device, str], default = "cuda" + device : Union[torch.device, str], default = "musa" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -255,7 +255,7 @@ def __init__( ub_bulk_wgrad: bool = False, bias: bool = True, normalization: str = "LayerNorm", - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = "musa", qkv_format: str = "sbhd", name: str = None, qk_norm_type: Optional[str] = None, @@ -542,7 +542,7 @@ def set_context_parallel_group( self, cp_group: Union[dist_group_type, List[dist_group_type], None], cp_global_ranks: List[int], - cp_stream: torch.cuda.Stream, + cp_stream: torch.musa.Stream, cp_comm_type: str = "p2p", ) -> None: """ @@ -558,8 +558,8 @@ def set_context_parallel_group( and cp_group[1] are for a2a and p2p communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. - cp_stream : torch.cuda.Stream - cuda stream for context parallel execution. + cp_stream : torch.musa.Stream + musa stream for context parallel execution. cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. Can be "p2p" or "all_gather" or "a2a", "a2a+p2p". diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index cc23d65a3e..1c4d7fb8c9 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -54,7 +54,7 @@ def __init__( inv_freq = 1.0 / ( self.rotary_base ** ( - torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.musa.current_device()) / dim ) ) @@ -76,7 +76,7 @@ def forward(self, max_seq_len: int, offset: int = 0): offset: int, default = 0 Fixed offset for frequencies. """ - with torch.autocast(enabled=False, device_type="cuda"): + with torch.autocast(enabled=False, device_type="musa"): seq = ( torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + offset diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 648b21eb4d..e111989492 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -18,6 +18,194 @@ CPUOffloadEnabled = False CPUOffloadedLayer = False +def set_offloading_param(tensor, param_name, value): + """Set the type of the offloading needed for a tensor.""" + assert param_name in ["weight_offloading", "activation_offloading", "fine_grained_offloading"] + if tensor is None: + return + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + setattr(tensor, param_name, value) + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + setattr(tensor, param_name, value) + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_push." + ) + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_pop." + ) + +class _FineGrainedAsyncDoubleBufferGroupOffloadHandler(OffloadHandler): + + def __init__(self) -> None: + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_state = {} + # Tracking the number of layers offloaded + self.current_layer_id = 0 + # Tracking the number of microbatches offloaded + self.current_microbatch_id = 0 + + self.reloading_tensor = {} + self.to_offload_tensor = {} + + # allocate streams and events for synchronization + self.d2h_stream = None + self.h2d_stream = None + + self.OFFLOAD_TENSOR_ATTR_KEY = 'fine_grained_offloading' + + self.num_layers = None + self.pp_size = None + self.is_pipeline_last_stage = None + + self.pin_memory_tensor_pool = {} + self.to_release_tensor = {} + + + def is_last_layer(self): + return self.is_pipeline_last_stage and self.current_layer_id >= self.num_layers - 1 + + + def end_microbatch(self): + current_microbatch_tensor_tags = [] + for tensor_tag in self.to_release_tensor: + if tensor_tag[0] == self.current_microbatch_id: + current_microbatch_tensor_tags.append(tensor_tag) + + for tensor_tag in current_microbatch_tensor_tags: + copy_done_event, release_src_tensor = self.to_release_tensor.pop(tensor_tag) + copy_done_event.synchronize() + release_src_tensor.data = torch.Tensor() + + + def register_offload(self, src_tensor): + assert hasattr(src_tensor, self.OFFLOAD_TENSOR_ATTR_KEY) + tensor_name = getattr(src_tensor, self.OFFLOAD_TENSOR_ATTR_KEY) + tensor_tag = (self.current_microbatch_id, self.current_layer_id, tensor_name) + self.to_offload_tensor[tensor_tag] = src_tensor + return tensor_tag + + + def launch_offload(self, tensor_name): + if self.d2h_stream is None: + self.d2h_stream = torch.cuda.Stream() + + prev_layer_id = self.current_layer_id - 1 + prev_tensor_tag = (self.current_microbatch_id, prev_layer_id, tensor_name) + if prev_tensor_tag in self.to_release_tensor: + copy_done_event, release_src_tensor = self.to_release_tensor.pop(prev_tensor_tag) + copy_done_event.synchronize() + release_src_tensor.data = torch.Tensor() + + if self.is_last_layer(): + return + + copy_done_event = torch.cuda.Event() + tensor_tag = (self.current_microbatch_id, self.current_layer_id, tensor_name) + # print(f'start to offload {tensor_tag}') + src_tensor = self.to_offload_tensor.pop(tensor_tag) + token_num = src_tensor.size(0) + hidden_dim = src_tensor.size()[1:] + device = src_tensor.device + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + pin_memory_tag = (self.current_microbatch_id % self.pp_size, self.current_layer_id, tensor_name) + with torch.cuda.stream(self.d2h_stream): + existing_buffer = self.pin_memory_tensor_pool.get(pin_memory_tag) + # existing_buffer = self.pin_memory_tensor_pool.get(tensor_tag) + if existing_buffer is None or existing_buffer.size() < src_tensor.size(): + buffer_shape = [token_num * 2] + list(hidden_dim) + new_buffer = torch.empty( + buffer_shape, + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=True, + ) + self.pin_memory_tensor_pool[pin_memory_tag] = new_buffer + # self.pin_memory_tensor_pool[tensor_tag] = new_buffer + + # buffer = self.pin_memory_tensor_pool[tensor_tag] + buffer = self.pin_memory_tensor_pool[pin_memory_tag] + buffer[:token_num, ...].copy_(src_tensor.detach(), non_blocking=True) + cpu_backup = buffer[:token_num, ...] + + copy_done_event.record(stream=self.d2h_stream) + + self.to_release_tensor[tensor_tag] = (copy_done_event, src_tensor) + + state = (device, cpu_backup, copy_done_event) + self.tensor_tag_to_state[tensor_tag] = state + + return tensor_tag + + + def launch_reload(self, tensor_name): + if self.h2d_stream is None: + self.h2d_stream = torch.cuda.Stream() + #reload fc1-input in layer i-1 in layer i + tensor_tag = (self.current_microbatch_id, self.current_layer_id - 1, tensor_name) + if not tensor_tag in self.tensor_tag_to_state: + return + (device, cpu_backup, copy_done_event) = self.tensor_tag_to_state.pop(tensor_tag) + + copy_done_event = torch.cuda.Event() + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.h2d_stream): + device_tensor = cpu_backup.to(device, non_blocking=True) + copy_done_event.record(stream=self.h2d_stream) + state = (copy_done_event, device_tensor) + self.reloading_tensor[tensor_tag] = state + + + def wait_reload(self, tensor_tag): + assert tensor_tag in self.reloading_tensor + (copy_done_event, device_tensor) = self.reloading_tensor.pop(tensor_tag) + copy_done_event.wait() + return device_tensor + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + if hasattr(tensor, self.OFFLOAD_TENSOR_ATTR_KEY): + return self.register_offload(tensor) + return tensor + + def tensor_pop(self, tensor_tag, **kwargs): + if tensor_tag in self.reloading_tensor: + return self.wait_reload(tensor_tag) + return tensor_tag + + def start_microbatch_forward(self, current_microbatch_id): + self.current_microbatch_id = current_microbatch_id + self.current_layer_id = 0 + + def start_microbatch_backward(self, current_microbatch_id): + self.current_microbatch_id = current_microbatch_id + self.current_layer_id = self.num_layers + #reload fc1-input in last layer of this rank in the begining of backward + self.launch_reload("fc1_inp") + + +_fg_offload_handler_instance = _FineGrainedAsyncDoubleBufferGroupOffloadHandler() + + +def get_fine_grained_offload_handler(): + return _fg_offload_handler_instance + def mark_activation_offload(*tensors): """Set the type of the offloading needed for a tensor.""" diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 6151ecafd3..5302b35f9e 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -9,13 +9,48 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch +from functools import reduce from .. import cpp_extensions as tex from ..constants import TE_DType from ..export import is_in_onnx_export_mode from ..utils import get_default_init_method +from ..tensor.float8_tensor import Float8Tensor + +def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor: + """Reorder FP8 transposes after Userbuffers gather. + + The all-gather is performed in-place in the Float8Tensor's + row-wise data, and afterwards we need to do a transpose to get the + correct ordering. This misuses data fields in Float8Tensor and + should be considered an evil hack. It would be best to move + transpose logic into CommOverlap::get_buffer. + + Responsibility for fixing: adener, tmoon + + """ + assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor" + assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1" + assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data" + assert ( + fp8_tensor._data.shape[0] % tp_size == 0 + ), "Leading dimension of data is not divisble by TP size" + + data = fp8_tensor._data + batched_size = reduce(multiply_op, data.shape[1:]) + interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size] + transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size] + fp8_tensor._transpose = ( + data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape) + ) + + fp8_tensor._transpose_invalid = False + fp8_tensor._data = None + + return fp8_tensor + def _get_normalization_func(normalization: str, forward: bool): fwd_normalization_funcs = { "LayerNorm": tex.layernorm_fwd, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d16455b5b4..4cc9e5cc4c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -76,7 +76,7 @@ class UserBufferQuantizationMode(Enum): def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: + if torch.musa.get_device_properties(torch.musa.current_device()).major >= 9: # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales return 32 * 1024 * 1024 + 1024 return 4_194_304 @@ -87,7 +87,7 @@ def get_workspace() -> torch.Tensor: global _cublas_workspace if _cublas_workspace is None: _cublas_workspace = torch.empty( - get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" + get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="musa" ) return _cublas_workspace @@ -98,7 +98,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: if not _multi_stream_cublas_workspace: for _ in range(tex.get_num_cublas_streams()): _multi_stream_cublas_workspace.append( - torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") + torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="musa") ) return _multi_stream_cublas_workspace @@ -111,7 +111,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( shape, dtype=dtype, - device="cuda", + device="musa", requires_grad=False, ) if zero: @@ -182,8 +182,8 @@ def initialize_ub( """ if not tex.device_supports_multicast(): assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( - "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " - + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + "MUSA device, driver and/or toolkit version does not support comm+GEMM overlap with " + + "MUSA Multicast. Launch app with UB_SKIPMC=1 to try MUSA IPC instead." ) if not quantization_modes: @@ -282,7 +282,7 @@ def initialize_ub( elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS: # This ensures we don't do `.repeat()` on an already expanded workspace _cublas_workspace = torch.empty( - get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" + get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="musa" ).repeat(_NUM_MAX_UB_STREAMS) # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe @@ -640,7 +640,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() - assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + assert torch.musa.is_available(), "TransformerEngine needs MUSA." self.name = None self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False @@ -895,7 +895,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: state["extra_fp8_variables"] = extra # Serialize state into byte tensor - torch.cuda.synchronize() + torch.musa.synchronize() state_serialized = bytearray(pickle.dumps(state)) state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) return state_serialized @@ -917,7 +917,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: elif isinstance(state, io.BytesIO): # Deprecated format with io.BytesIO state.seek(0) - state = torch.load(state, map_location="cuda") + state = torch.load(state, map_location="musa") else: raise RuntimeError("Unsupported checkpoint format.") @@ -956,7 +956,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) - torch.cuda.synchronize() + torch.musa.synchronize() def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" @@ -1080,7 +1080,7 @@ def prepare_forward( if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: - assert inp.is_cuda, "TransformerEngine needs CUDA." + assert inp.is_musa, "TransformerEngine needs MUSA." if self.tp_size > 1: assert self.tp_group_initialized, "TP group not initialized." @@ -1115,16 +1115,16 @@ def set_nccl_overlap_warning_if_tp(self) -> None: before the GEMM for there to be a guaranteed overlap. From the host side in TE, the comm calls are always launched first, but to ensure that the GEMM isn't scheduled first, the environment - variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to + variable `MUSA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a single channel. """ if self.tp_size == 1: return - num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) - if num_cuda_work_queues != 1: + num_musa_work_queues = int(os.getenv("MUSA_DEVICE_MAX_CONNECTIONS", "0")) + if num_musa_work_queues != 1: warnings.warn( "To guarantee overlapping TP and SP collectives with the backward" - "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" + "GEMMs, set environment variable MUSA_DEVICE_MAX_CONNECTIONS = 1" ) @staticmethod @@ -1248,7 +1248,7 @@ def register_parameter(self, name, param, **kwargs): def reset_parameters(self, defer_init: Optional[bool] = False) -> None: """ Reset all module parameters to initial values. Unless deferred initialization - is specified, all parameters on a 'meta' device are also materialized on a real cuda + is specified, all parameters on a 'meta' device are also materialized on a real musa device before the values are reset to initial. """ if defer_init: @@ -1257,7 +1257,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: for name, param in self.named_parameters(recurse=False): # Ensure parameter is on a real device if param.device == torch.device("meta"): - param = torch.empty_like(param, device="cuda") + param = torch.empty_like(param, device="musa") # Initialize the parameter values on device init_fn = self.param_init_meta[name].init_fn diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a5bf21ee17..548f4276d0 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -55,7 +55,7 @@ class _GroupedLinear(torch.autograd.Function): """GroupedLinear semi-top level module - Calls custom cuda extensions. + Calls custom musa extensions. """ @staticmethod @@ -450,14 +450,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad): wgrad = torch.zeros( weight.main_grad.shape, dtype=weight.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), requires_grad=False, ) else: wgrad = torch.empty( weight.main_grad.shape, dtype=weight.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), requires_grad=False, ) elif ctx.fuse_wgrad_accumulation: @@ -528,7 +528,7 @@ class GroupedLinear(TransformerEngineBaseModule): used to get the random number generator state tracker for initializing weights. rng_tracker_name : str, default = `None` the param passed to get_rng_state_tracker to get the specific rng tracker. - device : Union[torch.device, str], default = "cuda" + device : Union[torch.device, str], default = "musa" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -581,7 +581,7 @@ def __init__( return_bias: bool = False, params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = "musa", ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, @@ -768,7 +768,7 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with torch.cuda.device( + with torch.musa.device( getattr(self, list(self.named_parameters())[0][0]).device ), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = self._get_weight_tensors() diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 6d13544e4f..fae568358c 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -33,7 +33,7 @@ class LayerNorm(_LayerNormOp): eps : float, default = 1e-5 A value added to the denominator of layer normalization for numerical stability - device: torch.device, default = default CUDA device + device: torch.device, default = default MUSA device Tensor device dtype: torch.dtype, default = default dtype Tensor datatype @@ -45,7 +45,7 @@ class LayerNorm(_LayerNormOp): y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta sm_margin: int or dict, default = 0 - Number of SMs to exclude when launching CUDA kernels. This + Number of SMs to exclude when launching MUSA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM margin at each compute stage ("forward", "backward", diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6c0f969e47..29a75c4041 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -81,7 +81,7 @@ class _LayerNormLinear(torch.autograd.Function): """LayerNormLinear semi-top level module - Calls custom cuda extensions. + Calls custom musa extensions. """ @staticmethod @@ -788,7 +788,7 @@ def backward( # We use the send stream to copy into the userbuffers. # This is the same stream that we will use to access the data in the AG, # so we dont need to add any syncs yet. - with torch.cuda.stream(dgrad_send_stream): + with torch.musa.stream(dgrad_send_stream): grad_output, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_overlap_wgrad, grad_outputs[0], @@ -1101,7 +1101,7 @@ class LayerNormLinear(TransformerEngineBaseModule): .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta - device : Union[torch.device, str], default = "cuda" + device : Union[torch.device, str], default = "musa" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -1175,7 +1175,7 @@ def __init__( return_layernorm_output_gathered: bool = False, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = "musa", ub_overlap_ag: bool = False, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, @@ -1535,7 +1535,7 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with torch.cuda.device( + with torch.musa.device( getattr(self, list(self.named_parameters())[0][0]).device ), self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a2ddb970af..2a55b9e69e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -153,7 +153,7 @@ def _act_func(activation: str, recipe: Optional[Recipe] = None): class _LayerNormMLP(torch.autograd.Function): """LayerNormMLP semi-top level module - Calls custom cuda extensions. + Calls custom musa extensions. """ @staticmethod @@ -904,7 +904,7 @@ def backward( # We use the send stream to copy into the userbuffers. # This is the same stream that we will use to access the data in the AG, # so we dont need to add any syncs yet. - with torch.cuda.stream(dgrad_send_stream): + with torch.musa.stream(dgrad_send_stream): grad_output, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc2_wgrad, grad_outputs[0], @@ -1098,7 +1098,7 @@ def fc2_wgrad_gemm( reduce_scatter_out = None if ctx.ub_overlap_rs_dgrad: reduce_scatter_out = torch.empty( - fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="musa" ) if ctx.ub_bulk_wgrad: gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) @@ -1181,7 +1181,7 @@ def fc2_wgrad_gemm( reduce_scatter_out = None if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): reduce_scatter_out = torch.empty( - fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="musa" ) # Arguments to include in wgrad GEMM closure @@ -1309,14 +1309,14 @@ def fc1_wgrad_gemm( fc1_wgrad = torch.zeros( origin_fc1_weight.main_grad.shape, dtype=origin_fc1_weight.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), requires_grad=False, ) else: fc1_wgrad = torch.empty( origin_fc1_weight.main_grad.shape, dtype=origin_fc1_weight.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), requires_grad=False, ) elif ctx.fuse_wgrad_accumulation: @@ -1334,14 +1334,14 @@ def fc1_wgrad_gemm( fc2_wgrad = torch.zeros( origin_fc2_weight.main_grad.shape, dtype=origin_fc2_weight.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), requires_grad=False, ) else: fc2_wgrad = torch.empty( origin_fc2_weight.main_grad.shape, dtype=origin_fc2_weight.dtype, - device=torch.cuda.current_device(), + device=torch.musa.current_device(), requires_grad=False, ) elif ctx.fuse_wgrad_accumulation: @@ -1462,7 +1462,7 @@ class LayerNormMLP(TransformerEngineBaseModule): .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta - device : Union[torch.device, str], default = "cuda" + device : Union[torch.device, str], default = "musa" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -1546,7 +1546,7 @@ def __init__( micro_batch_size: Optional[int] = None, set_parallel_mode: bool = False, zero_centered_gamma: bool = False, - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = "musa", ub_overlap_ag: bool = False, name: str = None, ub_overlap_rs: bool = False, @@ -1806,7 +1806,7 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with torch.cuda.device( + with torch.musa.device( getattr(self, list(self.named_parameters())[0][0]).device ), self.prepare_forward(inp, num_gemms=2) as inp: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 42f29d06ee..5a02aac8a5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -77,7 +77,7 @@ class _Linear(torch.autograd.Function): """Linear semi-top level module - Calls custom cuda extensions. + Calls custom musa extensions. """ @staticmethod @@ -794,7 +794,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # We use the send stream to copy into the userbuffers. # This is the same stream that we will use to access the data in the AG, # so we dont need to add any syncs yet. - with torch.cuda.stream(dgrad_send_stream): + with torch.musa.stream(dgrad_send_stream): grad_output, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_overlap_wgrad, grad_output_arg, @@ -1035,7 +1035,7 @@ class Linear(TransformerEngineBaseModule): values as split sizes along dim 0. The resulting parameters will have names that end in `_weight` or `_bias`, so trailing underscores are stripped from any provided names. - device : Union[torch.device, str], default = "cuda" + device : Union[torch.device, str], default = "musa" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. @@ -1110,7 +1110,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, - device: Union[torch.device, str] = "cuda", + device: Union[torch.device, str] = "musa", ub_overlap_ag: bool = False, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, @@ -1420,7 +1420,7 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with torch.cuda.device( + with torch.musa.device( getattr(self, list(self.named_parameters())[0][0]).device ), self.prepare_forward( inp, diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index fb267d8a9b..b5dcd4a7df 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -37,7 +37,7 @@ class RMSNorm(_RMSNormOp): Inner dimensions of input tensor eps : float, default = 1e-5 A value added to the denominator for numerical stability - device: torch.device, default = default CUDA device + device: torch.device, default = default MUSA device Tensor device dtype: torch.dtype, default = default dtype Tensor datatype @@ -49,7 +49,7 @@ class RMSNorm(_RMSNormOp): y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) sm_margin: int, default = 0 - Number of SMs to exclude when launching CUDA kernels. This + Number of SMs to exclude when launching MUSA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM margin at each compute stage ("forward", "backward", diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 639817ada7..de78103ae2 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -222,6 +222,73 @@ def get_grad_output_quantizer(self) -> Optional[Quantizer]: return self.get_quantizer("backward", 0) return None + + + def _update_quantization_recipe_state( + self, + *, + recipe: Optional[Recipe] = None, + ) -> None: + """Make sure quantizer state matches quantization recipe""" + + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() + + # Reset quantization state if needed + if self._fp8_metas is None or self._quantizers is None: + self._reset_quantization_recipe_state(recipe=recipe) + return + for mode in ("forward", "backward"): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: + continue + recipe_state = self._fp8_metas[mode][fp8_meta_key] + need_to_reset_recipe_state = ( + recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) + ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + if need_to_reset_recipe_state: + self._reset_quantization_recipe_state(recipe=recipe) + return + + # Quantization recipe state for forward and backward pass + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: + continue + + # Update FP8 metadata + fp8_meta = self._fp8_metas[mode] + fp8_meta["recipe"] = recipe + fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + + # Get recipe state + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + recipe_state = fp8_meta[fp8_meta_key] + + # Reallocate amax history if needed + if recipe.mxfp8(): + continue + + current_length = recipe_state.amax_history.size(0) + target_length = recipe.amax_history_len + if current_length != target_length: + with torch.no_grad(): + if target_length < current_length: + recipe_state.amax_history = recipe_state.amax_history[ + :target_length + ].clone() + else: + recipe_state.amax_history = torch.nn.functional.pad( + recipe_state.amax_history, + pad=(0, 0, 0, target_length - current_length), + ) + self._quantizers[mode] = recipe_state.make_quantizers() + def reset_recipe_state( self, *, diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 7689e20194..b7b90fcbac 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -5,7 +5,7 @@ """Custom tensor classes""" import torch - +from . import _internal from .quantized_tensor import ( QuantizedTensorStorage, QuantizedTensor, @@ -24,6 +24,7 @@ from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ + "_internal", "Quantizer", "Float8Quantizer", "Float8CurrentScalingQuantizer", diff --git a/transformer_engine/pytorch/tensor/_internal/__init__.py b/transformer_engine/pytorch/tensor/_internal/__init__.py new file mode 100644 index 0000000000..e13014bf75 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Internal data structures for quantized tensors.""" diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py new file mode 100644 index 0000000000..b0b6f98e6c --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8Tensor""" + +from __future__ import annotations +from typing import Any, Dict, Optional, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: Float8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._data is not None: + # Cast from FP8 + return tex.dequantize(tensor, dtype) + + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class Float8TensorBase: + """Mixin class that holds data attributes of Float8Tensor. + + Float8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _scale_inv: torch.Tensor + + # FP8 transpose cache + _transpose: Optional[torch.Tensor] + _transpose_invalid: bool + + def __new__( + cls, + *args, + data: Optional[torch.Tensor], + fp8_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + data_transpose: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + if cls is Float8TensorBase: + instance = object.__new__(cls) + else: + instance = super().__new__(cls, *args, **kwargs) + instance._data = data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._scale_inv = fp8_scale_inv + instance._transpose = data_transpose + instance._transpose_invalid = instance._transpose is None + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "data": self._data, + "fp8_scale_inv": self._scale_inv, + "fp8_dtype": self._fp8_dtype, + "data_transpose": self._transpose, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + tensors = [self._data, self._transpose] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list""" + self._data = tensors[0] + self._transpose = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._data, self._transpose + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromFloat8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + return self._data.size(*args, **kwargs) + + def __repr__(self): + return ( + "Float8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py new file mode 100644 index 0000000000..bd581feab1 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -0,0 +1,134 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for MXFP8Tensor""" + +from __future__ import annotations +from typing import Optional, Dict, Any, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromMXFP8Func(torch.autograd.Function): + """Cast from MXFP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: MXFP8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._rowwise_data is not None: + return tex.dequantize(tensor, dtype) + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class MXFP8TensorBase: + """Mixin class that holds data attributes of MXFP8Tensor. + + MXFP8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + """Prepare the tensor base for saving for backward + + After calling this, the tensor instance does not hold any + data. + + """ + tensors = [self._rowwise_data, self._columnwise_data] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromMXFP8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + return self._rowwise_data.size(*args, **kwargs) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "MXFP8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"rowwise_scaled_data={data_rowwise}" + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + ")" + ) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a4e68e53b0..775eac664b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType - +from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 2be0aed4a8..98b749d682 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -33,6 +33,13 @@ def _empty_tensor() -> torch.Tensor: return torch.Tensor().cuda() + +def non_tn_fp8_gemm_supported() -> bool: + """Checks whether the device supports + non-TN layouts for FP8 GEMMs. + """ + return torch.cuda.get_device_capability() >= (10, 0) + def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """ Trick to deallocate tensor memory when delete operation does not