From b7ece1a16890193de6fc2586d6827762aa7ba8e4 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 22 May 2024 11:50:25 +0800 Subject: [PATCH] support bf16 on disc backend (#40) add float-norm pass to support bf16 amp training --- third_party/BladeDISC | 2 +- torch_xla/csrc/runtime/BUILD | 2 + .../custom_call_flash_attention_backward.cc | 4 ++ .../custom_call_flash_attention_forward.cc | 2 + .../csrc/runtime/disc_computation_client.cc | 48 +++++++++++++++++-- 5 files changed, 53 insertions(+), 5 deletions(-) diff --git a/third_party/BladeDISC b/third_party/BladeDISC index 67c324289c36..fbe39bce9ae2 160000 --- a/third_party/BladeDISC +++ b/third_party/BladeDISC @@ -1 +1 @@ -Subproject commit 67c324289c36da5187405c18600403a0d3681b61 +Subproject commit fbe39bce9ae2d365d77842af38a33fa76d37237a diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 9bd276c856f2..5d2992bf802a 100755 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -167,6 +167,8 @@ cc_library( "@xla//xla/client:xla_computation", "//torch_xla/csrc/runtime/disc:disc_ral", "//torch_xla/csrc/runtime/disc:disc_compile", + "@xla//xla/service:float_normalization", + "@xla//xla/service/gpu:gpu_float_support", ], ) diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc index efd4f775f489..8f4460d81457 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc @@ -27,6 +27,7 @@ namespace tao { namespace ral { DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); struct FlashAttentionBackwardParams { using index_t = uint32_t; @@ -235,6 +236,7 @@ custom_call_flash_attention_backward( memset(&launch_params, 0, sizeof(launch_params)); launch_params.is_bf16 = params.is_bf16; + launch_params.is_bf16 = true; // Set the pointers and strides. launch_params.q_ptr = q.data; @@ -380,6 +382,8 @@ TAO_RAL_API("custom_call_flash_attention_backward", "gpu", custom_call_flash_attention_backward); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", custom_call_flash_attention_backward); +TAO_RAL_API("custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward); } // namespace ral } // namespace tao \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc index 7d5ded3ebeb2..ca281319b856 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc @@ -245,6 +245,8 @@ TAO_RAL_API("custom_call_flash_attention_forward", "gpu", custom_call_flash_attention_forward); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", custom_call_flash_attention_forward); +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward); } // namespace ral } // namespace tao diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index d4066f441174..dbf5ca065c47 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -18,6 +19,11 @@ #include "torch_xla/csrc/runtime/disc/disc_compile.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/float_normalization.h" +#include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/hlo_proto_util.h" namespace torch_xla { namespace runtime { @@ -172,10 +178,44 @@ std::vector DISCComputationClient::Compile( mlir::MLIRContext context; mlir::ModuleOp mlir_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - auto status = torch_xla::ConvertHloToMhlo( - instance.computation.mutable_proto(), &mlir_module); - XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n" - << status.message(); + + auto hlo_proto = instance.computation.proto(); + auto program_shape = instance.computation.GetProgramShape().value(); + xla::HloModuleConfig module_config(program_shape); + module_config.set_debug_options(xla::GetDebugOptionsFromFlags()); + xla::ComputationLayout* entry_layout = + module_config.mutable_entry_computation_layout(); + for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) { + auto status = + entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + program_shape.parameters(i)); + if (!status.ok()) { + XLA_ERROR() << "Error copying layout from shape: "; + return {}; + } + } + + std::unique_ptr hlo_module = + xla::CreateModuleFromProto(hlo_proto, module_config).value(); + xla::HloPassPipeline pipeline("pre-stablehlo"); + stream_executor::CudaComputeCapability gpu_version; + auto dprops = at::cuda::getCurrentDeviceProperties(); + gpu_version.major = dprops->major; + gpu_version.minor = dprops->minor; + xla::gpu::GpuFloatSupport bf16_support(gpu_version, xla::BF16); + pipeline.AddPass(&bf16_support); + auto status = pipeline.Run(hlo_module.get()).status(); + if (!status.ok()) { + XLA_ERROR() << "Error running pre-stablehlo pass pipeline: "; + return {}; + } + { + auto mutable_hlo_proto = hlo_module->ToProto(); + auto status = + torch_xla::ConvertHloToMhlo(&mutable_hlo_proto, &mlir_module); + XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n" + << status.message(); + } // Add input and output attributes auto entry_func_identifier =