diff --git a/.bazelrc b/.bazelrc index 34c41167982..d92caaf0fed 100644 --- a/.bazelrc +++ b/.bazelrc @@ -223,3 +223,6 @@ build:linux --copt="-Wno-error=unused-but-set-variable" # Only include debug info for files not under XLA. build:dbg -c dbg build:dbg --per_file_copt=external/xla/.*@-g0,-DNDEBUG + +# build with DISC backend +build --define enable_disc=true diff --git a/.gitmodules b/.gitmodules index 32423922406..9a14e40cb54 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "third_party/flash-attention"] path = third_party/flash-attention url = https://github.com/Dao-AILab/flash-attention.git +[submodule "third_party/BladeDISC"] + path = third_party/BladeDISC + url = https://github.com/alibaba/BladeDISC.git diff --git a/BUILD b/BUILD index 6949f6dc748..6b00ea55070 100644 --- a/BUILD +++ b/BUILD @@ -3,6 +3,11 @@ load( "if_cuda_is_configured", ) +load( + "//bazel:rules_def.bzl", + "if_enable_disc", +) + cc_binary( name = "_XLAC.so", copts = [ @@ -28,5 +33,7 @@ cc_binary( "@torch//:libtorch_python", ] + if_cuda_is_configured([ "@xla//xla/stream_executor:cuda_platform", + ]) + if_enable_disc([ + "//torch_xla/csrc/runtime/disc:disc_ral", ]), ) diff --git a/WORKSPACE b/WORKSPACE index c007f07d271..9e10a077ae7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -89,3 +89,10 @@ new_local_repository( build_file = "//bazel:flash_attn.BUILD", path = "third_party/flash-attention/", ) +################################ BladeDISC Setup ################################ + +new_local_repository( + name = "disc_compiler", + build_file = "//bazel:disc.BUILD", + path = "third_party/BladeDISC/", +) diff --git a/bazel/disc.BUILD b/bazel/disc.BUILD new file mode 100644 index 00000000000..7ed33604618 --- /dev/null +++ b/bazel/disc.BUILD @@ -0,0 +1,49 @@ + +package( + default_visibility = [ + "//visibility:public", + ], +) + +cc_library( + name = "headers", + hdrs = glob( + [ + "mlir/ral/*.h", + "mlir/ral/context/base/cuda/*.h", + "mlir/ral/context/base/cuda/cuda_context_impl.h", + "mlir/ral/device/cpu/*.h", + "mlir/ral/device/gpu/*.h", + ], + ), + includes = [ + "tao_compiler", + "tao_compiler/mlir", + ], + strip_include_prefix = "external/disc_compiler/tao_compiler/mlir", +) + +cc_import( + name="disc_ral_cuda", + shared_library = ":libral_base_context.so", +) + +cc_import( + name="disc_custom_op", + shared_library = ":libdisc_custom_ops.so", +) + +genrule( + name = "build_disc", + outs = ["libral_base_context.so", "libdisc_custom_ops.so", "disc_compiler_main", "torch-mlir-opt"], + local = True, + cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}', + 'pushd external/disc_compiler/pytorch_blade/', + 'python ../scripts/python/common_setup.py', + 'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel', + 'popd', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/ral/libral_base_context.so $(location libral_base_context.so)', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/custom_ops/libdisc_custom_ops.so $(location libdisc_custom_ops.so)', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/tests/mhlo/torch-mlir-opt/torch-mlir-opt $(location torch-mlir-opt)']), +) diff --git a/bazel/flash_attn.BUILD b/bazel/flash_attn.BUILD index e5fd5ca6013..6be811b826b 100644 --- a/bazel/flash_attn.BUILD +++ b/bazel/flash_attn.BUILD @@ -23,9 +23,8 @@ genrule( name = "build_flash_attn", srcs = ["setup.py"], outs = ["flash_attn_cuda.so"], - cmd = ';'.join(['pushd third_party/flash-attention/', - 'MAX_JOBS=50 FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel 2>&1 | tee build.log', - 'cp build/*/*.so flash_attn_cuda.so', + cmd = ';'.join(['pushd external/flash_attn/', + 'FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel', 'popd', - 'cp third_party/flash-attention/flash_attn_cuda.so $(OUTS)']), + 'cp external/flash_attn/build/*/*.so $(location flash_attn_cuda.so)']), ) diff --git a/bazel/rules_def.bzl b/bazel/rules_def.bzl index 4569630f170..b10cb659e9a 100644 --- a/bazel/rules_def.bzl +++ b/bazel/rules_def.bzl @@ -39,3 +39,9 @@ def ptxla_cc_test( ], **kwargs ) + +def if_enable_disc(if_true, if_false=[]): + return select({ + "//torch_xla/csrc/runtime:enable_disc": if_true, + "//conditions:default": if_false + }) \ No newline at end of file diff --git a/bazel/torch.BUILD b/bazel/torch.BUILD index b91d75f9f0b..d2be73c6bcb 100644 --- a/bazel/torch.BUILD +++ b/bazel/torch.BUILD @@ -55,6 +55,10 @@ cc_import( shared_library = "build/lib/libtorch_cpu.so", ) +cc_import( + name = "libtorch_cuda", + shared_library = "build/lib/libtorch_cuda.so", +) cc_import( name = "libtorch_python", shared_library = "build/lib/libtorch_python.so", @@ -64,3 +68,7 @@ cc_import( name = "libc10", shared_library = "build/lib/libc10.so", ) +cc_import( + name = "libc10_cuda", + shared_library = "build/lib/libc10_cuda.so", +) diff --git a/setup.py b/setup.py index c1912d832b4..3531a61e211 100644 --- a/setup.py +++ b/setup.py @@ -231,6 +231,9 @@ def bazel_build(self, ext): bazel_argv.extend(build_util.bazel_options_from_env()) + if not build_util.check_env_flag('ENABLE_DISC', 'false'): + bazel_argv.append('--define=enable_disc=false') + self.spawn(bazel_argv) ext_bazel_bin_path = os.path.join(self.build_temp, 'bazel-bin', ext.relpath, @@ -244,9 +247,24 @@ def bazel_build(self, ext): # copy flash attention cuda so file flash_attn_so_name = 'flash_attn_cuda.so' - shutil.copyfile( - '/'.join(['third_party/flash-attention', flash_attn_so_name]), - '/'.join([ext_dest_dir, flash_attn_so_name])) + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/flash_attn/' + shutil.copyfile('/'.join([bazel_bin_path, flash_attn_so_name]), + '/'.join([ext_dest_dir, flash_attn_so_name])) + + # package BladeDISC distribution files + # please note, TorchBlade also create some symbolic links to 'torch_blade' dir + if build_util.check_env_flag('ENABLE_DISC', 'false'): + disc_ral_so_name = 'libral_base_context.so' + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' + shutil.copyfile( + os.path.join(bazel_bin_path, disc_ral_so_name), + '/'.join([ext_dest_dir, disc_ral_so_name])) + + disc_customop_so_name = 'libdisc_custom_ops.so' + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' + shutil.copyfile( + os.path.join(bazel_bin_path, disc_customop_so_name), + '/'.join([ext_dest_dir, disc_customop_so_name])) class Develop(develop.develop): diff --git a/test/test_flash_attention_backward.py b/test/test_flash_attention_backward.py old mode 100644 new mode 100755 diff --git a/third_party/BladeDISC b/third_party/BladeDISC new file mode 160000 index 00000000000..fbe39bce9ae --- /dev/null +++ b/third_party/BladeDISC @@ -0,0 +1 @@ +Subproject commit fbe39bce9ae2d365d77842af38a33fa76d37237a diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 728c4eacd56..55fab72fb5d 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -1,6 +1,7 @@ load( "//bazel:rules_def.bzl", "ptxla_cc_library", + "ptxla_cc_test", ) genrule( diff --git a/torch_xla/csrc/ops/flash_attention_forward.cpp b/torch_xla/csrc/ops/flash_attention_forward.cpp index 5c478f69a51..9a73f26a9ba 100644 --- a/torch_xla/csrc/ops/flash_attention_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_forward.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(int batch_size, int num_heads, int seqlen_q, xla::PrimitiveType::F32, {batch_size, num_heads, seqlen_q}); xla::Shape out_shape = GetXlaShape(q); xla::Shape rng_state_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2}); + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2}); return xla::ShapeUtil::MakeTupleShape( {softmax_lse_shape, out_shape, rng_state_shape}); } diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD old mode 100644 new mode 100755 index 720452b93ca..9ef0c9a6df9 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -10,13 +10,20 @@ load( load( "//bazel:rules_def.bzl", + "ptxla_cc_library", "ptxla_cc_test", + "if_enable_disc", ) licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +config_setting( + name = "enable_disc", + define_values = {"enable_disc": "true"}, +) + cc_library( name = "runtime", srcs = [ @@ -31,7 +38,12 @@ cc_library( ":pjrt_computation_client", ":ifrt_computation_client", "@tsl//tsl/platform:stacktrace", - ], + ] + if_enable_disc([ + ":disc_computation_client", + ]), + copts = if_enable_disc([ + "-DTORCHACC_ENABLE_DISC", + ]), ) cc_library( @@ -137,6 +149,29 @@ cc_library( ], ) +cc_library( + name = "disc_computation_client", + srcs = [ + "disc_computation_client.cc", + ], + hdrs = [ + "disc_computation_client.h", + ], + deps = [ + ":computation_client", + ":stablehlo_helper", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla/client:xla_computation", + "//torch_xla/csrc/runtime/disc:disc_ral", + "//torch_xla/csrc/runtime/disc:disc_compile", + "@xla//xla/service:float_normalization", + "@xla//xla/service/gpu:gpu_float_support", + ], +) + cc_library( name = "cache", hdrs = ["cache.h"], @@ -519,3 +554,25 @@ ptxla_cc_test( "@tsl//tsl/platform:test_main", ], ) + +ptxla_cc_test( + name = "disc_computation_client_test", + srcs = ["disc_computation_client_test.cc"], + deps = [ + ":disc_computation_client", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@stablehlo//:register", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/BUILD b/torch_xla/csrc/runtime/disc/BUILD new file mode 100755 index 00000000000..999aa85ea64 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/BUILD @@ -0,0 +1,81 @@ +load( + "@tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) + +load( + "//bazel:rules_def.bzl", + "ptxla_cc_library", + "ptxla_cc_test", +) + + +ptxla_cc_library( + name = "disc_ral", + srcs = [ + "disc_ral.cc", + "custom_call_flash_attention_forward.cc", + "custom_call_flash_attention_backward.cc" + ], + hdrs = [ + "disc_ral.h", + ], + deps = [ + ":disc_utils", + "@disc_compiler//:disc_ral_cuda", + "@disc_compiler//:disc_custom_op", + "@disc_compiler//:headers", + "@local_config_cuda//cuda:cuda_headers", + "@nccl_archive//:nccl_headers", + "@torch//:libc10", + "@torch//:libc10_cuda", + "@torch//:libtorch_cuda", + "@flash_attn//:headers", + "@flash_attn//:flash_attn_cuda", + ], + copts = [ + "-DGOOGLE_CUDA", + ] +) + +ptxla_cc_library( + name = "disc_utils", + srcs = ["disc_utils.cc"], + hdrs = [ + "disc_utils.h", + ], + deps = [ + "//torch_xla/csrc/runtime:tf_logging", + ] +) + +ptxla_cc_library( + name = "disc_compile", + srcs = ["disc_compile.cc"], + hdrs = [ + "disc_compile.h", + ], + deps = [ + ":disc_ral", + ":disc_utils", + "//torch_xla/csrc/runtime:tf_logging", + "//torch_xla/csrc/runtime:sys_util", + "//torch_xla/csrc/runtime:env_vars", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], + copts = [ + "-DGOOGLE_CUDA", + ] +) + +ptxla_cc_test( + name = "disc_ral_test", + srcs = ["disc_ral_test.cc"], + deps = [ + ":disc_ral", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] +) diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc new file mode 100644 index 00000000000..402eeedc3de --- /dev/null +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc @@ -0,0 +1,466 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "mlir/ral/context/pdll_util.h" +#include "mlir/ral/context/stream_executor_based_impl.h" +#include "static_switch.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace tao { +namespace ral { + +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); + +struct FlashAttentionBackwardParams { + using index_t = uint32_t; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k + // could be different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + int total_q; + int total_k; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + bool is_bf16; + bool is_causal; + int window_size_left; + int window_size_right; + int alibi_slopes_batch_stride; + bool enable_alibi_slopes; + bool is_seqlens_k_cumulative; + int num_splits; + + // Backward specific params + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + bool deterministic; + + void FromString(const std::string& str) { + std::vector params_list = absl::StrSplit(str, "|"); + TORCH_CHECK(params_list.size() == 51); + + // Forward specific param + absl::SimpleAtoi(params_list[0], &this->q_batch_stride); + absl::SimpleAtoi(params_list[1], &this->k_batch_stride); + absl::SimpleAtoi(params_list[2], &this->v_batch_stride); + absl::SimpleAtoi(params_list[3], &this->q_row_stride); + absl::SimpleAtoi(params_list[4], &this->k_row_stride); + absl::SimpleAtoi(params_list[5], &this->v_row_stride); + absl::SimpleAtoi(params_list[6], &this->q_head_stride); + absl::SimpleAtoi(params_list[7], &this->k_head_stride); + absl::SimpleAtoi(params_list[8], &this->v_head_stride); + absl::SimpleAtoi(params_list[9], &this->total_q); + absl::SimpleAtoi(params_list[10], &this->total_k); + absl::SimpleAtoi(params_list[11], &this->h); + absl::SimpleAtoi(params_list[12], &this->h_k); + absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[14], &this->o_batch_stride); + absl::SimpleAtoi(params_list[15], &this->o_row_stride); + absl::SimpleAtoi(params_list[16], &this->o_head_stride); + absl::SimpleAtoi(params_list[17], &this->b); + absl::SimpleAtoi(params_list[18], &this->seqlen_q); + absl::SimpleAtoi(params_list[19], &this->seqlen_k); + absl::SimpleAtoi(params_list[20], &this->d); + absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[23], &this->d_rounded); + absl::SimpleAtof(params_list[24], &this->scale_softmax); + absl::SimpleAtof(params_list[25], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[26], &this->p_dropout); + uint32_t tmp; + absl::SimpleAtoi(params_list[27], &tmp); + this->p_dropout_in_uint8_t = uint8_t(tmp); + absl::SimpleAtof(params_list[28], &this->rp_dropout); + absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[30], &this->is_bf16); + absl::SimpleAtob(params_list[31], &this->is_causal); + absl::SimpleAtoi(params_list[32], &this->window_size_left); + absl::SimpleAtoi(params_list[33], &this->window_size_right); + absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride); + absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative); + absl::SimpleAtoi(params_list[36], &this->num_splits); + absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes); + + // backward specific params + const int offset = 38; // FlashAttentionForwardParams has 38 variables + absl::SimpleAtoi(params_list[offset + 0], &this->do_batch_stride); + absl::SimpleAtoi(params_list[offset + 1], &this->do_row_stride); + absl::SimpleAtoi(params_list[offset + 2], &this->do_head_stride); + absl::SimpleAtoi(params_list[offset + 3], &this->dq_batch_stride); + absl::SimpleAtoi(params_list[offset + 4], &this->dk_batch_stride); + absl::SimpleAtoi(params_list[offset + 5], &this->dv_batch_stride); + absl::SimpleAtoi(params_list[offset + 6], &this->dq_row_stride); + absl::SimpleAtoi(params_list[offset + 7], &this->dk_row_stride); + absl::SimpleAtoi(params_list[offset + 8], &this->dv_row_stride); + absl::SimpleAtoi(params_list[offset + 9], &this->dq_head_stride); + absl::SimpleAtoi(params_list[offset + 10], &this->dk_head_stride); + absl::SimpleAtoi(params_list[offset + 11], &this->dv_head_stride); + absl::SimpleAtob(params_list[offset + 12], &this->deterministic); + } +}; + +void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, + const bool configure) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, + [&] { run_mha_bwd_(params, stream); }); + }); +} + +// Layout of `buffers` listed above: +// buffers[0] = dout +// buffers[1] = q +// buffers[2] = k +// buffers[3] = v +// buffers[4] = out +// buffers[5] = softmax_lse +// buffers[6] = cu_seqlens_q +// buffers[7] = cu_seqlens_k +// buffers[8] = rng_state +// buffers[9] = alibi_slopes +// buffers[10] = dq // this is output +// buffers[11] = dk // this is output +// buffers[12] = dv // this is output +// buffers[13] = softmax_d // this is output +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_impl( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, void* alibi_slopes_ptr, + void* customAttrs) { + auto attr = getOrParsePDLAttr(ctx, customAttrs, + "custom_call_flash_attention_backward"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + std::string backend_config = + dictAttr.get("backend_config").template as().getValue(); + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + + int softmax_element_count = 1, q_element_count = 1, k_element_count = 1, + v_element_count = 1; + for (int i = 0; i < M; i++) { + q_element_count *= q.sizes[i]; + k_element_count *= k.sizes[i]; + v_element_count *= v.sizes[i]; + softmax_element_count *= softmax_lse.sizes[i]; + } + + auto dq_ptr = static_cast( + gpu_driver->alloc(ctx, q_element_count * sizeof(T_IN))); + auto dq_res = assignMemRef(dq_ptr, q.sizes); + + auto dk_ptr = static_cast( + gpu_driver->alloc(ctx, k_element_count * sizeof(T_IN))); + auto dk_res = assignMemRef(dk_ptr, k.sizes); + + auto dv_ptr = static_cast( + gpu_driver->alloc(ctx, v_element_count * sizeof(T_IN))); + auto dv_res = assignMemRef(dv_ptr, v.sizes); + + auto dsoftmax_ptr = static_cast( + gpu_driver->alloc(ctx, softmax_element_count * sizeof(SOFT_MAX_TYPE))); + auto dsoftmax = + assignMemRef(dsoftmax_ptr, softmax_lse.sizes); + + FlashAttentionBackwardParams params; + params.FromString(std::move(backend_config)); + Flash_bwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data; + launch_params.k_ptr = k.data; + launch_params.v_ptr = v.data; + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_ptr = out.data; + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = static_cast(seqlens_q.data); + launch_params.cu_seqlens_k = static_cast(seqlens_k.data); + + launch_params.alibi_slopes_ptr = alibi_slopes_ptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Softmax sum + launch_params.softmax_lse_ptr = softmax_lse.data; + + // Set the dimensions. + launch_params.b = params.b; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = params.seqlen_q; + launch_params.seqlen_k = params.seqlen_k; + launch_params.seqlen_q_rounded = params.seqlen_q_rounded; + launch_params.seqlen_k_rounded = params.seqlen_k_rounded; + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + launch_params.is_causal = params.is_causal; + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = true; + + launch_params.do_ptr = dout.data; + launch_params.do_row_stride = params.do_row_stride; + launch_params.do_head_stride = params.do_head_stride; + launch_params.dq_ptr = dq_res.data; + launch_params.dk_ptr = dk_res.data; + launch_params.dv_ptr = dv_res.data; + launch_params.dq_row_stride = params.dq_row_stride; + launch_params.dk_row_stride = params.dk_row_stride; + launch_params.dv_row_stride = params.dv_row_stride; + launch_params.dq_head_stride = params.dq_head_stride; + launch_params.dk_head_stride = params.dk_head_stride; + launch_params.dv_head_stride = params.dv_head_stride; + + // bool loop = max_seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16; + auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); + at::Tensor dq_accum; + if (loop) { + if (!params.deterministic) { + dq_accum = torch::empty({params.total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } else { + auto dprops = at::cuda::getCurrentDeviceProperties(); + const int nsplits = (dprops->multiProcessorCount + + launch_params.b * launch_params.h - 1) / + (launch_params.b * launch_params.h); + dq_accum = torch::zeros({nsplits, params.total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } + } + + at::Tensor dk = torch::from_blob( + dk_res.data, {params.total_k, launch_params.h_k, launch_params.d}, opts); + at::Tensor dv = torch::from_blob( + dv_res.data, {params.total_k, launch_params.h_k, launch_params.d}, opts); + + at::Tensor dk_expanded, dv_expanded; + + if (launch_params.h_k != launch_params.h) { // MQA / GQA + TF_VLOG(2) << "Running FlashAttention Backward as MQA/GQA"; + dk_expanded = + torch::empty({params.total_k, launch_params.h, launch_params.d}, opts); + dv_expanded = + torch::empty({params.total_k, launch_params.h, launch_params.d}, opts); + + launch_params.dk_ptr = dk_expanded.data_ptr(); + launch_params.dv_ptr = dv_expanded.data_ptr(); + launch_params.dk_row_stride = dk_expanded.stride(-3); + launch_params.dv_row_stride = dv_expanded.stride(-3); + launch_params.dk_head_stride = dk_expanded.stride(-2); + launch_params.dv_head_stride = dv_expanded.stride(-2); + } else { + TF_VLOG(2) << "Running FlashAttention Backward"; + dk_expanded = dk; + dv_expanded = dv; + } + + launch_params.dq_accum_ptr = loop ? dq_accum.data_ptr() : nullptr; + launch_params.dk_accum_ptr = nullptr; + launch_params.dv_accum_ptr = nullptr; + + // Softmax sum + launch_params.dsoftmax_sum = dsoftmax.data; + + launch_params.deterministic = params.deterministic; + launch_params.dq_accum_split_stride = + !launch_params.deterministic ? 0 : dq_accum.stride(0); + + auto launch = &run_mha_bwd; + + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + + bool is_dropout = (1.f - launch_params.p_dropout) > 0.0; + // TODO(wenting.swt): According to the implementation in + // `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates + // `rng_state` which is passed as ctx to the backward. Hence, for simplifying + // the logic, the redundant branch where `rng_state` is None has been omitted. + launch_params.rng_state = reinterpret_cast(rng_state.data); + + launch(launch_params, gpu_stream, /*configure=*/false); + + // For MQA/GQA we need to sum dK and dV across the groups + if (launch_params.h_k != launch_params.h) { + at::sum_out(dk, + at::reshape(dk_expanded, {params.total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + at::sum_out(dv, + at::reshape(dv_expanded, {params.total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + } + + return std::make_tuple(dq_res, dk_res, dv_res, dsoftmax); +} + +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_noalibi( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, nullptr, customAttrs); +} + +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_alibi_v1( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, MemRefType alibi_slopes, + void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, alibi_slopes.data, customAttrs); +} + +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_alibi_v2( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, MemRefType alibi_slopes, + void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, alibi_slopes.data, customAttrs); +} + +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_noalibi); +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v1); +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v2); +TAO_RAL_API("custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_noalibi); +TAO_RAL_API("custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v1); +TAO_RAL_API("custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v2); + +} // namespace ral +} // namespace tao \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc new file mode 100644 index 00000000000..fcac32fa5c3 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc @@ -0,0 +1,331 @@ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "mlir/ral/context/pdll_util.h" +#include "mlir/ral/context/stream_executor_based_impl.h" +#include "static_switch.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace tao { +namespace ral { + +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); + +struct FlashAttentionForwardParams { + using index_t = uint32_t; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k + // could be different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + int total_q; + int total_k; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + bool is_bf16; + bool is_causal; + int window_size_left; + int window_size_right; + int alibi_slopes_batch_stride; + bool enable_alibi_slopes; + bool is_seqlens_k_cumulative; + int num_splits; + + void FromString(const std::string& str) { + std::vector params_list = absl::StrSplit(str, "|"); + TORCH_CHECK(params_list.size() >= 38); // at least 38 variables + absl::SimpleAtoi(params_list[0], &this->q_batch_stride); + absl::SimpleAtoi(params_list[1], &this->k_batch_stride); + absl::SimpleAtoi(params_list[2], &this->v_batch_stride); + absl::SimpleAtoi(params_list[3], &this->q_row_stride); + absl::SimpleAtoi(params_list[4], &this->k_row_stride); + absl::SimpleAtoi(params_list[5], &this->v_row_stride); + absl::SimpleAtoi(params_list[6], &this->q_head_stride); + absl::SimpleAtoi(params_list[7], &this->k_head_stride); + absl::SimpleAtoi(params_list[8], &this->v_head_stride); + absl::SimpleAtoi(params_list[9], &this->total_q); + absl::SimpleAtoi(params_list[10], &this->total_k); + absl::SimpleAtoi(params_list[11], &this->h); + absl::SimpleAtoi(params_list[12], &this->h_k); + absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[14], &this->o_batch_stride); + absl::SimpleAtoi(params_list[15], &this->o_row_stride); + absl::SimpleAtoi(params_list[16], &this->o_head_stride); + absl::SimpleAtoi(params_list[17], &this->b); + absl::SimpleAtoi(params_list[18], &this->seqlen_q); + absl::SimpleAtoi(params_list[19], &this->seqlen_k); + absl::SimpleAtoi(params_list[20], &this->d); + absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[23], &this->d_rounded); + absl::SimpleAtof(params_list[24], &this->scale_softmax); + absl::SimpleAtof(params_list[25], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[26], &this->p_dropout); + uint32_t tmp; + absl::SimpleAtoi(params_list[27], &tmp); + this->p_dropout_in_uint8_t = uint8_t(tmp); + absl::SimpleAtof(params_list[28], &this->rp_dropout); + absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[30], &this->is_bf16); + absl::SimpleAtob(params_list[31], &this->is_causal); + absl::SimpleAtoi(params_list[32], &this->window_size_left); + absl::SimpleAtoi(params_list[33], &this->window_size_right); + absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride); + absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative); + absl::SimpleAtoi(params_list[36], &this->num_splits); + absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes); + } +}; + +// Layout of `buffers` listed above: +// buffers[0] = q +// buffers[1] = k +// buffers[2] = v +// buffers[3] = cu_seqlens_q +// buffers[4] = cu_seqlens_k +// result[0] = softmax_lse // this is output +// result[1] = out_for_output // this is output +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_impl( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + void* alibi_slopes_ptr, void* customAttrs) { + auto attr = getOrParsePDLAttr(ctx, customAttrs, + "custom_call_flash_attention_forward"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + std::string backend_config = + dictAttr.get("backend_config").template as().getValue(); + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + + int output_element_count = 1; + for (int i = 0; i < M; i++) { + output_element_count *= q.sizes[i]; + } + + int bs = seqlens_q.sizes[0] - 1; + int nheads = q.sizes[1]; + int seqlen = q.sizes[0] / bs; + std::vector softmax_lse_sizes{bs, nheads, seqlen}; + + auto softmax_lse_ptr = static_cast( + gpu_driver->alloc(ctx, bs * nheads * seqlen * sizeof(SOFT_MAX_TYPE))); + auto softmax_lse = + assignMemRef(softmax_lse_ptr, softmax_lse_sizes); + + auto output_ptr = static_cast( + gpu_driver->alloc(ctx, output_element_count * sizeof(T_IN))); + auto output = assignMemRef(output_ptr, q.sizes); + + auto rng_state_ptr = + static_cast(gpu_driver->alloc(ctx, 2 * sizeof(int64_t))); + auto rng_state = + assignMemRef(rng_state_ptr, std::vector{2}); + + cudaMemsetAsync(rng_state_ptr, 0, 2 * sizeof(int64_t), gpu_stream); + + FlashAttentionForwardParams params; + params.FromString(std::move(backend_config)); + + Flash_fwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data; + launch_params.k_ptr = k.data; + launch_params.v_ptr = v.data; + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_ptr = output.data; + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = seqlens_q.data; + launch_params.cu_seqlens_k = seqlens_k.data; + launch_params.alibi_slopes_ptr = alibi_slopes_ptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Softmax sum + launch_params.softmax_lse_ptr = softmax_lse.data; + + // Set the dimensions. + launch_params.b = params.b; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = params.seqlen_q; + launch_params.seqlen_k = params.seqlen_k; + launch_params.seqlen_q_rounded = params.seqlen_q_rounded; + launch_params.seqlen_k_rounded = params.seqlen_k_rounded; + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + launch_params.is_causal = params.is_causal; + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = params.is_seqlens_k_cumulative; + + // set params splitkv + launch_params.num_splits = params.num_splits; + + // Forward kernel will populate memory with the seed and offset. + launch_params.rng_state = reinterpret_cast(rng_state_ptr); + + if ((1.f - launch_params.p_dropout) > 0.0) { + // number of times random will be generated per thread, to offset philox + // counter in thc random state We use a custom RNG that increases the offset + // by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + launch_params.philox_args = gen->philox_cuda_state(counter_offset); + } + + FP16_SWITCH(!launch_params.is_bf16, [&] { + HEADDIM_SWITCH(launch_params.d, [&] { + // TODO(wenting.swt): support split_kv + run_mha_fwd_(launch_params, gpu_stream); + }); + }); + + return std::make_tuple(softmax_lse, output, rng_state); +} + +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_noalibi( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, nullptr, customAttrs); +} + +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_alibi_v1( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType alibi_slopes, void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, alibi_slopes.data, + customAttrs); +} + +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_alibi_v2( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType alibi_slopes, void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, alibi_slopes.data, + customAttrs); +} + +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_noalibi); +TAO_RAL_API( + "custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v1); +TAO_RAL_API( + "custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v2); +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_noalibi); +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v1); +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v2); + +} // namespace ral +} // namespace tao diff --git a/torch_xla/csrc/runtime/disc/disc_compile.cc b/torch_xla/csrc/runtime/disc/disc_compile.cc new file mode 100644 index 00000000000..053535f5e2e --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_compile.cc @@ -0,0 +1,107 @@ +#include "torch_xla/csrc/runtime/disc/disc_compile.h" + +#include + +#include + +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/tf_logging.h" +using namespace std::filesystem; + +namespace torch_xla { +namespace runtime { +namespace disc { + +bool IsDiscDebugMode() { return sys_util::GetEnvBool("DISC_DEBUG", false); } + +std::string GetDebugDumpDir() { + return sys_util::GetEnvString("DISC_DEBUG_DUMP_DIR", "./dump_dir"); +} + +std::string CurrentLibLocation() { + Dl_info dl_info; + dladdr((void *)CurrentLibLocation, &dl_info); + auto fname = std::string(dl_info.dli_fname); + return fname.substr(0, fname.find_last_of("/")); +} + +std::string CompileCMD(const std::string &mlir_fname, + const std::string &out_fname) { + std::stringstream ss; + std::string logf = absl::StrCat(mlir_fname, ".log"); + // unset XLA_FLAGS, otherwise tf will throw parse error + std::string compile_cmd = "unset XLA_FLAGS"; + if (IsDiscDebugMode()) { + absl::StrAppend(&compile_cmd, " && export TF_CPP_VMODULE=disc_compiler=1 "); + } + absl::StrAppend(&compile_cmd, "&&", CurrentLibLocation(), + "/disc_compiler_main", " ", mlir_fname, " ", out_fname, " > ", + logf, " 2>&1"); + return compile_cmd; +} + +std::tuple CallDiscCompiler( + const std::string &mlir_fname) { + std::string out_fname = mlir_fname + ".out"; + std::string cmd = CompileCMD(mlir_fname, out_fname); + TF_VLOG(1) << "Executing command: " << cmd << " to compile mhlo..."; + auto ret = std::system(cmd.c_str()); + return {cmd, out_fname, ret}; +} + +std::shared_ptr DumpMlir(mlir::ModuleOp &stablehlo_module) { + std::string model_dump_str; + llvm::raw_string_ostream os(model_dump_str); + stablehlo_module.print(os); + os.flush(); + std::shared_ptr stablehlo_file = std::make_shared("mlir"); + stablehlo_file->WriteBytesToFile(model_dump_str); + return stablehlo_file; +} + +DISCComplationResult Compile(mlir::ModuleOp &module, + const std::vector &inputs, + const std::vector &outputs) { + // Dump stablehlo to file + DISCComplationResult res; + auto mlir_file = DumpMlir(module); + + // Compile mhlo + auto compile_res = CallDiscCompiler(mlir_file->GetFilename()); + auto output_fname = std::get<1>(compile_res); + + if (IsDiscDebugMode()) { + std::string base_path = GetDebugDumpDir(); + auto ret = std::filesystem::create_directory(base_path); + if (ret != 0) { + TF_VLOG(0) << "Failed to create dump dir: " << base_path + << ", it maybe exists.\n"; + } + std::string mlir_fname = mlir_file->GetFilename(); + std::string log_fname = absl::StrCat(mlir_fname, ".log"); + std::filesystem::copy_file( + log_fname, + absl::StrCat(base_path, "/", + std::filesystem::path(mlir_fname).stem().string(), + ".log")); + std::filesystem::copy_file( + mlir_fname, + absl::StrCat(base_path, "/", + std::filesystem::path(mlir_fname).stem().string(), + ".mlir")); + TF_VLOG(1) << "Dumping mlir to file: " << mlir_file->GetFilename(); + } + + // Construct compiled result + res.ral_lib = ReadFileBytes(output_fname); + res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt")); + res.inputs = inputs; + res.outputs = outputs; + + return res; +} + +} // namespace disc +} // namespace runtime +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_compile.h b/torch_xla/csrc/runtime/disc/disc_compile.h new file mode 100644 index 00000000000..0f3ac885211 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_compile.h @@ -0,0 +1,25 @@ +#ifndef XLA_TORCH_XLA_CSRC_RUNTIME_DISC_COMPILE_H_ +#define XLA_TORCH_XLA_CSRC_RUNTIME_DISC_COMPILE_H_ + +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "torch_xla/csrc/runtime/disc/disc_ral.h" + +namespace torch_xla { +namespace runtime { +namespace disc { +DISCComplationResult Compile(mlir::ModuleOp& module, + const std::vector& inputs, + const std::vector& outputs); + +} // namespace disc +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_RUNTIME_DISC_COMPILE_H_ \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_ral.cc b/torch_xla/csrc/runtime/disc/disc_ral.cc new file mode 100644 index 00000000000..ca2216c79b8 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_ral.cc @@ -0,0 +1,284 @@ +#include "torch_xla/csrc/runtime/disc/disc_ral.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace torch_xla { +namespace runtime { +namespace disc { + +class RalAllocator : public tao::ral::Allocator { + public: + using buffer_t = tao::ral::buffer_t; + using alloc_t = tao::ral::alloc_t; + using dealloc_t = tao::ral::dealloc_t; + RalAllocator(alloc_t alloc_func, dealloc_t dealloc_func) + : alloc_func_(alloc_func), dealloc_func_(dealloc_func) {} + + buffer_t alloc(size_t bytes) { return alloc_func_(bytes); } + + void dealloc(buffer_t buffer) { dealloc_func_(buffer); } + + private: + alloc_t alloc_func_; + dealloc_t dealloc_func_; +}; + +RalContext::RalContext(const DISCComplationResult& disc_result) + : disc_result_(disc_result) { + auto is_ok = meta_tmpf_.WriteBytesToFile(disc_result_.ral_mate_pb); + TORCH_CHECK(is_ok, "Failed to dump model proto to file."); + default_opt_.metadata_file_path = meta_tmpf_.GetFilename(); + default_opt_.cache_workspace_mem_across_execution = true; + auto torch_allocator = c10::GetAllocator(torch::kCPU); + TORCH_CHECK(torch_allocator != nullptr); + auto cpu_alloc = [torch_allocator](size_t n) { + return torch_allocator->raw_allocate(n); + }; + auto cpu_delete = [torch_allocator](void* ptr) { + torch_allocator->raw_deallocate(ptr); + }; + cpu_opt_.cpu_allocator.reset(new RalAllocator(cpu_alloc, cpu_delete)); + + at::globalContext().lazyInitCUDA(); + + void* func_handle = nullptr; + std::tie(tao_lib_, func_handle) = LoadEngine(disc_result_.ral_lib); + + using func_t = void (*)(void**); + entry_func_ = (func_t)func_handle; + + CHECK(entry_func_ != nullptr); +} + +std::tuple RalContext::LoadEngine( + const std::string& ral_engine_bytes) { + auto is_ok = lib_tmpf_.WriteBytesToFile(ral_engine_bytes); + TORCH_CHECK(is_ok, "Failed to dump RAL engine to file"); + std::string filename = lib_tmpf_.GetFilename(); + + void* tao_lib = dlopen(filename.c_str(), RTLD_NOW | RTLD_LOCAL); + TORCH_CHECK(tao_lib, "Fail to open ral engine"); + + void* func_handle = dlsym(tao_lib, kMlirLoweredEntry); + TORCH_CHECK(func_handle, "Fail to find kMlirLoweredEntry"); + return std::make_tuple(tao_lib, func_handle); +} + +RalContext::~RalContext() { + if (tao_lib_ != nullptr) { + dlclose(tao_lib_); + } +} + +void RalContext::CheckCurrentDevice(const std::vector& inputs) { + int64_t gpu_device = LazyInitCurrentDevice(); + // Engine Context + if (inputs.empty()) { + return; + } + + torch::Device cur_cuda_device = torch::Device(torch::kCUDA, gpu_device); + + TORCH_CHECK(disc_result_.inputs.size() == inputs.size()); + for (size_t k = 0; k < inputs.size(); ++k) { + at::Tensor inp = inputs[k]; + auto device = disc_result_.inputs[k].device; + if (device == "cuda") { + TORCH_CHECK(inp.device() == cur_cuda_device, "Input tensor ", k, + " device mismatch. Expect: ", cur_cuda_device, + ", got: ", inp.device()); + } + } + return; +} + +int64_t RalContext::LazyInitCurrentDevice() { + int64_t cur_device = c10::cuda::current_device(); + int64_t prev_device = NULL_GPU_DEVICE; + bool success = gpu_device_.compare_exchange_strong(prev_device, cur_device); + if (!success) { + TORCH_CHECK(prev_device == cur_device, + "Device changed during inference. Please do NOT change CUDA " + "current device during inference."); + } + TORCH_CHECK(gpu_device_ != NULL_GPU_DEVICE); + return cur_device; +} + +tao::ral::BaseContext* RalContext::LoadCache() { + int64_t gpu_device = LazyInitCurrentDevice(); + TORCH_CHECK(gpu_device >= 0, "expect gpu device id >= 0, but got ", + gpu_device); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(gpu_device); + + tao::ral::gpu::BaseCudaContextOption gpu_opt; + gpu_opt.device_ordinal = gpu_device; + gpu_opt.use_stream_executor = true; + gpu_opt.gpu_allocator.reset( + new RalAllocator(c10::cuda::CUDACachingAllocator::raw_alloc, + c10::cuda::CUDACachingAllocator::raw_delete)); + + std::lock_guard guard(mtx_); + tao::ral::BaseContext* ral_ctx_ptr; + auto it = ral_ctx_map_.find(stream); + if (it == ral_ctx_map_.end()) { + gpu_opt.stream = stream.stream(); + auto ral_ctx = + tao::ral::gpu::MakeBaseCudaContext(default_opt_, cpu_opt_, gpu_opt); + ral_ctx_ptr = ral_ctx.get(); + ral_ctx_map_[stream].reset(ral_ctx.release()); + } else { + ral_ctx_ptr = it->second.get(); + } + return ral_ctx_ptr; +} + +std::vector RalContext::PreProcessInputs( + const std::vector& inputs) { + CheckCurrentDevice(inputs); + + std::vector contiguous_inputs; + for (at::Tensor inp_tensor : inputs) { + // make sure the input is in contiguous layout + contiguous_inputs.push_back(inp_tensor.contiguous()); + } + return contiguous_inputs; +} + +inline bool IsEmptyTensor(const tao::ral::buffer_shape_t& shape) { + return shape.size() > 0 && std::any_of(shape.begin(), shape.end(), + [](int64_t dim) { return dim == 0; }); +} + +inline bool IsSameShape(const tao::ral::buffer_shape_t& shape, + at::Tensor input_tensor) { + if (input_tensor.dim() != shape.size()) { + return false; + } + + for (int i = 0; i < shape.size(); i++) { + if (input_tensor.sizes()[i] != shape[i]) { + return false; + } + } + + return true; +} + +std::vector RalContext::CreateAndBindingOutputs( + const std::vector& inputs, + tao::ral::ExecutionContext& exec_ctx) { + std::vector outputs; + + auto num_outputs = disc_result_.outputs.size(); + outputs.reserve(num_outputs); + std::vector> out_bufs( + num_outputs); + for (size_t idx = 0; idx < num_outputs; ++idx) { + auto& out_buf = out_bufs[idx]; + // Note: Ral has memory allocator that allocate memory each time forward. + // So it's thread-safe to reuse the underline memory. + exec_ctx.bindOutput(idx, &out_buf); + + const auto& output_info = disc_result_.outputs[idx]; + auto scalar_type = output_info.scalar_type; + + torch::DeviceType dev_type = torch::kCUDA; + dev_type = (output_info.device == "cuda") ? torch::kCUDA : torch::kCPU; + + auto option = torch::device(dev_type) + .dtype(scalar_type) + .memory_format(torch::MemoryFormat::Contiguous); + at::Tensor out_tensor; + if (IsEmptyTensor(out_buf->shape())) { + out_tensor = torch::zeros(out_buf->shape(), option); + } else if (out_buf->owned()) { + auto cpu_allocator = c10::GetAllocator(torch::kCPU); + TORCH_CHECK(cpu_allocator != nullptr); + std::function deleter = [cpu_allocator](void* ptr) { + cpu_allocator->raw_deallocate(ptr); + }; + if (output_info.device == "cuda") { + deleter = c10::cuda::CUDACachingAllocator::raw_delete; + } + out_tensor = torch::from_blob(const_cast(out_buf->data()), + out_buf->shape(), deleter, option); + out_buf->release(); + } else { + //(@yuanxiulong.yxl) For input output alias, now we will only have full + // tensor memory reuse. + // We will support partial memory space reuse in the future + bool alias_input = false; + for (auto& input_tensor : inputs) { + // same address, same shape, same dtype + if (input_tensor.data_ptr() == out_buf->data() && + scalar_type == input_tensor.dtype() && + IsSameShape(out_buf->shape(), input_tensor)) { + out_tensor = input_tensor; + alias_input = true; + } + } + if (!alias_input) { + out_tensor = torch::from_blob(const_cast(out_buf->data()), + out_buf->shape(), option) + .clone(); + } + } + outputs.push_back(out_tensor); + } + return outputs; +} + +void RalContext::BindingInputs(const std::vector& inputs, + tao::ral::ExecutionContext& exec_ctx) { + for (size_t idx = 0; idx < inputs.size(); ++idx) { + at::Tensor inp = inputs[idx]; + const auto& shape = inp.sizes(); + exec_ctx.bindInput(idx, inp.data_ptr(), shape.vec()); + } +} + +std::vector RalContext::Execute( + const std::vector& inputs) { + // inputs are always contigous + auto ral_ctx = LoadCache(); + // execution context is per-inference context and thread-safe + auto exec_ctx = + tao::ral::MakeExecutionContext( + ral_ctx); + + BindingInputs(inputs, *exec_ctx.get()); + + auto tao_ral_func_ptr = reinterpret_cast(&tao_ral_call_impl); + + // execute + void* ctx_struct[] = {exec_ctx.get(), tao_ral_func_ptr}; + try { + entry_func_(ctx_struct); + } catch (std::exception& ex) { + LOG(ERROR) << ex.what(); + throw ex; + } + + // Support input output buffer reuse + // Now we only have full buffer reuse for alias + auto outputs = CreateAndBindingOutputs(inputs, *exec_ctx.get()); + + return outputs; +} + +} // namespace disc +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc/disc_ral.h b/torch_xla/csrc/runtime/disc/disc_ral.h new file mode 100644 index 00000000000..f47431689c5 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_ral.h @@ -0,0 +1,73 @@ +#ifndef XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCRAL_H_ +#define XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCRAL_H_ + +#include +#include +#include + +#include "torch_xla/csrc/runtime/disc/disc_utils.h" + +namespace torch_xla { +namespace runtime { +namespace disc { + +using tao::ral::ExecutionContext; + +struct DataMeta { + std::string device; + c10::ScalarType scalar_type; +}; + +struct DISCComplationResult { + std::string ral_lib; + std::string ral_mate_pb; + std::vector inputs; + std::vector outputs; +}; + +class RalContext { + using EntryFunc = std::function; + + public: + RalContext(const DISCComplationResult& disc_result); + ~RalContext(); + + std::vector Execute(const std::vector& inputs); + + private: + void BindingInputs(const std::vector& inputs, + tao::ral::ExecutionContext& exec_ctx); + void CheckCurrentDevice(const std::vector& inputs); + std::vector CreateAndBindingOutputs( + const std::vector& inputs, + tao::ral::ExecutionContext& exec_ctx); + std::vector PreProcessInputs( + const std::vector& inputs); + std::tuple LoadEngine(const std::string& ral_engine_bytes); + + int64_t LazyInitCurrentDevice(); + + constexpr static int64_t NULL_GPU_DEVICE = -1; + std::atomic gpu_device_{NULL_GPU_DEVICE}; + std::mutex mtx_; + std::unordered_map> + ral_ctx_map_; + tao::ral::BaseContext* LoadCache(); + + tao::ral::BaseContextOption default_opt_; + tao::ral::cpu::BaseCpuContextOption cpu_opt_; + + DISCComplationResult disc_result_; + + void* tao_lib_; + EntryFunc entry_func_; + + TempFile lib_tmpf_{"ral_lib.so"}; + TempFile meta_tmpf_{"ral_meta.pb"}; +}; +} // namespace disc +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCRAL_H_ diff --git a/torch_xla/csrc/runtime/disc/disc_ral_test.cc b/torch_xla/csrc/runtime/disc/disc_ral_test.cc new file mode 100644 index 00000000000..7d25bef0fb6 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_ral_test.cc @@ -0,0 +1,19 @@ +#include "torch_xla/csrc/runtime/disc/disc_ral.h" + +#include + +namespace torch_xla { +namespace runtime { +namespace disc { +TEST(DISCRAlTest, E2E) { + // TODO(disc): need compile API to output the compilation result + std::shared_ptr disc_result = + std::make_shared(); + RalContext ral_ctx(disc_result); + std::vector inputs; + ral_ctx.Execute(at::List()); +} + +} // namespace disc +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc/disc_utils.cc b/torch_xla/csrc/runtime/disc/disc_utils.cc new file mode 100644 index 00000000000..b9c6303bac0 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_utils.cc @@ -0,0 +1,114 @@ +#include "torch_xla/csrc/runtime/disc/disc_utils.h" + +#include +#include + +#include + +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace torch_xla { +namespace runtime { +namespace disc { + +std::string ReadStringFromEnvVar(const char* env_var_name, + std::string default_val) { + const char* env_var_val = std::getenv(env_var_name); + if (env_var_val == nullptr) { + return default_val; + } + return std::string(env_var_val); +} + +// This function is copied from c10/util/tempfile.h, so it follows to these +// temperary directory env variables, too. +std::vector make_filename(std::string name_prefix) { + // The filename argument to `mkstemp` needs "XXXXXX" at the end according to + // http://pubs.opengroup.org/onlinepubs/009695399/functions/mkstemp.html + static const std::string kRandomPattern = "XXXXXX"; + // We see if any of these environment variables is set and use their value, or + // else default the temporary directory to `/tmp`. + static const char* env_variables[] = {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}; + std::string tmp_directory = "/tmp"; + for (const char* variable : env_variables) { + auto path = ReadStringFromEnvVar(variable, ""); + if (!path.empty()) { + tmp_directory = path; + break; + } + } + std::vector filename; + filename.reserve(tmp_directory.size() + name_prefix.size() + + kRandomPattern.size() + 2); + filename.insert(filename.end(), tmp_directory.begin(), tmp_directory.end()); + filename.push_back('/'); + filename.insert(filename.end(), name_prefix.begin(), name_prefix.end()); + filename.insert(filename.end(), kRandomPattern.begin(), kRandomPattern.end()); + filename.push_back('\0'); + return filename; +} + +std::string ReadFileBytes(const std::string& fname) { + std::ifstream input(fname, std::ios::binary); + std::vector bytes((std::istreambuf_iterator(input)), + (std::istreambuf_iterator())); + return std::string(bytes.begin(), bytes.end()); +} + +TempFile::TempFile(std::string prefix) : fname_(""), fd_(-1) { + auto fname = make_filename(prefix); + fd_ = mkstemp(fname.data()); + fname_ = std::string(fname.data()); + TORCH_CHECK(fd_ != -1, "Error generating temporary file, file name: ", fname_, + ", error: ", std::strerror(errno)); +} + +TempFile::~TempFile() { + if (!fname_.empty()) { + ::unlink(fname_.c_str()); + } + if (fd_ > 0) { + ::close(fd_); + } +} + +bool TempFile::WriteBytesToFile(const std::string& bytes) { + ssize_t left_len = bytes.length(); + const char* data = bytes.data(); + errno = 0; + while (left_len > 0) { + auto sz = ::write(fd_, data, left_len); + if (sz <= 0) { + if (errno != EINTR && errno != EAGAIN) { + TF_VLOG(1) << "Failed to write content to temp file: " << GetFilename() + << ", error: " << strerror(errno); + return false; + } + errno = 0; + continue; + } + left_len -= sz; + data += sz; + } + return true; +} + +const std::string& TempFile::GetFilename() const { return fname_; } + +std::string TempFile::ReadBytesFromFile() { + std::ifstream infile(fname_, std::ios::binary); + std::string str((std::istreambuf_iterator(infile)), + std::istreambuf_iterator()); + return str; +} + +std::string TempFile::ReadStringFromFile() { + std::ifstream infile(fname_); + std::string str((std::istreambuf_iterator(infile)), + std::istreambuf_iterator()); + return str; +} + +} // namespace disc +} // namespace runtime +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_utils.h b/torch_xla/csrc/runtime/disc/disc_utils.h new file mode 100644 index 00000000000..df7476e77df --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_utils.h @@ -0,0 +1,36 @@ +#ifndef XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCUTILS_H_ +#define XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCUTILS_H_ + +#include +#include + +namespace torch_xla { +namespace runtime { +namespace disc { +std::vector make_filename(std::string name_prefix); +std::string ReadFileBytes(const std::string& fname); +class TempFile { + public: + TempFile(std::string prefix = ""); + ~TempFile(); + TempFile(const TempFile&) = delete; + void operator=(const TempFile&) = delete; + /// Write bytes content to temp file and return true on success. + bool WriteBytesToFile(const std::string& bytes); + /// Read byte content from temp file. + std::string ReadBytesFromFile(); + /// Read string content from temp file.. + std::string ReadStringFromFile(); + /// Get the filename of the temp file. + const std::string& GetFilename() const; + + private: + std::string fname_; + int fd_; +}; + +} // namespace disc +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCUTILS_H_ \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc new file mode 100644 index 00000000000..6465551dbde --- /dev/null +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -0,0 +1,410 @@ +#include "torch_xla/csrc/runtime/disc_computation_client.h" + +#include +#include +#include +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/disc/disc_compile.h" +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/float_normalization.h" +#include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/hlo_proto_util.h" + +namespace torch_xla { +namespace runtime { + +at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { + switch (xla_type) { + case xla::PrimitiveType::BF16: + return at::ScalarType::BFloat16; + case xla::PrimitiveType::F16: + return at::ScalarType::Half; + case xla::PrimitiveType::F32: + return at::ScalarType::Float; + case xla::PrimitiveType::F64: + return at::ScalarType::Double; + case xla::PrimitiveType::PRED: + return at::ScalarType::Bool; + case xla::PrimitiveType::U8: + return at::ScalarType::Byte; + case xla::PrimitiveType::S8: + return at::ScalarType::Char; + case xla::PrimitiveType::S16: + case xla::PrimitiveType::U16: + return at::ScalarType::Short; + case xla::PrimitiveType::S32: + case xla::PrimitiveType::U32: + return at::ScalarType::Int; + case xla::PrimitiveType::S64: + case xla::PrimitiveType::U64: + return at::ScalarType::Long; + case xla::PrimitiveType::C64: + return at::ScalarType::ComplexFloat; + case xla::PrimitiveType::C128: + return at::ScalarType::ComplexDouble; + default: + XLA_ERROR() << "XLA type not supported: " << xla_type; + } +} + +xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) { + switch (scalar_type) { + case at::ScalarType::Double: + return xla::PrimitiveType::F64; + case at::ScalarType::Float: + return xla::PrimitiveType::F32; + case at::ScalarType::BFloat16: + return xla::PrimitiveType::BF16; + case at::ScalarType::Half: + return xla::PrimitiveType::F16; + case at::ScalarType::Bool: + return xla::PrimitiveType::PRED; + case at::ScalarType::Byte: + return xla::PrimitiveType::U8; + case at::ScalarType::Char: + return xla::PrimitiveType::S8; + case at::ScalarType::Short: + return xla::PrimitiveType::S16; + case at::ScalarType::Int: + return xla::PrimitiveType::S32; + case at::ScalarType::Long: + return xla::PrimitiveType::S64; + case at::ScalarType::ComplexFloat: + return xla::PrimitiveType::C64; + case at::ScalarType::ComplexDouble: + return xla::PrimitiveType::C128; + default: + XLA_ERROR() << "Type not supported: " << scalar_type; + } +} + +DISCComputationClient::DISCComputationClient() { + world_size_ = sys_util::GetEnvInt("WORLD_SIZE", 1); + local_rank_ = sys_util::GetEnvInt("LOCAL_RANK", 0); + global_rank_ = sys_util::GetEnvInt("RANK", local_rank_); + device_type_ = sys_util::GetEnvString(env::kEnvDISCDevice, "CUDA"); + if (device_type_ != "CUDA") { + XLA_ERROR() << "Only CUDA device is supported by DISC backend"; + } +} + +DISCComputationClient::~DISCComputationClient() {} + +void DISCComputationClient::DISCData::Assign( + const torch::lazy::BackendData& data) { + const DISCData& disc_data = dynamic_cast(data); + if (&disc_data != this) { + buffer = disc_data.buffer; + } +} + +ComputationClient::DataPtr DISCComputationClient::CreateDataPlaceholder( + std::string device, xla::Shape shape, + std::optional sharding) { + return std::make_shared(std::move(device), std::move(shape)); +} + +std::vector DISCComputationClient::TransferToDevice( + absl::Span> tensors) { + std::vector datas; + datas.reserve(tensors.size()); + + size_t total_transfered_bytes = 0; + + for (auto& tensor : tensors) { + std::vector sizes; + for (auto& dim_val : tensor->shape().dimensions()) { + sizes.push_back(dim_val); + } + + auto dtype = + at::TensorOptions(TorchTypeFromXlaType(tensor->shape().element_type())); + auto ret = at::empty(sizes, dtype).contiguous(); + // tensor->populate_fn(tensor, ret.data_ptr(), + // ret.element_size() * ret.numel()); + std::memcpy(ret.data_ptr(), tensor->data(), + ret.element_size() * ret.numel()); + + total_transfered_bytes += ret.element_size() * ret.numel(); + + if (!torch::cuda::is_available()) { + XLA_ERROR() << "CUDA is not available."; + } + + auto device_ret = ret.to(at::kCUDA); + ComputationClient::DataPtr data = std::make_shared( + tensor->device(), tensor->shape(), device_ret); + datas.push_back(data); + } + + return datas; +} + +std::vector DISCComputationClient::TransferFromDevice( + absl::Span handles) { + std::vector literals; + literals.reserve(handles.size()); + for (auto handle : handles) { + std::shared_ptr disc_data = + std::dynamic_pointer_cast(handle); + xla::Shape target_shape = + xla::ShapeUtil::DeviceShapeToHostShape(xla::ShapeUtil::MakeShape( + XlaTypeFromTorchType(disc_data->buffer.dtype().toScalarType()), + disc_data->buffer.sizes())); + auto& literal = literals.emplace_back(target_shape); + auto host_data = disc_data->buffer.to(at::kCPU); + std::memcpy(literal.untyped_data(), host_data.data_ptr(), + literal.size_bytes()); + } + + return literals; +} + +std::vector DISCComputationClient::Compile( + std::vector instances) { + std::vector computations{}; + for (auto& instance : instances) { + mlir::MLIRContext context; + mlir::ModuleOp mlir_module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + + auto hlo_proto = instance.computation.proto(); + auto program_shape = instance.computation.GetProgramShape().value(); + xla::HloModuleConfig module_config(program_shape); + module_config.set_debug_options(xla::GetDebugOptionsFromFlags()); + xla::ComputationLayout* entry_layout = + module_config.mutable_entry_computation_layout(); + for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) { + auto status = + entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + program_shape.parameters(i)); + if (!status.ok()) { + XLA_ERROR() << "Error copying layout from shape: "; + return {}; + } + } + + std::unique_ptr hlo_module = + xla::CreateModuleFromProto(hlo_proto, module_config).value(); + xla::HloPassPipeline pipeline("pre-stablehlo"); + stream_executor::CudaComputeCapability gpu_version; + auto dprops = at::cuda::getCurrentDeviceProperties(); + gpu_version.major = dprops->major; + gpu_version.minor = dprops->minor; + xla::gpu::GpuFloatSupport bf16_support(gpu_version, xla::BF16); + pipeline.AddPass(&bf16_support); + auto status = pipeline.Run(hlo_module.get()).status(); + if (!status.ok()) { + XLA_ERROR() << "Error running pre-stablehlo pass pipeline: "; + return {}; + } + { + auto mutable_hlo_proto = hlo_module->ToProto(); + auto status = + torch_xla::ConvertHloToMhlo(&mutable_hlo_proto, &mlir_module); + XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n" + << status.message(); + } + + // Add input and output attributes + auto entry_func_identifier = + mlir::StringAttr::get(&context, "tf.entry_function"); + auto input_placement_key = + mlir::StringAttr::get(&context, "input_placements"); + auto output_placement_key = + mlir::StringAttr::get(&context, "output_placements"); + auto input_output_alias_params_key = + mlir::StringAttr::get(&context, "input_output_alias_params"); + auto input_output_alias_outputs_key = + mlir::StringAttr::get(&context, "input_output_alias_outputs"); + + std::string input_placement = ""; + std::string output_placement = ""; + std::string input_output_alias_params = ""; + std::string input_output_alias_outputs = ""; + + std::vector inputs, outputs; + + auto input_output_alias = instance.computation.proto().input_output_alias(); + if (sys_util::GetEnvString("ENBALE_DISC_INPUT_OUTPUT_ALIAS", "") != "OFF") { + for (const auto& entry : input_output_alias.entries()) { + input_output_alias_params += + std::to_string(entry.parameter_number()) + ","; + input_output_alias_outputs += + std::to_string(entry.output_shape_index(0)) + ","; + } + } + if (!input_output_alias_params.empty()) { + input_output_alias_params.pop_back(); + input_output_alias_outputs.pop_back(); + } + + // Set attribute for entry function + mlir::func::FuncOp entry_func; + for (auto func : mlir_module.getOps()) { + if (func.getName().str() == "main") { + entry_func = func; + break; + } + } + + for (int i = 0; i < entry_func.getFunctionType().getNumInputs(); i++) { + absl::StrAppend(&input_placement, "gpu,"); + disc::DataMeta tensor_info; + tensor_info.device = "cuda"; + inputs.push_back(tensor_info); + } + if (!input_placement.empty()) { + input_placement.pop_back(); + } + + if (instance.output_shape->IsTuple()) { + for (auto& sub_shape : instance.output_shape->tuple_shapes()) { + absl::StrAppend(&output_placement, "gpu,"); + disc::DataMeta tensor_info; + tensor_info.device = "cuda"; + tensor_info.scalar_type = + TorchTypeFromXlaType(sub_shape.element_type()); + outputs.push_back(tensor_info); + } + } else { + absl::StrAppend(&output_placement, "gpu,"); + disc::DataMeta tensor_info; + tensor_info.device = "cuda"; + tensor_info.scalar_type = + TorchTypeFromXlaType(instance.output_shape->element_type()); + outputs.push_back(tensor_info); + } + + if (!output_placement.empty()) { + output_placement.pop_back(); + } + + auto input_placement_value = + mlir::StringAttr::get(&context, input_placement); + auto output_placement_value = + mlir::StringAttr::get(&context, output_placement); + + auto input_output_alias_outputs_value = + mlir::StringAttr::get(&context, input_output_alias_outputs); + auto input_output_alias_params_value = + mlir::StringAttr::get(&context, input_output_alias_params); + + auto dict = mlir::DictionaryAttr::get( + &context, + {mlir::NamedAttribute(input_placement_key, input_placement_value), + mlir::NamedAttribute(output_placement_key, output_placement_value), + mlir::NamedAttribute(input_output_alias_params_key, + input_output_alias_params_value), + mlir::NamedAttribute(input_output_alias_outputs_key, + input_output_alias_outputs_value)}); + + entry_func->setAttr(entry_func_identifier, dict); + mlir_module->setAttr(entry_func_identifier, dict); + + // Trigger disc compilation + disc::DISCComplationResult compile_res = + disc::Compile(mlir_module, inputs, outputs); + std::shared_ptr disc_computation = + std::make_shared( + std::move(xla::XlaComputation(instance.computation.proto())), + instance.devices, std::make_unique(compile_res)); + computations.push_back(disc_computation); + } + + return computations; +} + +std::vector +DISCComputationClient::ExecuteComputation( + const ComputationClient::Computation& computation, + absl::Span arguments, + const std::string& device, const ExecuteComputationOptions& options) { + const DISCComputation& disc_computation = + dynamic_cast(computation); + + std::vector buffers; + buffers.reserve(arguments.size()); + for (auto& argument : arguments) { + std::shared_ptr disc_data = + std::dynamic_pointer_cast(argument); + buffers.push_back(disc_data->buffer); + } + + std::vector results = + disc_computation.executable->Execute(buffers); + + std::vector datas; + datas.reserve(results.size()); + for (auto& result : results) { + std::shared_ptr data = std::make_shared( + device, xla::ShapeUtil::MakeShape(xla::F32, result.sizes()), result); + + datas.push_back(data); + } + + return datas; +} + +std::map DISCComputationClient::GetMetrics() const { + return {}; +} + +std::string DISCComputationClient::GetDefaultDevice() const { + return absl::StrCat(device_type_, ":", std::to_string(local_rank_)); +} + +std::vector DISCComputationClient::GetLocalDevices() const { + std::vector all_devices; + all_devices.push_back(GetDefaultDevice()); + return all_devices; +} + +std::optional DISCComputationClient::GetDataSharding( + ComputationClient::DataPtr handle) { + return std::optional(); +} + +void DISCComputationClient::SetReplicationDevices( + std::shared_ptr> devices) { + replication_devices_ = std::move(devices); +} + +std::shared_ptr> +DISCComputationClient::GetReplicationDevices() { + return replication_devices_; +} + +std::vector DISCComputationClient::GetAllDevices() const { + std::vector all_devices; + int device_count = world_size_; + for (int idx = 0; idx < device_count; idx++) { + all_devices.push_back(absl::StrCat(device_type_, ":", std::to_string(idx))); + } + return all_devices; +} + +size_t DISCComputationClient::GetNumDevices() const { return world_size_; } + +int DISCComputationClient::GetProcessIndex() const { return local_rank_; } + +int DISCComputationClient::GetNumProcesses() const { return world_size_; } + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h new file mode 100644 index 00000000000..0701d7b3591 --- /dev/null +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -0,0 +1,199 @@ +#ifndef XLA_CLIENT_DISC_COMPUTATION_CLIENT_H_ +#define XLA_CLIENT_DISC_COMPUTATION_CLIENT_H_ + +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/disc/disc_ral.h" +#include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "xla/client/xla_computation.h" + +namespace torch_xla { +namespace runtime { + +class DISCComputationClient : public ComputationClient { + public: + DISCComputationClient(); + ~DISCComputationClient(); + + DataPtr CreateDataPlaceholder( + std::string device, xla::Shape shape, + std::optional sharding = std::nullopt) override; + + std::vector GetDataShards(DataPtr data) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + DataPtr GetDataShard(DataPtr data, size_t index) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::vector ReshardData( + absl::Span handles, + absl::Span shardings) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + DataPtr WrapDataShards(absl::Span shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::optional GetDataSharding(DataPtr handle) override; + + std::vector TransferToDevice( + absl::Span> tensors) override; + + std::vector TransferFromDevice( + absl::Span handles) override; + + DataPtr TransferShardsToDevice( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + DataPtr CopyToDevice(DataPtr data, std::string dst) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::string SerializeComputation(const ComputationPtr computation) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + ComputationPtr DeserializeComputation( + const std::string& serialized) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + torch::lazy::hash_t HashCompilationEnv() override { + // TODO(wangang.wa): Improve this function. + return torch::lazy::hash_t(); + } + + torch_xla::DeviceType GetDeviceType() const override { + return torch_xla::DeviceType("CUDA"); + }; + + bool CoordinatorInitialized() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + XlaCoordinator& GetCoordinator() override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::vector Compile( + std::vector instances) override; + + std::vector ExecuteComputation( + const Computation& computation, absl::Span arguments, + const std::string& device, + const ExecuteComputationOptions& options) override; + + std::vector ExecuteReplicated( + const Computation& computation, absl::Span arguments, + absl::Span devices, + const ExecuteReplicatedOptions& options) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + size_t GetNumDevices() const override; + + std::string GetDefaultDevice() const override; + + std::vector GetLocalDevices() const override; + + std::vector GetAllDevices() const override; + + int GetProcessIndex() const override; + + int GetNumProcesses() const override; + + const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + GetDeviceAttributes(const std::string& device) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + void SetReplicationDevices( + std::shared_ptr> devices) override; + + std::shared_ptr> GetReplicationDevices() override; + + void WaitDeviceOps(absl::Span devices) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::map GetMetrics() const override; + + MemoryInfo GetMemoryInfo(const std::string& device) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + private: + std::shared_ptr> replication_devices_; + int world_size_; + int local_rank_; + int global_rank_; + std::string device_type_; + struct DISCData : public Data { + DISCData(std::string device, xla::Shape device_shape) + : Data(std::move(device), std::move(device_shape)) {} + + DISCData(std::string device, xla::Shape device_shape, at::Tensor buffer) + : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} + + void Assign(const torch::lazy::BackendData& data) override; + + bool HasValue() const override { + return buffer.defined() && buffer.element_size() > 0; + } + + Handle GetHandle() override { + return reinterpret_cast(buffer.const_data_ptr()); + } + + bool HasSharding() const override { return false; } + + xla::OpSharding GetSharding() const override { + XLA_CHECK(false) << "GetSharding should not be called on DISCData, check " + "HasSharding first"; + return xla::OpSharding(); + } + + std::string ToString() const override { + std::stringstream ss; + ss << "XLAData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " Data Handle: "; + if (HasValue()) { + ss << reinterpret_cast(buffer.const_data_ptr()) << "\n"; + } else { + ss << "None\n"; + } + return ss.str(); + } + + at::Tensor buffer; + }; + + struct DISCComputation : public Computation { + DISCComputation(xla::XlaComputation computation, + std::vector devices, + std::unique_ptr executable) + : Computation(std::move(computation), std::move(devices)), + executable(std::move(executable)) {} + + std::unique_ptr executable; + }; +}; + +} // namespace runtime +} // namespace torch_xla +#endif // XLA_CLIENT_DISC_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/runtime/disc_computation_client_test.cc b/torch_xla/csrc/runtime/disc_computation_client_test.cc new file mode 100644 index 00000000000..902c3ab151c --- /dev/null +++ b/torch_xla/csrc/runtime/disc_computation_client_test.cc @@ -0,0 +1,78 @@ +#include "torch_xla/csrc/runtime/disc_computation_client.h" + +#include +#include + +#include +#include +#include +#include + +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/tests/literal_test_util.h" + +namespace torch_xla { +namespace runtime { + +tsl::StatusOr MakeComputation() { + xla::Shape input_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2}); + xla::XlaBuilder builder("AddComputation"); + xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x"); + xla::XlaOp y = xla::Parameter(&builder, 1, input_shape, "y"); + xla::XlaOp sum = xla::Add(x, y); + return builder.Build(); +} + +TEST(DISCComputationClientTest, Init) { + tsl::setenv("DISC_DEVICE", "GPU", true); + auto client = std::make_unique(); + std::string device = "GPU:0"; + + // Compose a computation. + auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2}); + std::vector instances; + instances.push_back(ComputationClient::CompileInstance( + std::move(MakeComputation().value()), device, {"cuda:0"}, &shape)); + + // Prepare inputs. + xla::Literal literal_x = + xla::LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + xla::Literal literal_y = + xla::LiteralUtil::CreateR2({{5.0f, 6.0f}, {7.0f, 8.0f}}); + + // Compile the graph. + std::vector computations = + client->Compile(std::move(instances)); + + // Copy inputs to device. + ComputationClient::ExecuteComputationOptions options{}; + std::vector> args = { + std::make_shared(std::move(literal_x), device), + std::make_shared(std::move(literal_y), device)}; + + // Execute the graph. + std::vector results = client->ExecuteComputation( + *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), + device, options); + + // Copy the output from device back to host and assert correctness.. + ASSERT_EQ(results.size(), 1); + auto result_literals = client->TransferFromDevice(results); + ASSERT_THAT(result_literals, ::testing::SizeIs(1)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), + result_literals[0])); +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index d04d99ab761..a682b273e69 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -9,6 +9,7 @@ const char* const kEnvNumGpu = "GPU_NUM_DEVICES"; const char* const kEnvNumCpu = "CPU_NUM_DEVICES"; const char* const kEnvTpuvmMode = "TPUVM_MODE"; const char* const kEnvPjRtDevice = "PJRT_DEVICE"; +const char* const kEnvDISCDevice = "DISC_DEVICE"; const char* const kEnvPjRtTpuMaxInflightComputations = "PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS"; const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h old mode 100644 new mode 100755 index ef2535c230f..6002497d7b0 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -19,6 +19,7 @@ extern const char* const kEnvHostOrdinal; extern const char* const kEnvShardOrdinal; extern const char* const kEnvStartService; extern const char* const kEnvTpuvmMode; +extern const char* const kEnvDISCDevice; extern const char* const kEnvPjRtDevice; extern const char* const kEnvPjRtTpuMaxInflightComputations; extern const char* const kEnvPjrtAsyncCpuClient; diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index feb2a0844c6..d31fab90e86 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -2,6 +2,9 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" +#ifdef TORCHACC_ENABLE_DISC +#include "torch_xla/csrc/runtime/disc_computation_client.h" +#endif #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" @@ -21,7 +24,13 @@ ComputationClient* GetComputationClient() { std::unique_ptr client; static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); - if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { + if (sys_util::GetEnvString(env::kEnvDISCDevice, "") != "") { +#ifdef TORCHACC_ENABLE_DISC + client = std::make_unique(); +#else + XLA_ERROR() << "should build with ENABLE_DISC=ON" << std::endl; +#endif + } else if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { if (use_ifrt) { client = std::make_unique(); } else { diff --git a/torch_xla/csrc/runtime/stablehlo_helper.cc b/torch_xla/csrc/runtime/stablehlo_helper.cc index 1d9a740e52c..6ff9292fd70 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_helper.cc @@ -50,8 +50,8 @@ static std::string getMlirModuleBytecode(mlir::ModuleOp& mlir_module) { return txt_mlir_module; } -static absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, - mlir::ModuleOp* mlir_module) { +absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, + mlir::ModuleOp* mlir_module) { auto status = xla::ConvertHloToMlirHlo(*mlir_module, proto, /*import_all_computations=*/false); if (!status.ok()) { @@ -62,6 +62,17 @@ static absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, absl::StatusCode::kInternal, "MHLO Module from HLO -> MHLO conversion is not legal."); } + mlir::PassManager pm(mlir_module->getContext()); + // Apply pass to remove HLO tuple output, as MHLO/StableHLO supports multiple + // outputs. + pm.addPass(mlir::mhlo::createExpandHloTuplesPass()); + // Canonicalization after tuple flatten, to remove unused tuple op. + pm.addNestedPass(mlir::createCanonicalizerPass()); + + XLA_CHECK(mlir::succeeded(pm.run(*mlir_module))) + << "HLO -> MHLO conversion failed.\n" + << status.message(); + return absl::OkStatus(); } diff --git a/torch_xla/csrc/runtime/stablehlo_helper.h b/torch_xla/csrc/runtime/stablehlo_helper.h index 235dc6b38b7..cfc65b817e2 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.h +++ b/torch_xla/csrc/runtime/stablehlo_helper.h @@ -20,6 +20,9 @@ void ConvertStableHloToHlo(mlir::ModuleOp* mlir_module, mlir::MLIRContext* context, xla::HloProto* hlo_proto); +absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, + mlir::ModuleOp* mlir_module); + std::string GetHloModuleStr(const xla::HloModuleProto* proto); const std::string GetTorchDtypeToStablehloDtype(const std::string& dtype);