Skip to content

Commit

Permalink
support bf16 on disc backend (pytorch#40)
Browse files Browse the repository at this point in the history
add float-norm pass to support bf16 amp training
  • Loading branch information
Yancey1989 authored and yitongh committed Aug 8, 2024
1 parent f9057f5 commit 0d40623
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 5 deletions.
2 changes: 1 addition & 1 deletion third_party/BladeDISC
Submodule BladeDISC updated 60 files
+2 −2 .github/workflows/pytorch113_gpu.yml
+2 −1 docker/dev/Dockerfile.aarch64
+1 −1 docker/scripts/install-python.sh
+0 −1 pytorch_blade/pytorch_blade/compiler/jit/torch/constant_propagation.cpp
+1 −1 pytorch_blade/pytorch_blade/compiler/jit/torch/freeze_module.cpp
+1 −0 pytorch_blade/pytorch_blade/compiler/mlir/runtime/BUILD
+0 −1 pytorch_blade/scripts/pip/requirements-dev-1.13.1+cpu.txt
+0 −1 pytorch_blade/scripts/pip/requirements-dev-1.13.1+cu116.txt
+12 −41 pytorch_blade/tests/disc/ops/test_scatter.py
+8 −0 tao_compiler/ci_build/platforms/tao/gpu/env.conf.cuda11_8
+3 −1 tao_compiler/mlir/custom_ops/custom_library/transpose_gpu.cu.cc
+3 −1 tao_compiler/mlir/custom_ops/transpose_impl.cc
+183 −0 tao_compiler/mlir/disc/BUILD
+25 −0 tao_compiler/mlir/disc/IR/lhlo_disc_ops.td
+25 −4 tao_compiler/mlir/disc/disc_compiler.cc
+1 −0 tao_compiler/mlir/disc/tests/BUILD
+1 −0 tao_compiler/mlir/disc/tools/disc-opt/disc-opt.cc
+1 −0 tao_compiler/mlir/disc/tools/disc-replay/BUILD
+172 −5 tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc
+108 −0 tao_compiler/mlir/disc/transforms/disc_argsmutation_expand.cc
+2 −2 tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc
+16 −14 tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc
+355 −0 tao_compiler/mlir/disc/transforms/disc_collective_ops_rewriter.cc
+143 −8 tao_compiler/mlir/disc/transforms/disc_hlo_legalize_to_lhlo.cc
+163 −0 tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc
+0 −16 tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc
+2 −3 tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc
+999 −0 tao_compiler/mlir/disc/transforms/disc_op_schedule.cc
+113 −0 tao_compiler/mlir/disc/transforms/disc_optimization_barrier_expand.cc
+20 −0 tao_compiler/mlir/disc/transforms/disc_passes.td
+90 −0 tao_compiler/mlir/disc/transforms/disc_reduce_buffer_live_range.cc
+731 −0 tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc
+1 −1 tao_compiler/mlir/disc/transforms/disc_supported_list.h.inc
+12 −1 tao_compiler/mlir/disc/transforms/disc_to_llvm.cc
+19 −17 tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc
+6 −3 tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc
+97 −1 tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc
+15 −0 tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td
+1 −1 tao_compiler/mlir/disc/transforms/mhlo_placer.cc
+19 −0 tao_compiler/mlir/disc/transforms/passes.h
+36 −0 tao_compiler/mlir/disc/transforms/tests/disc-algebraic-simplifier.mlir
+14 −0 tao_compiler/mlir/disc/transforms/tests/disc-async-collective-ops-rewriter.mlir
+23 −0 tao_compiler/mlir/disc/transforms/tests/disc-collective-ops-rewriter.mlir
+11 −0 tao_compiler/mlir/disc/transforms/tests/disc-hlo-legalize-to-lhlo.mlir
+11 −0 tao_compiler/mlir/disc/transforms/tests/disc-input-output-alias.mlir
+22 −0 tao_compiler/mlir/disc/transforms/tests/disc-optimization-barrier-expand.mlir
+221 −0 tao_compiler/mlir/disc/transforms/tests/disc-shape-propagate.mlir
+1 −1 tao_compiler/mlir/disc/transforms/tests/input-mutation.mlir
+18 −0 tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir
+186 −154 tao_compiler/mlir/ral/BUILD
+518 −0 tao_compiler/mlir/ral/collective.cu.cc
+33 −0 tao_compiler/mlir/ral/collective.h
+9 −0 tao_compiler/mlir/ral/context/base/base_context.cc
+1 −0 tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc
+27 −6 tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc
+12 −0 tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h
+22 −0 tao_compiler/mlir/ral/context/common_context_impl_cuda.cc
+18 −3 tao_compiler/mlir/ral/context/stream_executor_based_impl.cc
+10 −0 tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc
+3 −1 tao_compiler/mlir/ral/ral_logging.cc
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ cc_library(
"@xla//xla/client:xla_computation",
"//torch_xla/csrc/runtime/disc:disc_ral",
"//torch_xla/csrc/runtime/disc:disc_compile",
"@xla//xla/service:float_normalization",
"@xla//xla/service/gpu:gpu_float_support",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace tao {
namespace ral {

DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16");
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");

struct FlashAttentionBackwardParams {
using index_t = uint32_t;
Expand Down Expand Up @@ -235,6 +236,7 @@ custom_call_flash_attention_backward(
memset(&launch_params, 0, sizeof(launch_params));

launch_params.is_bf16 = params.is_bf16;
launch_params.is_bf16 = true;

// Set the pointers and strides.
launch_params.q_ptr = q.data;
Expand Down Expand Up @@ -380,6 +382,8 @@ TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<float, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<Eigen::half, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<Eigen::bfloat16, float, 3>);

} // namespace ral
} // namespace tao
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ TAO_RAL_API("custom_call_flash_attention_forward", "gpu",
custom_call_flash_attention_forward<float, float, 3>);
TAO_RAL_API("custom_call_flash_attention_forward", "gpu",
custom_call_flash_attention_forward<Eigen::half, float, 3>);
TAO_RAL_API("custom_call_flash_attention_forward", "gpu",
custom_call_flash_attention_forward<bfloat16, float, 3>);

} // namespace ral
} // namespace tao
48 changes: 44 additions & 4 deletions torch_xla/csrc/runtime/disc_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/cuda.h>

#include <vector>
Expand All @@ -18,6 +19,11 @@
#include "torch_xla/csrc/runtime/disc/disc_compile.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/float_normalization.h"
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/service/hlo_proto_util.h"

namespace torch_xla {
namespace runtime {
Expand Down Expand Up @@ -172,10 +178,44 @@ std::vector<ComputationClient::ComputationPtr> DISCComputationClient::Compile(
mlir::MLIRContext context;
mlir::ModuleOp mlir_module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
auto status = torch_xla::ConvertHloToMhlo(
instance.computation.mutable_proto(), &mlir_module);
XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n"
<< status.message();

auto hlo_proto = instance.computation.proto();
auto program_shape = instance.computation.GetProgramShape().value();
xla::HloModuleConfig module_config(program_shape);
module_config.set_debug_options(xla::GetDebugOptionsFromFlags());
xla::ComputationLayout* entry_layout =
module_config.mutable_entry_computation_layout();
for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) {
auto status =
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
program_shape.parameters(i));
if (!status.ok()) {
XLA_ERROR() << "Error copying layout from shape: ";
return {};
}
}

std::unique_ptr<xla::HloModule> hlo_module =
xla::CreateModuleFromProto(hlo_proto, module_config).value();
xla::HloPassPipeline pipeline("pre-stablehlo");
stream_executor::CudaComputeCapability gpu_version;
auto dprops = at::cuda::getCurrentDeviceProperties();
gpu_version.major = dprops->major;
gpu_version.minor = dprops->minor;
xla::gpu::GpuFloatSupport bf16_support(gpu_version, xla::BF16);
pipeline.AddPass<xla::FloatNormalization>(&bf16_support);
auto status = pipeline.Run(hlo_module.get()).status();
if (!status.ok()) {
XLA_ERROR() << "Error running pre-stablehlo pass pipeline: ";
return {};
}
{
auto mutable_hlo_proto = hlo_module->ToProto();
auto status =
torch_xla::ConvertHloToMhlo(&mutable_hlo_proto, &mlir_module);
XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n"
<< status.message();
}

// Add input and output attributes
auto entry_func_identifier =
Expand Down

0 comments on commit 0d40623

Please sign in to comment.