From f00b30bb89cbe0e351a20cbf753dcee354c54da2 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 5 Jan 2024 14:12:39 +0800 Subject: [PATCH 01/12] build with BladeDISC (#8) --- .gitmodules | 3 + BUILD | 1 + WORKSPACE | 7 + bazel/disc.BUILD | 43 +++++ bazel/torch.BUILD | 8 + setup.py | 8 + third_party/BladeDISC | 1 + torch_xla/csrc/BUILD | 1 + torch_xla/csrc/runtime/disc/BUILD | 40 +++++ torch_xla/csrc/runtime/disc/disc_ral.cc | 166 +++++++++++++++++++ torch_xla/csrc/runtime/disc/disc_ral.h | 63 +++++++ torch_xla/csrc/runtime/disc/disc_ral_test.cc | 18 ++ 12 files changed, 359 insertions(+) create mode 100644 bazel/disc.BUILD create mode 160000 third_party/BladeDISC create mode 100644 torch_xla/csrc/runtime/disc/BUILD create mode 100644 torch_xla/csrc/runtime/disc/disc_ral.cc create mode 100644 torch_xla/csrc/runtime/disc/disc_ral.h create mode 100644 torch_xla/csrc/runtime/disc/disc_ral_test.cc diff --git a/.gitmodules b/.gitmodules index 32423922406..9a14e40cb54 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "third_party/flash-attention"] path = third_party/flash-attention url = https://github.com/Dao-AILab/flash-attention.git +[submodule "third_party/BladeDISC"] + path = third_party/BladeDISC + url = https://github.com/alibaba/BladeDISC.git diff --git a/BUILD b/BUILD index 6949f6dc748..5efbd38c0b6 100644 --- a/BUILD +++ b/BUILD @@ -21,6 +21,7 @@ cc_binary( visibility = ["//visibility:public"], deps = [ "//torch_xla/csrc:init_python_bindings", + "//torch_xla/csrc/runtime/disc:disc_ral", "@torch//:headers", "@torch//:libc10", "@torch//:libtorch", diff --git a/WORKSPACE b/WORKSPACE index c007f07d271..9e10a077ae7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -89,3 +89,10 @@ new_local_repository( build_file = "//bazel:flash_attn.BUILD", path = "third_party/flash-attention/", ) +################################ BladeDISC Setup ################################ + +new_local_repository( + name = "disc_compiler", + build_file = "//bazel:disc.BUILD", + path = "third_party/BladeDISC/", +) diff --git a/bazel/disc.BUILD b/bazel/disc.BUILD new file mode 100644 index 00000000000..c6ffbb187bd --- /dev/null +++ b/bazel/disc.BUILD @@ -0,0 +1,43 @@ + +package( + default_visibility = [ + "//visibility:public", + ], +) + +cc_library( + name = "headers", + hdrs = glob( + [ + "mlir/ral/*.h", + "mlir/ral/context/base/cuda/*.h", + "mlir/ral/context/base/cuda/cuda_context_impl.h", + "mlir/ral/device/cpu/*.h", + "mlir/ral/device/gpu/*.h", + ], + ), + includes = [ + "tao_compiler", + "tao_compiler/mlir", + ], + strip_include_prefix = "external/disc_compiler/tao_compiler/mlir", +) + +cc_import( + name="disc_ral_cuda", + shared_library = ":libral_base_context.so", +) + +genrule( + name = "build_disc", + outs = ["libral_base_context.so", "disc_compiler_main", "torch-mlir-opt"], + local = True, + cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}', + 'pushd third_party/BladeDISC/pytorch_blade', + 'python ../scripts/python/common_setup.py', + 'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel 2>&1 | tee build.log', + 'popd', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/ral/libral_base_context.so $(location libral_base_context.so)', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/tests/mhlo/torch-mlir-opt/torch-mlir-opt $(location torch-mlir-opt)']), +) diff --git a/bazel/torch.BUILD b/bazel/torch.BUILD index b91d75f9f0b..d2be73c6bcb 100644 --- a/bazel/torch.BUILD +++ b/bazel/torch.BUILD @@ -55,6 +55,10 @@ cc_import( shared_library = "build/lib/libtorch_cpu.so", ) +cc_import( + name = "libtorch_cuda", + shared_library = "build/lib/libtorch_cuda.so", +) cc_import( name = "libtorch_python", shared_library = "build/lib/libtorch_python.so", @@ -64,3 +68,7 @@ cc_import( name = "libc10", shared_library = "build/lib/libc10.so", ) +cc_import( + name = "libc10_cuda", + shared_library = "build/lib/libc10_cuda.so", +) diff --git a/setup.py b/setup.py index c1912d832b4..e095867812d 100644 --- a/setup.py +++ b/setup.py @@ -248,6 +248,14 @@ def bazel_build(self, ext): '/'.join(['third_party/flash-attention', flash_attn_so_name]), '/'.join([ext_dest_dir, flash_attn_so_name])) + # package BladeDISC distribution files + # please note, TorchBlade also create some symbolic links to 'torch_blade' dir + disc_ral_so_name = 'libral_base_context.so' + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' + shutil.copyfile( + os.path.join(bazel_bin_path, disc_ral_so_name), + '/'.join([ext_dest_dir, disc_ral_so_name])) + class Develop(develop.develop): diff --git a/third_party/BladeDISC b/third_party/BladeDISC new file mode 160000 index 00000000000..67c324289c3 --- /dev/null +++ b/third_party/BladeDISC @@ -0,0 +1 @@ +Subproject commit 67c324289c36da5187405c18600403a0d3681b61 diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 728c4eacd56..55fab72fb5d 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -1,6 +1,7 @@ load( "//bazel:rules_def.bzl", "ptxla_cc_library", + "ptxla_cc_test", ) genrule( diff --git a/torch_xla/csrc/runtime/disc/BUILD b/torch_xla/csrc/runtime/disc/BUILD new file mode 100644 index 00000000000..9b69e2205ea --- /dev/null +++ b/torch_xla/csrc/runtime/disc/BUILD @@ -0,0 +1,40 @@ +load( + "@tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) + +load( + "//bazel:rules_def.bzl", + "ptxla_cc_library", + "ptxla_cc_test", +) + + +ptxla_cc_library( + name = "disc_ral", + srcs = ["disc_ral.cc"], + hdrs = [ + "disc_ral.h", + ], + deps = [ + "@disc_compiler//:disc_ral_cuda", + "@disc_compiler//:headers", + "@local_config_cuda//cuda:cuda_headers", + "@torch//:libc10", + "@torch//:libc10_cuda", + "@torch//:libtorch_cuda", + ], + copts = [ + "-DGOOGLE_CUDA", + ] +) + +ptxla_cc_test( + name = "disc_ral_test", + srcs = ["disc_ral_test.cc"], + deps = [ + ":disc_ral", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ] +) diff --git a/torch_xla/csrc/runtime/disc/disc_ral.cc b/torch_xla/csrc/runtime/disc/disc_ral.cc new file mode 100644 index 00000000000..e8571a6e0e9 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_ral.cc @@ -0,0 +1,166 @@ +#include "torch_xla/csrc/runtime/disc/disc_ral.h" + +#include +#include +#include +#include +#include +#include +#include +namespace torch_xla { +namespace runtime { +class RalAllocator : public tao::ral::Allocator { + public: + using buffer_t = tao::ral::buffer_t; + using alloc_t = tao::ral::alloc_t; + using dealloc_t = tao::ral::dealloc_t; + RalAllocator(alloc_t alloc_func, dealloc_t dealloc_func) + : alloc_func_(alloc_func), dealloc_func_(dealloc_func) {} + + buffer_t alloc(size_t bytes) { return alloc_func_(bytes); } + + void dealloc(buffer_t buffer) { dealloc_func_(buffer); } + + private: + alloc_t alloc_func_; + dealloc_t dealloc_func_; +}; + +int64_t RalContext::LazyInitCurrentDevice() { + int64_t cur_device = c10::cuda::current_device(); + int64_t prev_device = NULL_GPU_DEVICE; + bool success = gpu_device_.compare_exchange_strong(prev_device, cur_device); + if (!success) { + TORCH_CHECK(prev_device == cur_device, + "Device changed during inference. Please do NOT change CUDA " + "current device during inference."); + } + TORCH_CHECK(gpu_device_ != NULL_GPU_DEVICE); + return cur_device; +} + +BaseContext* RalContext::LoadCache() { + int64_t gpu_device = LazyInitCurrentDevice(); + TORCH_CHECK(gpu_device >= 0, "expect gpu device id >= 0, but got ", + gpu_device); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(gpu_device); + + // TODO: take care of the duplicated const + // which currently is managed per context + tao::ral::gpu::BaseCudaContextOption gpu_opt; + gpu_opt.device_ordinal = gpu_device; + gpu_opt.use_stream_executor = true; + gpu_opt.gpu_allocator.reset( + new RalAllocator(c10::cuda::CUDACachingAllocator::raw_alloc, + c10::cuda::CUDACachingAllocator::raw_delete)); + + std::lock_guard guard(mtx_); + tao::ral::BaseContext* ral_ctx_ptr; + auto it = ral_ctx_map_.find(stream); + if (it == ral_ctx_map_.end()) { + gpu_opt.stream = stream.stream(); + auto ral_ctx = + tao::ral::gpu::MakeBaseCudaContext(default_opt_, cpu_opt_, gpu_opt); + ral_ctx_ptr = ral_ctx.get(); + ral_ctx_map_[stream].reset(ral_ctx.release()); + } else { + ral_ctx_ptr = it->second.get(); + } + return ral_ctx_ptr; +} +at::List RalContext::PreProcessInputs( + const at::List& inputs) { + at::List contiguous_inputs; + for (at::Tensor inp_tensor : inputs) { + // make sure the input is in contiguous layout + contiguous_inputs.push_back(inp_tensor.contiguous()); + } + return contiguous_inputs; +} + +inline bool IsEmptyTensor(const tao::ral::buffer_shape_t& shape) { + return shape.size() > 0 && std::any_of(shape.begin(), shape.end(), + [](int64_t dim) { return dim == 0; }); +} + +at::List RalContext::CreateAndBindingOutputs( + tao::ral::ExecutionContext& exec_ctx) { + at::List outputs; + + auto num_outputs = disc_result_->outputs.size(); + outputs.reserve(num_outputs); + std::vector> out_bufs( + num_outputs); + for (size_t idx = 0; idx < num_outputs; ++idx) { + auto& out_buf = out_bufs[idx]; + // Note: Ral has memory allocator that allocate memory each time forward. + // So it's thread-safe to reuse the underline memory. + exec_ctx.bindOutput(idx, &out_buf); + + const auto& output_info = disc_result_->outputs[idx]; + auto scalar_type = output_info.scalar_type; + torch::DeviceType dev_type = torch::kCUDA; + dev_type = (output_info.device == "cuda") ? torch::kCUDA : torch::kCPU; + + auto option = torch::device(dev_type) + .dtype(scalar_type) + .memory_format(torch::MemoryFormat::Contiguous); + at::Tensor out_tensor; + if (IsEmptyTensor(out_buf->shape())) { + out_tensor = torch::zeros(out_buf->shape(), option); + } else if (out_buf->owned()) { + auto cpu_allocator = c10::GetAllocator(torch::kCPU); + TORCH_CHECK(cpu_allocator != nullptr); + std::function deleter = [cpu_allocator](void* ptr) { + cpu_allocator->raw_deallocate(ptr); + }; + if (output_info.device == "cuda") { + deleter = c10::cuda::CUDACachingAllocator::raw_delete; + } + out_tensor = torch::from_blob(const_cast(out_buf->data()), + out_buf->shape(), deleter, option); + out_buf->release(); + } else { + out_tensor = torch::from_blob(const_cast(out_buf->data()), + out_buf->shape(), option) + .clone(); + } + outputs.push_back(out_tensor); + } + return outputs; +} +void RalContext::BindingInputs(const at::List& inputs, + tao::ral::ExecutionContext& exec_ctx) { + for (size_t idx = 0; idx < inputs.size(); ++idx) { + at::Tensor inp = inputs[idx]; + const auto& shape = inp.sizes(); + exec_ctx.bindInput(idx, inp.data_ptr(), shape.vec()); + } +} +at::List RalContext::Execute(const at::List& inputs) { + auto ral_ctx = LoadCache(); + // execution context is per-inference context and thread-safe + auto exec_ctx = + tao::ral::MakeExecutionContext( + ral_ctx); + + auto contiguous_inputs = PreProcessInputs(inputs); + BindingInputs(contiguous_inputs, *exec_ctx.get()); + + auto tao_ral_func_ptr = reinterpret_cast(&tao_ral_call_impl); + + // execute + void* ctx_struct[] = {exec_ctx.get(), tao_ral_func_ptr}; + try { + entry_func_(ctx_struct); + } catch (std::exception& ex) { + LOG(ERROR) << ex.what(); + throw ex; + } + + auto outputs = CreateAndBindingOutputs(*exec_ctx.get()); + return outputs; +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc/disc_ral.h b/torch_xla/csrc/runtime/disc/disc_ral.h new file mode 100644 index 00000000000..0ea71350d27 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_ral.h @@ -0,0 +1,63 @@ +#include +#include +#include +#include + +namespace torch_xla { +namespace runtime { + +using tao::ral::BaseContext; +using tao::ral::ExecutionContext; + +class DataMeta { + public: + std::string device; + c10::ScalarType scalar_type; +}; + +class DISCComplationResult { + public: + std::string ral_lib; + std::string ral_mate_pb; + std::vector inputs; + std::vector outputs; +}; + +class RalContext { + using EntryFunc = std::function; + + public: + RalContext(std::shared_ptr disc_result) + : disc_result_(disc_result){}; + ~RalContext(){}; + + at::List Execute(const at::List&); + + private: + void BindingInputs(const at::List& inputs, + tao::ral::ExecutionContext& exec_ctx); + void CheckCurrentDevice(const at::List& inputs); + at::List CreateAndBindingOutputs( + tao::ral::ExecutionContext& exec_ctx); + at::List PreProcessInputs(const at::List& inputs); + + int64_t LazyInitCurrentDevice(); + + constexpr static int64_t NULL_GPU_DEVICE = -1; + std::atomic gpu_device_{NULL_GPU_DEVICE}; + std::mutex mtx_; + std::unordered_map> + ral_ctx_map_; + tao::ral::BaseContext* LoadCache(); + + tao::ral::BaseContextOption default_opt_; + tao::ral::cpu::BaseCpuContextOption cpu_opt_; + + std::shared_ptr disc_result_; + + void* tao_lib_; + EntryFunc entry_func_; +}; +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc/disc_ral_test.cc b/torch_xla/csrc/runtime/disc/disc_ral_test.cc new file mode 100644 index 00000000000..4ab01b06f38 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_ral_test.cc @@ -0,0 +1,18 @@ +#include "torch_xla/csrc/runtime/disc/disc_ral.h" + +#include + +namespace torch_xla { +namespace runtime { + +TEST(DISCRAlTest, E2E) { + // TODO(disc): need compile API to output the compilation result + std::shared_ptr disc_result = + std::make_shared(); + RalContext ral_ctx(disc_result); + std::vector inputs; + ral_ctx.Execute(at::List()); +} + +} // namespace runtime +} // namespace torch_xla From 17220c2d85c97e9e392e921375e9fcd849281bd1 Mon Sep 17 00:00:00 2001 From: "wenting.swt" Date: Tue, 12 Dec 2023 17:12:55 +0800 Subject: [PATCH 02/12] [to #53687860] feat: DISC client header, implement DISCComputation and DISCData POC implement in : https://code.alibaba-inc.com/torchx/xla/codereview/14984824 Link: https://code.alibaba-inc.com/torchx/xla/codereview/14987956 --- torch_xla/csrc/runtime/BUILD | 40 ++++ .../csrc/runtime/disc_computation_client.cc | 56 +++++ .../csrc/runtime/disc_computation_client.h | 215 ++++++++++++++++++ .../runtime/disc_computation_client_test.cc | 89 ++++++++ torch_xla/csrc/runtime/env_vars.cc | 1 + torch_xla/csrc/runtime/env_vars.h | 1 + torch_xla/csrc/runtime/runtime.cc | 5 +- 7 files changed, 406 insertions(+), 1 deletion(-) create mode 100644 torch_xla/csrc/runtime/disc_computation_client.cc create mode 100644 torch_xla/csrc/runtime/disc_computation_client.h create mode 100644 torch_xla/csrc/runtime/disc_computation_client_test.cc diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 41450933f67..585ab4a6198 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -10,6 +10,7 @@ load( load( "//bazel:rules_def.bzl", + "ptxla_cc_library", "ptxla_cc_test", ) @@ -30,6 +31,7 @@ cc_library( ":env_vars", ":pjrt_computation_client", ":ifrt_computation_client", + ":disc_computation_client", "@tsl//tsl/platform:stacktrace", ], ) @@ -137,6 +139,22 @@ cc_library( ], ) +cc_library( + name = "disc_computation_client", + srcs = [ + "disc_computation_client.cc", + ], + hdrs = [ + "disc_computation_client.h", + ], + deps = [ + ":computation_client", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla/client:xla_computation" + ], +) + cc_library( name = "cache", hdrs = ["cache.h"], @@ -506,3 +524,25 @@ ptxla_cc_test( "@tsl//tsl/platform:test_main", ], ) + +ptxla_cc_test( + name = "disc_computation_client_test", + srcs = ["disc_computation_client_test.cc"], + deps = [ + ":disc_computation_client", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@stablehlo//:register", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc new file mode 100644 index 00000000000..c51cb9662c1 --- /dev/null +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -0,0 +1,56 @@ +#include "torch_xla/csrc/runtime/disc_computation_client.h" + +#include + +#include "torch_xla/csrc/runtime/computation_client.h" + +namespace torch_xla { +namespace runtime { + +DISCComputationClient::DISCComputationClient() {} + +DISCComputationClient::~DISCComputationClient() {} + +void DISCComputationClient::DISCData::Assign( + const torch::lazy::BackendData& data) { + const DISCData& disc_data = dynamic_cast(data); + if (&disc_data != this) { + buffer = disc_data.buffer; + } +} + +ComputationClient::DataPtr DISCComputationClient::CreateDataPlaceholder( + std::string device, xla::Shape shape, + std::optional sharding) { + return std::make_shared(std::move(device), std::move(shape)); +} + +std::vector TransferToDevice( + absl::Span> tensors) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + +std::vector DISCComputationClient::TransferFromDevice( + absl::Span handles) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + +std::vector DISCComputationClient::Compile( + std::vector instances) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + +std::vector +DISCComputationClient::ExecuteComputation( + const ComputationClient::Computation& computation, + absl::Span arguments, + const std::string& device, const ExecuteComputationOptions& options) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + +std::map DISCComputationClient::GetMetrics() const { + return {}; +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h new file mode 100644 index 00000000000..3f25fe45b16 --- /dev/null +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -0,0 +1,215 @@ +#ifndef XLA_CLIENT_DISC_COMPUTATION_CLIENT_H_ +#define XLA_CLIENT_DISC_COMPUTATION_CLIENT_H_ + +#include "torch_xla/csrc/runtime/computation_client.h" +#include "xla/client/xla_computation.h" + +namespace torch_xla { +namespace runtime { + +namespace disc { +class DISCLoadedExecutable {}; +} // namespace disc + +class DISCComputationClient : public ComputationClient { + public: + DISCComputationClient(); + ~DISCComputationClient(); + + DataPtr CreateDataPlaceholder( + std::string device, xla::Shape shape, + std::optional sharding = std::nullopt) override; + + std::vector GetDataShards(DataPtr data) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + DataPtr GetDataShard(DataPtr data, size_t index) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::vector ReshardData( + absl::Span handles, + absl::Span shardings) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + DataPtr WrapDataShards(absl::Span shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::optional GetDataSharding(DataPtr handle) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::vector TransferToDevice( + absl::Span> tensors) override; + + std::vector TransferFromDevice( + absl::Span handles) override; + + DataPtr TransferShardsToDevice( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + DataPtr CopyToDevice(DataPtr data, std::string dst) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::string SerializeComputation(const ComputationPtr computation) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + ComputationPtr DeserializeComputation( + const std::string& serialized) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + torch::lazy::hash_t HashCompilationEnv() override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + torch_xla::DeviceType GetDeviceType() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; + + bool CoordinatorInitialized() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + XlaCoordinator& GetCoordinator() override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::vector Compile( + std::vector instances) override; + + std::vector ExecuteComputation( + const Computation& computation, absl::Span arguments, + const std::string& device, + const ExecuteComputationOptions& options) override; + + std::vector ExecuteReplicated( + const Computation& computation, absl::Span arguments, + absl::Span devices, + const ExecuteReplicatedOptions& options) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + size_t GetNumDevices() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::string GetDefaultDevice() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::vector GetLocalDevices() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::vector GetAllDevices() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + int GetProcessIndex() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; + + int GetNumProcesses() const override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + const absl::flat_hash_map< + std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& + GetDeviceAttributes(const std::string& device) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + void SetReplicationDevices( + std::shared_ptr> devices) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::shared_ptr> GetReplicationDevices() override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + void WaitDeviceOps(absl::Span devices) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + std::map GetMetrics() const override; + + MemoryInfo GetMemoryInfo(const std::string& device) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + private: + struct DISCData : public Data { + DISCData(std::string device, xla::Shape device_shape) + : Data(std::move(device), std::move(device_shape)) {} + + DISCData(std::string device, xla::Shape device_shape, + std::shared_ptr buffer) + : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} + + void Assign(const torch::lazy::BackendData& data) override; + + bool HasValue() const override { + return buffer->defined() && buffer->element_size() > 0; + } + + Handle GetHandle() override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + + bool HasSharding() const override { return false; } + + xla::OpSharding GetSharding() const override { + XLA_CHECK(false) << "GetSharding should not be called on DISCData, check " + "HasSharding first"; + return xla::OpSharding(); + } + + std::string ToString() const override { + std::stringstream ss; + ss << "XLAData: \n"; + ss << " Data Device: " << device() << "\n"; + ss << " Data Shape: " << shape().ToString() << "\n"; + ss << " Data Handle: "; + if (HasValue()) { + ss << reinterpret_cast(buffer->const_data_ptr()) + << "\n"; + } else { + ss << "None\n"; + } + return ss.str(); + } + + std::shared_ptr buffer; + }; + + struct DISCComputation : public Computation { + DISCComputation(xla::XlaComputation computation, + std::vector devices, + std::unique_ptr executable) + : Computation(std::move(computation), std::move(devices)), + executable(std::move(executable)) {} + + std::unique_ptr executable; + }; +}; + +} // namespace runtime +} // namespace torch_xla +#endif // XLA_CLIENT_DISC_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/runtime/disc_computation_client_test.cc b/torch_xla/csrc/runtime/disc_computation_client_test.cc new file mode 100644 index 00000000000..f54de6079a8 --- /dev/null +++ b/torch_xla/csrc/runtime/disc_computation_client_test.cc @@ -0,0 +1,89 @@ +#include "torch_xla/csrc/runtime/disc_computation_client.h" + +#include + +#include +#include +#include +#include + +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/status.h" +#include "xla/statusor.h" +#include "xla/tests/literal_test_util.h" + +namespace torch_xla { +namespace runtime { + +tsl::StatusOr MakeComputation() { + xla::Shape input_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2}); + xla::XlaBuilder builder("AddComputation"); + xla::XlaOp x = xla::Parameter(&builder, 0, input_shape, "x"); + xla::XlaOp y = xla::Parameter(&builder, 1, input_shape, "y"); + xla::XlaOp sum = xla::Add(x, y); + return builder.Build(); +} + +ComputationClient::TensorSource TensorSourceFromLiteral( + const std::string& device, const xla::Literal& literal) { + auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + std::memcpy(dest_buffer, literal.data().data(), + dest_buffer_size * sizeof(literal.data().data())); + }; + return ComputationClient::TensorSource(literal.shape(), device, + std::move(populate_fn)); +} + +TEST(DISCComputationClientTest, Init) { + tsl::setenv("DISC_DEVICE", "GPU", true); + auto client = std::make_unique(); + std::string device = "cuda:0"; + + // // Compose a computation. + auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2}); + std::vector instances; + instances.push_back(ComputationClient::CompileInstance( + std::move(MakeComputation().value()), device, {"cuda:0"}, + &shape)); + + // // Prepare inputs. + xla::Literal literal_x = + xla::LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + xla::Literal literal_y = + xla::LiteralUtil::CreateR2({{5.0f, 6.0f}, {7.0f, 8.0f}}); + + // // Compile the graph. + std::vector computations = + client->Compile(std::move(instances)); + + // // Copy inputs to device. + ComputationClient::ExecuteComputationOptions options{}; + std::vector args = { + TensorSourceFromLiteral(device, literal_x), + TensorSourceFromLiteral(device, literal_y)}; + + // // Execute the graph. + auto inputs = client->TransferToServer(absl::MakeConstSpan(args)); + + std::vector results = client->ExecuteComputation( + *computations[0], inputs, + device, options); + + // // Copy the output from device back to host and assert correctness.. + auto result_literals = client->TransferFromServer(results); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), + result_literals[0])); +} + +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index f774d578ca6..ab3a972dd19 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -9,6 +9,7 @@ const char* const kEnvNumGpu = "GPU_NUM_DEVICES"; const char* const kEnvNumCpu = "CPU_NUM_DEVICES"; const char* const kEnvTpuvmMode = "TPUVM_MODE"; const char* const kEnvPjRtDevice = "PJRT_DEVICE"; +const char* const kEnvDISCDevice = "DISC_DEVICE"; const char* const kEnvPjRtTpuMaxInflightComputations = "PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS"; const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index 3affac2031e..52aa270e1d6 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -19,6 +19,7 @@ extern const char* const kEnvHostOrdinal; extern const char* const kEnvShardOrdinal; extern const char* const kEnvStartService; extern const char* const kEnvTpuvmMode; +extern const char* const kEnvDISCDevice; extern const char* const kEnvPjRtDevice; extern const char* const kEnvPjRtTpuMaxInflightComputations; extern const char* const kEnvPjrtAsyncCpuClient; diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index feb2a0844c6..202f56c9a70 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -5,6 +5,7 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" +#include "torch_xla/csrc/runtime/disc_computation_client.h" #include "tsl/platform/stacktrace_handler.h" namespace torch_xla { @@ -21,7 +22,9 @@ ComputationClient* GetComputationClient() { std::unique_ptr client; static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); - if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { + if (sys_util::GetEnvString(env::kEnvDISCDevice, "") != "") { + client = std::make_unique(); + } else if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { if (use_ifrt) { client = std::make_unique(); } else { From 85927c6493c45ae83b9a02b778555b999585c47b Mon Sep 17 00:00:00 2001 From: Dalong Date: Wed, 10 Jan 2024 19:16:23 +0800 Subject: [PATCH 03/12] Disc computation (#2) Support Disc as backend Co-authored-by: yancey.yx Co-authored-by: wangang.wa --- bazel/disc.BUILD | 4 +- bazel/flash_attn.BUILD | 7 +- setup.py | 6 +- torch_xla/csrc/runtime/BUILD | 7 +- torch_xla/csrc/runtime/disc/BUILD | 31 ++ torch_xla/csrc/runtime/disc/disc_compile.cc | 73 ++++ torch_xla/csrc/runtime/disc/disc_compile.h | 25 ++ torch_xla/csrc/runtime/disc/disc_ral.cc | 162 +++++++-- torch_xla/csrc/runtime/disc/disc_ral.h | 44 ++- torch_xla/csrc/runtime/disc/disc_ral_test.cc | 3 +- torch_xla/csrc/runtime/disc/disc_utils.cc | 114 +++++++ torch_xla/csrc/runtime/disc/disc_utils.h | 36 ++ .../csrc/runtime/disc_computation_client.cc | 322 +++++++++++++++++- .../csrc/runtime/disc_computation_client.h | 69 ++-- .../runtime/disc_computation_client_test.cc | 43 +-- torch_xla/csrc/runtime/env_vars.h | 0 torch_xla/csrc/runtime/runtime.cc | 2 +- torch_xla/csrc/runtime/stablehlo_helper.cc | 15 +- torch_xla/csrc/runtime/stablehlo_helper.h | 3 + 19 files changed, 838 insertions(+), 128 deletions(-) mode change 100644 => 100755 torch_xla/csrc/runtime/BUILD mode change 100644 => 100755 torch_xla/csrc/runtime/disc/BUILD create mode 100644 torch_xla/csrc/runtime/disc/disc_compile.cc create mode 100644 torch_xla/csrc/runtime/disc/disc_compile.h create mode 100644 torch_xla/csrc/runtime/disc/disc_utils.cc create mode 100644 torch_xla/csrc/runtime/disc/disc_utils.h mode change 100644 => 100755 torch_xla/csrc/runtime/env_vars.h diff --git a/bazel/disc.BUILD b/bazel/disc.BUILD index c6ffbb187bd..fc72c6c0025 100644 --- a/bazel/disc.BUILD +++ b/bazel/disc.BUILD @@ -33,9 +33,9 @@ genrule( outs = ["libral_base_context.so", "disc_compiler_main", "torch-mlir-opt"], local = True, cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}', - 'pushd third_party/BladeDISC/pytorch_blade', + 'pushd external/disc_compiler/pytorch_blade/', 'python ../scripts/python/common_setup.py', - 'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel 2>&1 | tee build.log', + 'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel', 'popd', 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/ral/libral_base_context.so $(location libral_base_context.so)', 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)', diff --git a/bazel/flash_attn.BUILD b/bazel/flash_attn.BUILD index e5fd5ca6013..6be811b826b 100644 --- a/bazel/flash_attn.BUILD +++ b/bazel/flash_attn.BUILD @@ -23,9 +23,8 @@ genrule( name = "build_flash_attn", srcs = ["setup.py"], outs = ["flash_attn_cuda.so"], - cmd = ';'.join(['pushd third_party/flash-attention/', - 'MAX_JOBS=50 FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel 2>&1 | tee build.log', - 'cp build/*/*.so flash_attn_cuda.so', + cmd = ';'.join(['pushd external/flash_attn/', + 'FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel', 'popd', - 'cp third_party/flash-attention/flash_attn_cuda.so $(OUTS)']), + 'cp external/flash_attn/build/*/*.so $(location flash_attn_cuda.so)']), ) diff --git a/setup.py b/setup.py index e095867812d..599a9c03bd7 100644 --- a/setup.py +++ b/setup.py @@ -244,9 +244,9 @@ def bazel_build(self, ext): # copy flash attention cuda so file flash_attn_so_name = 'flash_attn_cuda.so' - shutil.copyfile( - '/'.join(['third_party/flash-attention', flash_attn_so_name]), - '/'.join([ext_dest_dir, flash_attn_so_name])) + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/flash_attn/' + shutil.copyfile('/'.join([bazel_bin_path, flash_attn_so_name]), + '/'.join([ext_dest_dir, flash_attn_so_name])) # package BladeDISC distribution files # please note, TorchBlade also create some symbolic links to 'torch_blade' dir diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD old mode 100644 new mode 100755 index 585ab4a6198..73bee67fd84 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -149,9 +149,14 @@ cc_library( ], deps = [ ":computation_client", + ":stablehlo_helper", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", "@xla//xla:literal", "@xla//xla:shape_util", - "@xla//xla/client:xla_computation" + "@xla//xla/client:xla_computation", + "//torch_xla/csrc/runtime/disc:disc_ral", + "//torch_xla/csrc/runtime/disc:disc_compile", ], ) diff --git a/torch_xla/csrc/runtime/disc/BUILD b/torch_xla/csrc/runtime/disc/BUILD old mode 100644 new mode 100755 index 9b69e2205ea..817c4299493 --- a/torch_xla/csrc/runtime/disc/BUILD +++ b/torch_xla/csrc/runtime/disc/BUILD @@ -17,6 +17,7 @@ ptxla_cc_library( "disc_ral.h", ], deps = [ + ":disc_utils", "@disc_compiler//:disc_ral_cuda", "@disc_compiler//:headers", "@local_config_cuda//cuda:cuda_headers", @@ -29,6 +30,36 @@ ptxla_cc_library( ] ) +ptxla_cc_library( + name = "disc_utils", + srcs = ["disc_utils.cc"], + hdrs = [ + "disc_utils.h", + ], + deps = [ + "//torch_xla/csrc/runtime:tf_logging", + ] +) + +ptxla_cc_library( + name = "disc_compile", + srcs = ["disc_compile.cc"], + hdrs = [ + "disc_compile.h", + ], + deps = [ + ":disc_ral", + ":disc_utils", + "//torch_xla/csrc/runtime:tf_logging", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], + copts = [ + "-DGOOGLE_CUDA", + ] +) + ptxla_cc_test( name = "disc_ral_test", srcs = ["disc_ral_test.cc"], diff --git a/torch_xla/csrc/runtime/disc/disc_compile.cc b/torch_xla/csrc/runtime/disc/disc_compile.cc new file mode 100644 index 00000000000..ffff24e3eba --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_compile.cc @@ -0,0 +1,73 @@ +#include "torch_xla/csrc/runtime/disc/disc_compile.h" + +#include + +#include + +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace torch_xla { +namespace runtime { +namespace disc { + +std::string CurrentLibLocation() { + Dl_info dl_info; + dladdr((void*)CurrentLibLocation, &dl_info); + auto fname = std::string(dl_info.dli_fname); + return fname.substr(0, fname.find_last_of("/")); +} + +std::string CompileCMD(const std::string& mlir_fname, + const std::string& out_fname) { + std::stringstream ss; + std::string logf = absl::StrCat(mlir_fname, ".log"); + // unset XLA_FLAGS, otherwise tf will throw parse error + std::string compile_cmd = "unset XLA_FLAGS"; + absl::StrAppend(&compile_cmd, "&&", CurrentLibLocation(), + "/disc_compiler_main", " ", mlir_fname, " ", out_fname, " > ", + logf, " 2>&1"); + return compile_cmd; +} + +std::tuple CallDiscCompiler( + const std::string& mlir_fname) { + std::string out_fname = mlir_fname + ".out"; + std::string cmd = CompileCMD(mlir_fname, out_fname); + TF_VLOG(1) << "Executing command: " << cmd << " to compile mhlo..."; + auto ret = std::system(cmd.c_str()); + return {cmd, out_fname, ret}; +} + +std::shared_ptr DumpMlir(mlir::ModuleOp& stablehlo_module) { + std::string model_dump_str; + llvm::raw_string_ostream os(model_dump_str); + stablehlo_module.print(os); + os.flush(); + std::shared_ptr stablehlo_file = std::make_shared("mlir"); + stablehlo_file->WriteBytesToFile(model_dump_str); + return stablehlo_file; +} + +DISCComplationResult Compile(mlir::ModuleOp& module, + const std::vector& inputs, + const std::vector& outputs) { + // Dump stablehlo to file + DISCComplationResult res; + auto mlir_file = DumpMlir(module); + + // Compile mhlo + auto compile_res = CallDiscCompiler(mlir_file->GetFilename()); + auto output_fname = std::get<1>(compile_res); + + // Construct compiled result + res.ral_lib = ReadFileBytes(output_fname); + res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt")); + res.inputs = inputs; + res.outputs = outputs; + + return res; +} + +} // namespace disc +} // namespace runtime +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_compile.h b/torch_xla/csrc/runtime/disc/disc_compile.h new file mode 100644 index 00000000000..0f3ac885211 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_compile.h @@ -0,0 +1,25 @@ +#ifndef XLA_TORCH_XLA_CSRC_RUNTIME_DISC_COMPILE_H_ +#define XLA_TORCH_XLA_CSRC_RUNTIME_DISC_COMPILE_H_ + +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "torch_xla/csrc/runtime/disc/disc_ral.h" + +namespace torch_xla { +namespace runtime { +namespace disc { +DISCComplationResult Compile(mlir::ModuleOp& module, + const std::vector& inputs, + const std::vector& outputs); + +} // namespace disc +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_RUNTIME_DISC_COMPILE_H_ \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_ral.cc b/torch_xla/csrc/runtime/disc/disc_ral.cc index e8571a6e0e9..ca2216c79b8 100644 --- a/torch_xla/csrc/runtime/disc/disc_ral.cc +++ b/torch_xla/csrc/runtime/disc/disc_ral.cc @@ -1,14 +1,23 @@ #include "torch_xla/csrc/runtime/disc/disc_ral.h" -#include +#include #include #include -#include +#include #include #include +#include #include + +#include + +#include "absl/strings/str_cat.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + namespace torch_xla { namespace runtime { +namespace disc { + class RalAllocator : public tao::ral::Allocator { public: using buffer_t = tao::ral::buffer_t; @@ -26,6 +35,75 @@ class RalAllocator : public tao::ral::Allocator { dealloc_t dealloc_func_; }; +RalContext::RalContext(const DISCComplationResult& disc_result) + : disc_result_(disc_result) { + auto is_ok = meta_tmpf_.WriteBytesToFile(disc_result_.ral_mate_pb); + TORCH_CHECK(is_ok, "Failed to dump model proto to file."); + default_opt_.metadata_file_path = meta_tmpf_.GetFilename(); + default_opt_.cache_workspace_mem_across_execution = true; + auto torch_allocator = c10::GetAllocator(torch::kCPU); + TORCH_CHECK(torch_allocator != nullptr); + auto cpu_alloc = [torch_allocator](size_t n) { + return torch_allocator->raw_allocate(n); + }; + auto cpu_delete = [torch_allocator](void* ptr) { + torch_allocator->raw_deallocate(ptr); + }; + cpu_opt_.cpu_allocator.reset(new RalAllocator(cpu_alloc, cpu_delete)); + + at::globalContext().lazyInitCUDA(); + + void* func_handle = nullptr; + std::tie(tao_lib_, func_handle) = LoadEngine(disc_result_.ral_lib); + + using func_t = void (*)(void**); + entry_func_ = (func_t)func_handle; + + CHECK(entry_func_ != nullptr); +} + +std::tuple RalContext::LoadEngine( + const std::string& ral_engine_bytes) { + auto is_ok = lib_tmpf_.WriteBytesToFile(ral_engine_bytes); + TORCH_CHECK(is_ok, "Failed to dump RAL engine to file"); + std::string filename = lib_tmpf_.GetFilename(); + + void* tao_lib = dlopen(filename.c_str(), RTLD_NOW | RTLD_LOCAL); + TORCH_CHECK(tao_lib, "Fail to open ral engine"); + + void* func_handle = dlsym(tao_lib, kMlirLoweredEntry); + TORCH_CHECK(func_handle, "Fail to find kMlirLoweredEntry"); + return std::make_tuple(tao_lib, func_handle); +} + +RalContext::~RalContext() { + if (tao_lib_ != nullptr) { + dlclose(tao_lib_); + } +} + +void RalContext::CheckCurrentDevice(const std::vector& inputs) { + int64_t gpu_device = LazyInitCurrentDevice(); + // Engine Context + if (inputs.empty()) { + return; + } + + torch::Device cur_cuda_device = torch::Device(torch::kCUDA, gpu_device); + + TORCH_CHECK(disc_result_.inputs.size() == inputs.size()); + for (size_t k = 0; k < inputs.size(); ++k) { + at::Tensor inp = inputs[k]; + auto device = disc_result_.inputs[k].device; + if (device == "cuda") { + TORCH_CHECK(inp.device() == cur_cuda_device, "Input tensor ", k, + " device mismatch. Expect: ", cur_cuda_device, + ", got: ", inp.device()); + } + } + return; +} + int64_t RalContext::LazyInitCurrentDevice() { int64_t cur_device = c10::cuda::current_device(); int64_t prev_device = NULL_GPU_DEVICE; @@ -39,14 +117,12 @@ int64_t RalContext::LazyInitCurrentDevice() { return cur_device; } -BaseContext* RalContext::LoadCache() { +tao::ral::BaseContext* RalContext::LoadCache() { int64_t gpu_device = LazyInitCurrentDevice(); TORCH_CHECK(gpu_device >= 0, "expect gpu device id >= 0, but got ", gpu_device); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(gpu_device); - // TODO: take care of the duplicated const - // which currently is managed per context tao::ral::gpu::BaseCudaContextOption gpu_opt; gpu_opt.device_ordinal = gpu_device; gpu_opt.use_stream_executor = true; @@ -68,9 +144,12 @@ BaseContext* RalContext::LoadCache() { } return ral_ctx_ptr; } -at::List RalContext::PreProcessInputs( - const at::List& inputs) { - at::List contiguous_inputs; + +std::vector RalContext::PreProcessInputs( + const std::vector& inputs) { + CheckCurrentDevice(inputs); + + std::vector contiguous_inputs; for (at::Tensor inp_tensor : inputs) { // make sure the input is in contiguous layout contiguous_inputs.push_back(inp_tensor.contiguous()); @@ -83,11 +162,27 @@ inline bool IsEmptyTensor(const tao::ral::buffer_shape_t& shape) { [](int64_t dim) { return dim == 0; }); } -at::List RalContext::CreateAndBindingOutputs( +inline bool IsSameShape(const tao::ral::buffer_shape_t& shape, + at::Tensor input_tensor) { + if (input_tensor.dim() != shape.size()) { + return false; + } + + for (int i = 0; i < shape.size(); i++) { + if (input_tensor.sizes()[i] != shape[i]) { + return false; + } + } + + return true; +} + +std::vector RalContext::CreateAndBindingOutputs( + const std::vector& inputs, tao::ral::ExecutionContext& exec_ctx) { - at::List outputs; + std::vector outputs; - auto num_outputs = disc_result_->outputs.size(); + auto num_outputs = disc_result_.outputs.size(); outputs.reserve(num_outputs); std::vector> out_bufs( num_outputs); @@ -97,8 +192,9 @@ at::List RalContext::CreateAndBindingOutputs( // So it's thread-safe to reuse the underline memory. exec_ctx.bindOutput(idx, &out_buf); - const auto& output_info = disc_result_->outputs[idx]; + const auto& output_info = disc_result_.outputs[idx]; auto scalar_type = output_info.scalar_type; + torch::DeviceType dev_type = torch::kCUDA; dev_type = (output_info.device == "cuda") ? torch::kCUDA : torch::kCPU; @@ -121,15 +217,31 @@ at::List RalContext::CreateAndBindingOutputs( out_buf->shape(), deleter, option); out_buf->release(); } else { - out_tensor = torch::from_blob(const_cast(out_buf->data()), - out_buf->shape(), option) - .clone(); + //(@yuanxiulong.yxl) For input output alias, now we will only have full + // tensor memory reuse. + // We will support partial memory space reuse in the future + bool alias_input = false; + for (auto& input_tensor : inputs) { + // same address, same shape, same dtype + if (input_tensor.data_ptr() == out_buf->data() && + scalar_type == input_tensor.dtype() && + IsSameShape(out_buf->shape(), input_tensor)) { + out_tensor = input_tensor; + alias_input = true; + } + } + if (!alias_input) { + out_tensor = torch::from_blob(const_cast(out_buf->data()), + out_buf->shape(), option) + .clone(); + } } outputs.push_back(out_tensor); } return outputs; } -void RalContext::BindingInputs(const at::List& inputs, + +void RalContext::BindingInputs(const std::vector& inputs, tao::ral::ExecutionContext& exec_ctx) { for (size_t idx = 0; idx < inputs.size(); ++idx) { at::Tensor inp = inputs[idx]; @@ -137,15 +249,17 @@ void RalContext::BindingInputs(const at::List& inputs, exec_ctx.bindInput(idx, inp.data_ptr(), shape.vec()); } } -at::List RalContext::Execute(const at::List& inputs) { + +std::vector RalContext::Execute( + const std::vector& inputs) { + // inputs are always contigous auto ral_ctx = LoadCache(); // execution context is per-inference context and thread-safe auto exec_ctx = tao::ral::MakeExecutionContext( ral_ctx); - auto contiguous_inputs = PreProcessInputs(inputs); - BindingInputs(contiguous_inputs, *exec_ctx.get()); + BindingInputs(inputs, *exec_ctx.get()); auto tao_ral_func_ptr = reinterpret_cast(&tao_ral_call_impl); @@ -158,9 +272,13 @@ at::List RalContext::Execute(const at::List& inputs) { throw ex; } - auto outputs = CreateAndBindingOutputs(*exec_ctx.get()); + // Support input output buffer reuse + // Now we only have full buffer reuse for alias + auto outputs = CreateAndBindingOutputs(inputs, *exec_ctx.get()); + return outputs; } -} // namespace runtime -} // namespace torch_xla +} // namespace disc +} // namespace runtime +} // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc/disc_ral.h b/torch_xla/csrc/runtime/disc/disc_ral.h index 0ea71350d27..f47431689c5 100644 --- a/torch_xla/csrc/runtime/disc/disc_ral.h +++ b/torch_xla/csrc/runtime/disc/disc_ral.h @@ -1,22 +1,24 @@ -#include -#include +#ifndef XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCRAL_H_ +#define XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCRAL_H_ + +#include +#include #include -#include + +#include "torch_xla/csrc/runtime/disc/disc_utils.h" namespace torch_xla { namespace runtime { +namespace disc { -using tao::ral::BaseContext; using tao::ral::ExecutionContext; -class DataMeta { - public: +struct DataMeta { std::string device; c10::ScalarType scalar_type; }; -class DISCComplationResult { - public: +struct DISCComplationResult { std::string ral_lib; std::string ral_mate_pb; std::vector inputs; @@ -27,19 +29,21 @@ class RalContext { using EntryFunc = std::function; public: - RalContext(std::shared_ptr disc_result) - : disc_result_(disc_result){}; - ~RalContext(){}; + RalContext(const DISCComplationResult& disc_result); + ~RalContext(); - at::List Execute(const at::List&); + std::vector Execute(const std::vector& inputs); private: - void BindingInputs(const at::List& inputs, + void BindingInputs(const std::vector& inputs, tao::ral::ExecutionContext& exec_ctx); - void CheckCurrentDevice(const at::List& inputs); - at::List CreateAndBindingOutputs( + void CheckCurrentDevice(const std::vector& inputs); + std::vector CreateAndBindingOutputs( + const std::vector& inputs, tao::ral::ExecutionContext& exec_ctx); - at::List PreProcessInputs(const at::List& inputs); + std::vector PreProcessInputs( + const std::vector& inputs); + std::tuple LoadEngine(const std::string& ral_engine_bytes); int64_t LazyInitCurrentDevice(); @@ -54,10 +58,16 @@ class RalContext { tao::ral::BaseContextOption default_opt_; tao::ral::cpu::BaseCpuContextOption cpu_opt_; - std::shared_ptr disc_result_; + DISCComplationResult disc_result_; void* tao_lib_; EntryFunc entry_func_; + + TempFile lib_tmpf_{"ral_lib.so"}; + TempFile meta_tmpf_{"ral_meta.pb"}; }; +} // namespace disc } // namespace runtime } // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCRAL_H_ diff --git a/torch_xla/csrc/runtime/disc/disc_ral_test.cc b/torch_xla/csrc/runtime/disc/disc_ral_test.cc index 4ab01b06f38..7d25bef0fb6 100644 --- a/torch_xla/csrc/runtime/disc/disc_ral_test.cc +++ b/torch_xla/csrc/runtime/disc/disc_ral_test.cc @@ -4,7 +4,7 @@ namespace torch_xla { namespace runtime { - +namespace disc { TEST(DISCRAlTest, E2E) { // TODO(disc): need compile API to output the compilation result std::shared_ptr disc_result = @@ -14,5 +14,6 @@ TEST(DISCRAlTest, E2E) { ral_ctx.Execute(at::List()); } +} // namespace disc } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc/disc_utils.cc b/torch_xla/csrc/runtime/disc/disc_utils.cc new file mode 100644 index 00000000000..b9c6303bac0 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_utils.cc @@ -0,0 +1,114 @@ +#include "torch_xla/csrc/runtime/disc/disc_utils.h" + +#include +#include + +#include + +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace torch_xla { +namespace runtime { +namespace disc { + +std::string ReadStringFromEnvVar(const char* env_var_name, + std::string default_val) { + const char* env_var_val = std::getenv(env_var_name); + if (env_var_val == nullptr) { + return default_val; + } + return std::string(env_var_val); +} + +// This function is copied from c10/util/tempfile.h, so it follows to these +// temperary directory env variables, too. +std::vector make_filename(std::string name_prefix) { + // The filename argument to `mkstemp` needs "XXXXXX" at the end according to + // http://pubs.opengroup.org/onlinepubs/009695399/functions/mkstemp.html + static const std::string kRandomPattern = "XXXXXX"; + // We see if any of these environment variables is set and use their value, or + // else default the temporary directory to `/tmp`. + static const char* env_variables[] = {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}; + std::string tmp_directory = "/tmp"; + for (const char* variable : env_variables) { + auto path = ReadStringFromEnvVar(variable, ""); + if (!path.empty()) { + tmp_directory = path; + break; + } + } + std::vector filename; + filename.reserve(tmp_directory.size() + name_prefix.size() + + kRandomPattern.size() + 2); + filename.insert(filename.end(), tmp_directory.begin(), tmp_directory.end()); + filename.push_back('/'); + filename.insert(filename.end(), name_prefix.begin(), name_prefix.end()); + filename.insert(filename.end(), kRandomPattern.begin(), kRandomPattern.end()); + filename.push_back('\0'); + return filename; +} + +std::string ReadFileBytes(const std::string& fname) { + std::ifstream input(fname, std::ios::binary); + std::vector bytes((std::istreambuf_iterator(input)), + (std::istreambuf_iterator())); + return std::string(bytes.begin(), bytes.end()); +} + +TempFile::TempFile(std::string prefix) : fname_(""), fd_(-1) { + auto fname = make_filename(prefix); + fd_ = mkstemp(fname.data()); + fname_ = std::string(fname.data()); + TORCH_CHECK(fd_ != -1, "Error generating temporary file, file name: ", fname_, + ", error: ", std::strerror(errno)); +} + +TempFile::~TempFile() { + if (!fname_.empty()) { + ::unlink(fname_.c_str()); + } + if (fd_ > 0) { + ::close(fd_); + } +} + +bool TempFile::WriteBytesToFile(const std::string& bytes) { + ssize_t left_len = bytes.length(); + const char* data = bytes.data(); + errno = 0; + while (left_len > 0) { + auto sz = ::write(fd_, data, left_len); + if (sz <= 0) { + if (errno != EINTR && errno != EAGAIN) { + TF_VLOG(1) << "Failed to write content to temp file: " << GetFilename() + << ", error: " << strerror(errno); + return false; + } + errno = 0; + continue; + } + left_len -= sz; + data += sz; + } + return true; +} + +const std::string& TempFile::GetFilename() const { return fname_; } + +std::string TempFile::ReadBytesFromFile() { + std::ifstream infile(fname_, std::ios::binary); + std::string str((std::istreambuf_iterator(infile)), + std::istreambuf_iterator()); + return str; +} + +std::string TempFile::ReadStringFromFile() { + std::ifstream infile(fname_); + std::string str((std::istreambuf_iterator(infile)), + std::istreambuf_iterator()); + return str; +} + +} // namespace disc +} // namespace runtime +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/disc_utils.h b/torch_xla/csrc/runtime/disc/disc_utils.h new file mode 100644 index 00000000000..df7476e77df --- /dev/null +++ b/torch_xla/csrc/runtime/disc/disc_utils.h @@ -0,0 +1,36 @@ +#ifndef XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCUTILS_H_ +#define XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCUTILS_H_ + +#include +#include + +namespace torch_xla { +namespace runtime { +namespace disc { +std::vector make_filename(std::string name_prefix); +std::string ReadFileBytes(const std::string& fname); +class TempFile { + public: + TempFile(std::string prefix = ""); + ~TempFile(); + TempFile(const TempFile&) = delete; + void operator=(const TempFile&) = delete; + /// Write bytes content to temp file and return true on success. + bool WriteBytesToFile(const std::string& bytes); + /// Read byte content from temp file. + std::string ReadBytesFromFile(); + /// Read string content from temp file.. + std::string ReadStringFromFile(); + /// Get the filename of the temp file. + const std::string& GetFilename() const; + + private: + std::string fname_; + int fd_; +}; + +} // namespace disc +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_RUNTIME_DISC_DISCUTILS_H_ \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index c51cb9662c1..d4066f44117 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -1,13 +1,97 @@ #include "torch_xla/csrc/runtime/disc_computation_client.h" +#include +#include +#include +#include + #include +#include "absl/strings/str_cat.h" +#include "llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/disc/disc_compile.h" +#include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/sys_util.h" namespace torch_xla { namespace runtime { -DISCComputationClient::DISCComputationClient() {} +at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { + switch (xla_type) { + case xla::PrimitiveType::BF16: + return at::ScalarType::BFloat16; + case xla::PrimitiveType::F16: + return at::ScalarType::Half; + case xla::PrimitiveType::F32: + return at::ScalarType::Float; + case xla::PrimitiveType::F64: + return at::ScalarType::Double; + case xla::PrimitiveType::PRED: + return at::ScalarType::Bool; + case xla::PrimitiveType::U8: + return at::ScalarType::Byte; + case xla::PrimitiveType::S8: + return at::ScalarType::Char; + case xla::PrimitiveType::S16: + case xla::PrimitiveType::U16: + return at::ScalarType::Short; + case xla::PrimitiveType::S32: + case xla::PrimitiveType::U32: + return at::ScalarType::Int; + case xla::PrimitiveType::S64: + case xla::PrimitiveType::U64: + return at::ScalarType::Long; + case xla::PrimitiveType::C64: + return at::ScalarType::ComplexFloat; + case xla::PrimitiveType::C128: + return at::ScalarType::ComplexDouble; + default: + XLA_ERROR() << "XLA type not supported: " << xla_type; + } +} + +xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) { + switch (scalar_type) { + case at::ScalarType::Double: + return xla::PrimitiveType::F64; + case at::ScalarType::Float: + return xla::PrimitiveType::F32; + case at::ScalarType::BFloat16: + return xla::PrimitiveType::BF16; + case at::ScalarType::Half: + return xla::PrimitiveType::F16; + case at::ScalarType::Bool: + return xla::PrimitiveType::PRED; + case at::ScalarType::Byte: + return xla::PrimitiveType::U8; + case at::ScalarType::Char: + return xla::PrimitiveType::S8; + case at::ScalarType::Short: + return xla::PrimitiveType::S16; + case at::ScalarType::Int: + return xla::PrimitiveType::S32; + case at::ScalarType::Long: + return xla::PrimitiveType::S64; + case at::ScalarType::ComplexFloat: + return xla::PrimitiveType::C64; + case at::ScalarType::ComplexDouble: + return xla::PrimitiveType::C128; + default: + XLA_ERROR() << "Type not supported: " << scalar_type; + } +} + +DISCComputationClient::DISCComputationClient() { + world_size_ = sys_util::GetEnvInt("WORLD_SIZE", 1); + local_rank_ = sys_util::GetEnvInt("LOCAL_RANK", 0); + global_rank_ = sys_util::GetEnvInt("RANK", local_rank_); +} DISCComputationClient::~DISCComputationClient() {} @@ -25,19 +109,181 @@ ComputationClient::DataPtr DISCComputationClient::CreateDataPlaceholder( return std::make_shared(std::move(device), std::move(shape)); } -std::vector TransferToDevice( +std::vector DISCComputationClient::TransferToDevice( absl::Span> tensors) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + std::vector datas; + datas.reserve(tensors.size()); + + size_t total_transfered_bytes = 0; + + for (auto& tensor : tensors) { + std::vector sizes; + for (auto& dim_val : tensor->shape().dimensions()) { + sizes.push_back(dim_val); + } + + auto dtype = + at::TensorOptions(TorchTypeFromXlaType(tensor->shape().element_type())); + auto ret = at::empty(sizes, dtype).contiguous(); + // tensor->populate_fn(tensor, ret.data_ptr(), + // ret.element_size() * ret.numel()); + std::memcpy(ret.data_ptr(), tensor->data(), + ret.element_size() * ret.numel()); + + total_transfered_bytes += ret.element_size() * ret.numel(); + + if (!torch::cuda::is_available()) { + XLA_ERROR() << "CUDA is not available."; + } + + auto device_ret = ret.to(at::kCUDA); + ComputationClient::DataPtr data = std::make_shared( + tensor->device(), tensor->shape(), device_ret); + datas.push_back(data); + } + + return datas; } std::vector DISCComputationClient::TransferFromDevice( absl::Span handles) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + std::vector literals; + literals.reserve(handles.size()); + for (auto handle : handles) { + std::shared_ptr disc_data = + std::dynamic_pointer_cast(handle); + xla::Shape target_shape = + xla::ShapeUtil::DeviceShapeToHostShape(xla::ShapeUtil::MakeShape( + XlaTypeFromTorchType(disc_data->buffer.dtype().toScalarType()), + disc_data->buffer.sizes())); + auto& literal = literals.emplace_back(target_shape); + auto host_data = disc_data->buffer.to(at::kCPU); + std::memcpy(literal.untyped_data(), host_data.data_ptr(), + literal.size_bytes()); + } + + return literals; } std::vector DISCComputationClient::Compile( std::vector instances) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + std::vector computations{}; + for (auto& instance : instances) { + mlir::MLIRContext context; + mlir::ModuleOp mlir_module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + auto status = torch_xla::ConvertHloToMhlo( + instance.computation.mutable_proto(), &mlir_module); + XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n" + << status.message(); + + // Add input and output attributes + auto entry_func_identifier = + mlir::StringAttr::get(&context, "tf.entry_function"); + auto input_placement_key = + mlir::StringAttr::get(&context, "input_placements"); + auto output_placement_key = + mlir::StringAttr::get(&context, "output_placements"); + auto input_output_alias_params_key = + mlir::StringAttr::get(&context, "input_output_alias_params"); + auto input_output_alias_outputs_key = + mlir::StringAttr::get(&context, "input_output_alias_outputs"); + + std::string input_placement = ""; + std::string output_placement = ""; + std::string input_output_alias_params = ""; + std::string input_output_alias_outputs = ""; + + std::vector inputs, outputs; + + auto input_output_alias = instance.computation.proto().input_output_alias(); + if (sys_util::GetEnvString("ENBALE_DISC_INPUT_OUTPUT_ALIAS", "") != "OFF") { + for (const auto& entry : input_output_alias.entries()) { + input_output_alias_params += + std::to_string(entry.parameter_number()) + ","; + input_output_alias_outputs += + std::to_string(entry.output_shape_index(0)) + ","; + } + } + if (!input_output_alias_params.empty()) { + input_output_alias_params.pop_back(); + input_output_alias_outputs.pop_back(); + } + + // Set attribute for entry function + mlir::func::FuncOp entry_func; + for (auto func : mlir_module.getOps()) { + if (func.getName().str() == "main") { + entry_func = func; + break; + } + } + + for (int i = 0; i < entry_func.getFunctionType().getNumInputs(); i++) { + absl::StrAppend(&input_placement, "gpu,"); + disc::DataMeta tensor_info; + tensor_info.device = "cuda"; + inputs.push_back(tensor_info); + } + if (!input_placement.empty()) { + input_placement.pop_back(); + } + + if (instance.output_shape->IsTuple()) { + for (auto& sub_shape : instance.output_shape->tuple_shapes()) { + absl::StrAppend(&output_placement, "gpu,"); + disc::DataMeta tensor_info; + tensor_info.device = "cuda"; + tensor_info.scalar_type = + TorchTypeFromXlaType(sub_shape.element_type()); + outputs.push_back(tensor_info); + } + } else { + absl::StrAppend(&output_placement, "gpu,"); + disc::DataMeta tensor_info; + tensor_info.device = "cuda"; + tensor_info.scalar_type = + TorchTypeFromXlaType(instance.output_shape->element_type()); + outputs.push_back(tensor_info); + } + + if (!output_placement.empty()) { + output_placement.pop_back(); + } + + auto input_placement_value = + mlir::StringAttr::get(&context, input_placement); + auto output_placement_value = + mlir::StringAttr::get(&context, output_placement); + + auto input_output_alias_outputs_value = + mlir::StringAttr::get(&context, input_output_alias_outputs); + auto input_output_alias_params_value = + mlir::StringAttr::get(&context, input_output_alias_params); + + auto dict = mlir::DictionaryAttr::get( + &context, + {mlir::NamedAttribute(input_placement_key, input_placement_value), + mlir::NamedAttribute(output_placement_key, output_placement_value), + mlir::NamedAttribute(input_output_alias_params_key, + input_output_alias_params_value), + mlir::NamedAttribute(input_output_alias_outputs_key, + input_output_alias_outputs_value)}); + + entry_func->setAttr(entry_func_identifier, dict); + mlir_module->setAttr(entry_func_identifier, dict); + + // Trigger disc compilation + disc::DISCComplationResult compile_res = + disc::Compile(mlir_module, inputs, outputs); + std::shared_ptr disc_computation = + std::make_shared( + std::move(xla::XlaComputation(instance.computation.proto())), + instance.devices, std::make_unique(compile_res)); + computations.push_back(disc_computation); + } + + return computations; } std::vector @@ -45,12 +291,76 @@ DISCComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + const DISCComputation& disc_computation = + dynamic_cast(computation); + + std::vector buffers; + buffers.reserve(arguments.size()); + for (auto& argument : arguments) { + std::shared_ptr disc_data = + std::dynamic_pointer_cast(argument); + buffers.push_back(disc_data->buffer); + } + + std::vector results = + disc_computation.executable->Execute(buffers); + + std::vector datas; + datas.reserve(results.size()); + for (auto& result : results) { + std::shared_ptr data = std::make_shared( + device, xla::ShapeUtil::MakeShape(xla::F32, result.sizes()), result); + + datas.push_back(data); + } + + return datas; } std::map DISCComputationClient::GetMetrics() const { return {}; } +std::string DISCComputationClient::GetDefaultDevice() const { + return absl::StrCat(DefaultDevicePrefix, std::to_string(local_rank_)); +} + +std::vector DISCComputationClient::GetLocalDevices() const { + std::vector all_devices; + all_devices.push_back(GetDefaultDevice()); + return all_devices; +} + +std::optional DISCComputationClient::GetDataSharding( + ComputationClient::DataPtr handle) { + return std::optional(); +} + +void DISCComputationClient::SetReplicationDevices( + std::shared_ptr> devices) { + replication_devices_ = std::move(devices); +} + +std::shared_ptr> +DISCComputationClient::GetReplicationDevices() { + return replication_devices_; +} + +std::vector DISCComputationClient::GetAllDevices() const { + std::vector all_devices; + int device_count = world_size_; + for (int idx = 0; idx < device_count; idx++) { + all_devices.push_back( + absl::StrCat(DefaultDevicePrefix, std::to_string(idx))); + } + return all_devices; +} + +size_t DISCComputationClient::GetNumDevices() const { return world_size_; } + +int DISCComputationClient::GetProcessIndex() const { return local_rank_; } + +int DISCComputationClient::GetNumProcesses() const { return world_size_; } + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h index 3f25fe45b16..8df3a6c0e73 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.h +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -2,17 +2,17 @@ #define XLA_CLIENT_DISC_COMPUTATION_CLIENT_H_ #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/disc/disc_ral.h" +#include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "xla/client/xla_computation.h" namespace torch_xla { namespace runtime { -namespace disc { -class DISCLoadedExecutable {}; -} // namespace disc - class DISCComputationClient : public ComputationClient { public: + const std::string DefaultDevicePrefix = "GPU:"; + DISCComputationClient(); ~DISCComputationClient(); @@ -39,9 +39,7 @@ class DISCComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::optional GetDataSharding(DataPtr handle) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToDevice( absl::Span> tensors) override; @@ -69,11 +67,12 @@ class DISCComputationClient : public ComputationClient { } torch::lazy::hash_t HashCompilationEnv() override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + // TODO(wangang.wa): Improve this function. + return torch::lazy::hash_t(); } torch_xla::DeviceType GetDeviceType() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + return torch_xla::DeviceType("CUDA"); }; bool CoordinatorInitialized() const override { @@ -105,29 +104,17 @@ class DISCComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - size_t GetNumDevices() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + size_t GetNumDevices() const override; - std::string GetDefaultDevice() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::string GetDefaultDevice() const override; - std::vector GetLocalDevices() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::vector GetLocalDevices() const override; - std::vector GetAllDevices() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::vector GetAllDevices() const override; - int GetProcessIndex() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - }; + int GetProcessIndex() const override; - int GetNumProcesses() const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + int GetNumProcesses() const override; const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& @@ -136,13 +123,9 @@ class DISCComputationClient : public ComputationClient { } void SetReplicationDevices( - std::shared_ptr> devices) override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::shared_ptr> devices) override; - std::shared_ptr> GetReplicationDevices() override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } + std::shared_ptr> GetReplicationDevices() override; void WaitDeviceOps(absl::Span devices) override { XLA_ERROR() << __FUNCTION__ << " not implemented"; @@ -155,22 +138,25 @@ class DISCComputationClient : public ComputationClient { } private: + std::shared_ptr> replication_devices_; + int world_size_; + int local_rank_; + int global_rank_; struct DISCData : public Data { DISCData(std::string device, xla::Shape device_shape) : Data(std::move(device), std::move(device_shape)) {} - DISCData(std::string device, xla::Shape device_shape, - std::shared_ptr buffer) + DISCData(std::string device, xla::Shape device_shape, at::Tensor buffer) : Data(std::move(device), std::move(device_shape)), buffer(buffer) {} void Assign(const torch::lazy::BackendData& data) override; bool HasValue() const override { - return buffer->defined() && buffer->element_size() > 0; + return buffer.defined() && buffer.element_size() > 0; } Handle GetHandle() override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; + return reinterpret_cast(buffer.const_data_ptr()); } bool HasSharding() const override { return false; } @@ -188,25 +174,24 @@ class DISCComputationClient : public ComputationClient { ss << " Data Shape: " << shape().ToString() << "\n"; ss << " Data Handle: "; if (HasValue()) { - ss << reinterpret_cast(buffer->const_data_ptr()) - << "\n"; + ss << reinterpret_cast(buffer.const_data_ptr()) << "\n"; } else { ss << "None\n"; } return ss.str(); } - std::shared_ptr buffer; + at::Tensor buffer; }; struct DISCComputation : public Computation { DISCComputation(xla::XlaComputation computation, std::vector devices, - std::unique_ptr executable) + std::unique_ptr executable) : Computation(std::move(computation), std::move(devices)), executable(std::move(executable)) {} - std::unique_ptr executable; + std::unique_ptr executable; }; }; diff --git a/torch_xla/csrc/runtime/disc_computation_client_test.cc b/torch_xla/csrc/runtime/disc_computation_client_test.cc index f54de6079a8..902c3ab151c 100644 --- a/torch_xla/csrc/runtime/disc_computation_client_test.cc +++ b/torch_xla/csrc/runtime/disc_computation_client_test.cc @@ -1,6 +1,7 @@ #include "torch_xla/csrc/runtime/disc_computation_client.h" #include +#include #include #include @@ -32,54 +33,42 @@ tsl::StatusOr MakeComputation() { return builder.Build(); } -ComputationClient::TensorSource TensorSourceFromLiteral( - const std::string& device, const xla::Literal& literal) { - auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - std::memcpy(dest_buffer, literal.data().data(), - dest_buffer_size * sizeof(literal.data().data())); - }; - return ComputationClient::TensorSource(literal.shape(), device, - std::move(populate_fn)); -} - TEST(DISCComputationClientTest, Init) { tsl::setenv("DISC_DEVICE", "GPU", true); auto client = std::make_unique(); - std::string device = "cuda:0"; + std::string device = "GPU:0"; - // // Compose a computation. + // Compose a computation. auto shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 2}); std::vector instances; instances.push_back(ComputationClient::CompileInstance( - std::move(MakeComputation().value()), device, {"cuda:0"}, - &shape)); + std::move(MakeComputation().value()), device, {"cuda:0"}, &shape)); - // // Prepare inputs. + // Prepare inputs. xla::Literal literal_x = xla::LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); xla::Literal literal_y = xla::LiteralUtil::CreateR2({{5.0f, 6.0f}, {7.0f, 8.0f}}); - // // Compile the graph. + // Compile the graph. std::vector computations = client->Compile(std::move(instances)); - // // Copy inputs to device. + // Copy inputs to device. ComputationClient::ExecuteComputationOptions options{}; - std::vector args = { - TensorSourceFromLiteral(device, literal_x), - TensorSourceFromLiteral(device, literal_y)}; - - // // Execute the graph. - auto inputs = client->TransferToServer(absl::MakeConstSpan(args)); + std::vector> args = { + std::make_shared(std::move(literal_x), device), + std::make_shared(std::move(literal_y), device)}; + // Execute the graph. std::vector results = client->ExecuteComputation( - *computations[0], inputs, + *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), device, options); - // // Copy the output from device back to host and assert correctness.. - auto result_literals = client->TransferFromServer(results); + // Copy the output from device back to host and assert correctness.. + ASSERT_EQ(results.size(), 1); + auto result_literals = client->TransferFromDevice(results); + ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), result_literals[0])); diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h old mode 100644 new mode 100755 diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 202f56c9a70..be4bd41b503 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -2,10 +2,10 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/disc_computation_client.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" -#include "torch_xla/csrc/runtime/disc_computation_client.h" #include "tsl/platform/stacktrace_handler.h" namespace torch_xla { diff --git a/torch_xla/csrc/runtime/stablehlo_helper.cc b/torch_xla/csrc/runtime/stablehlo_helper.cc index 1d9a740e52c..6ff9292fd70 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_helper.cc @@ -50,8 +50,8 @@ static std::string getMlirModuleBytecode(mlir::ModuleOp& mlir_module) { return txt_mlir_module; } -static absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, - mlir::ModuleOp* mlir_module) { +absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, + mlir::ModuleOp* mlir_module) { auto status = xla::ConvertHloToMlirHlo(*mlir_module, proto, /*import_all_computations=*/false); if (!status.ok()) { @@ -62,6 +62,17 @@ static absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, absl::StatusCode::kInternal, "MHLO Module from HLO -> MHLO conversion is not legal."); } + mlir::PassManager pm(mlir_module->getContext()); + // Apply pass to remove HLO tuple output, as MHLO/StableHLO supports multiple + // outputs. + pm.addPass(mlir::mhlo::createExpandHloTuplesPass()); + // Canonicalization after tuple flatten, to remove unused tuple op. + pm.addNestedPass(mlir::createCanonicalizerPass()); + + XLA_CHECK(mlir::succeeded(pm.run(*mlir_module))) + << "HLO -> MHLO conversion failed.\n" + << status.message(); + return absl::OkStatus(); } diff --git a/torch_xla/csrc/runtime/stablehlo_helper.h b/torch_xla/csrc/runtime/stablehlo_helper.h index 235dc6b38b7..cfc65b817e2 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.h +++ b/torch_xla/csrc/runtime/stablehlo_helper.h @@ -20,6 +20,9 @@ void ConvertStableHloToHlo(mlir::ModuleOp* mlir_module, mlir::MLIRContext* context, xla::HloProto* hlo_proto); +absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, + mlir::ModuleOp* mlir_module); + std::string GetHloModuleStr(const xla::HloModuleProto* proto); const std::string GetTorchDtypeToStablehloDtype(const std::string& dtype); From ff96e9b54f9f13e120b7deab3297ac797acd53be Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Tue, 30 Jan 2024 19:28:43 +0800 Subject: [PATCH 04/12] add bazel flag to disable disc backend (#23) * add flag to disable disc backend in bazel workspace --- .bazelrc | 3 +++ BUILD | 8 +++++++- bazel/rules_def.bzl | 6 ++++++ setup.py | 3 +++ torch_xla/csrc/runtime/BUILD | 14 ++++++++++++-- torch_xla/csrc/runtime/runtime.cc | 6 ++++++ 6 files changed, 37 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index 34c41167982..d92caaf0fed 100644 --- a/.bazelrc +++ b/.bazelrc @@ -223,3 +223,6 @@ build:linux --copt="-Wno-error=unused-but-set-variable" # Only include debug info for files not under XLA. build:dbg -c dbg build:dbg --per_file_copt=external/xla/.*@-g0,-DNDEBUG + +# build with DISC backend +build --define enable_disc=true diff --git a/BUILD b/BUILD index 5efbd38c0b6..6b00ea55070 100644 --- a/BUILD +++ b/BUILD @@ -3,6 +3,11 @@ load( "if_cuda_is_configured", ) +load( + "//bazel:rules_def.bzl", + "if_enable_disc", +) + cc_binary( name = "_XLAC.so", copts = [ @@ -21,7 +26,6 @@ cc_binary( visibility = ["//visibility:public"], deps = [ "//torch_xla/csrc:init_python_bindings", - "//torch_xla/csrc/runtime/disc:disc_ral", "@torch//:headers", "@torch//:libc10", "@torch//:libtorch", @@ -29,5 +33,7 @@ cc_binary( "@torch//:libtorch_python", ] + if_cuda_is_configured([ "@xla//xla/stream_executor:cuda_platform", + ]) + if_enable_disc([ + "//torch_xla/csrc/runtime/disc:disc_ral", ]), ) diff --git a/bazel/rules_def.bzl b/bazel/rules_def.bzl index 4569630f170..b10cb659e9a 100644 --- a/bazel/rules_def.bzl +++ b/bazel/rules_def.bzl @@ -39,3 +39,9 @@ def ptxla_cc_test( ], **kwargs ) + +def if_enable_disc(if_true, if_false=[]): + return select({ + "//torch_xla/csrc/runtime:enable_disc": if_true, + "//conditions:default": if_false + }) \ No newline at end of file diff --git a/setup.py b/setup.py index 599a9c03bd7..34827fb1dc4 100644 --- a/setup.py +++ b/setup.py @@ -231,6 +231,9 @@ def bazel_build(self, ext): bazel_argv.extend(build_util.bazel_options_from_env()) + if not build_util.check_env_flag('ENABLE_DISC', 'true'): + bazel_argv.append('--define=enable_disc=false') + self.spawn(bazel_argv) ext_bazel_bin_path = os.path.join(self.build_temp, 'bazel-bin', ext.relpath, diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 73bee67fd84..9bd276c856f 100755 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -12,12 +12,18 @@ load( "//bazel:rules_def.bzl", "ptxla_cc_library", "ptxla_cc_test", + "if_enable_disc", ) licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +config_setting( + name = "enable_disc", + define_values = {"enable_disc": "true"}, +) + cc_library( name = "runtime", srcs = [ @@ -31,9 +37,13 @@ cc_library( ":env_vars", ":pjrt_computation_client", ":ifrt_computation_client", - ":disc_computation_client", "@tsl//tsl/platform:stacktrace", - ], + ] + if_enable_disc([ + ":disc_computation_client", + ]), + copts = if_enable_disc([ + "-DTORCHACC_ENABLE_DISC", + ]), ) cc_library( diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index be4bd41b503..d31fab90e86 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -2,7 +2,9 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/computation_client.h" +#ifdef TORCHACC_ENABLE_DISC #include "torch_xla/csrc/runtime/disc_computation_client.h" +#endif #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/ifrt_computation_client.h" #include "torch_xla/csrc/runtime/pjrt_computation_client.h" @@ -23,7 +25,11 @@ ComputationClient* GetComputationClient() { static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false); if (sys_util::GetEnvString(env::kEnvDISCDevice, "") != "") { +#ifdef TORCHACC_ENABLE_DISC client = std::make_unique(); +#else + XLA_ERROR() << "should build with ENABLE_DISC=ON" << std::endl; +#endif } else if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { if (use_ifrt) { client = std::make_unique(); From 2d70a88a6cf02ec81905dbd561905675ee47f747 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Tue, 27 Feb 2024 13:45:50 +0800 Subject: [PATCH 05/12] support disc debug mode to dump mhlo and logs (#25) support disc backend debug mode to dump DISC compilation logs --- torch_xla/csrc/runtime/disc/disc_compile.cc | 50 +++++++++++++++++---- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/torch_xla/csrc/runtime/disc/disc_compile.cc b/torch_xla/csrc/runtime/disc/disc_compile.cc index ffff24e3eba..053535f5e2e 100644 --- a/torch_xla/csrc/runtime/disc/disc_compile.cc +++ b/torch_xla/csrc/runtime/disc/disc_compile.cc @@ -4,25 +4,37 @@ #include +#include "torch_xla/csrc/runtime/env_vars.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" +using namespace std::filesystem; namespace torch_xla { namespace runtime { namespace disc { +bool IsDiscDebugMode() { return sys_util::GetEnvBool("DISC_DEBUG", false); } + +std::string GetDebugDumpDir() { + return sys_util::GetEnvString("DISC_DEBUG_DUMP_DIR", "./dump_dir"); +} + std::string CurrentLibLocation() { Dl_info dl_info; - dladdr((void*)CurrentLibLocation, &dl_info); + dladdr((void *)CurrentLibLocation, &dl_info); auto fname = std::string(dl_info.dli_fname); return fname.substr(0, fname.find_last_of("/")); } -std::string CompileCMD(const std::string& mlir_fname, - const std::string& out_fname) { +std::string CompileCMD(const std::string &mlir_fname, + const std::string &out_fname) { std::stringstream ss; std::string logf = absl::StrCat(mlir_fname, ".log"); // unset XLA_FLAGS, otherwise tf will throw parse error std::string compile_cmd = "unset XLA_FLAGS"; + if (IsDiscDebugMode()) { + absl::StrAppend(&compile_cmd, " && export TF_CPP_VMODULE=disc_compiler=1 "); + } absl::StrAppend(&compile_cmd, "&&", CurrentLibLocation(), "/disc_compiler_main", " ", mlir_fname, " ", out_fname, " > ", logf, " 2>&1"); @@ -30,7 +42,7 @@ std::string CompileCMD(const std::string& mlir_fname, } std::tuple CallDiscCompiler( - const std::string& mlir_fname) { + const std::string &mlir_fname) { std::string out_fname = mlir_fname + ".out"; std::string cmd = CompileCMD(mlir_fname, out_fname); TF_VLOG(1) << "Executing command: " << cmd << " to compile mhlo..."; @@ -38,7 +50,7 @@ std::tuple CallDiscCompiler( return {cmd, out_fname, ret}; } -std::shared_ptr DumpMlir(mlir::ModuleOp& stablehlo_module) { +std::shared_ptr DumpMlir(mlir::ModuleOp &stablehlo_module) { std::string model_dump_str; llvm::raw_string_ostream os(model_dump_str); stablehlo_module.print(os); @@ -48,9 +60,9 @@ std::shared_ptr DumpMlir(mlir::ModuleOp& stablehlo_module) { return stablehlo_file; } -DISCComplationResult Compile(mlir::ModuleOp& module, - const std::vector& inputs, - const std::vector& outputs) { +DISCComplationResult Compile(mlir::ModuleOp &module, + const std::vector &inputs, + const std::vector &outputs) { // Dump stablehlo to file DISCComplationResult res; auto mlir_file = DumpMlir(module); @@ -59,6 +71,28 @@ DISCComplationResult Compile(mlir::ModuleOp& module, auto compile_res = CallDiscCompiler(mlir_file->GetFilename()); auto output_fname = std::get<1>(compile_res); + if (IsDiscDebugMode()) { + std::string base_path = GetDebugDumpDir(); + auto ret = std::filesystem::create_directory(base_path); + if (ret != 0) { + TF_VLOG(0) << "Failed to create dump dir: " << base_path + << ", it maybe exists.\n"; + } + std::string mlir_fname = mlir_file->GetFilename(); + std::string log_fname = absl::StrCat(mlir_fname, ".log"); + std::filesystem::copy_file( + log_fname, + absl::StrCat(base_path, "/", + std::filesystem::path(mlir_fname).stem().string(), + ".log")); + std::filesystem::copy_file( + mlir_fname, + absl::StrCat(base_path, "/", + std::filesystem::path(mlir_fname).stem().string(), + ".mlir")); + TF_VLOG(1) << "Dumping mlir to file: " << mlir_file->GetFilename(); + } + // Construct compiled result res.ral_lib = ReadFileBytes(output_fname); res.ral_mate_pb = ReadFileBytes(absl::StrCat(output_fname, ".pbtxt")); From d14788bf47f1581d47d93441d222ca7b362ea3e7 Mon Sep 17 00:00:00 2001 From: Dalong Date: Fri, 1 Mar 2024 17:22:05 +0800 Subject: [PATCH 06/12] support flash attention in disc (#34) --- bazel/disc.BUILD | 8 +- setup.py | 6 + test/test_flash_attention_backward.py | 7 +- torch_xla/csrc/runtime/disc/BUILD | 9 +- .../custom_call_flash_attention_backward.cc | 385 ++++++++++++++++++ .../custom_call_flash_attention_forward.cc | 250 ++++++++++++ 6 files changed, 661 insertions(+), 4 deletions(-) mode change 100644 => 100755 test/test_flash_attention_backward.py create mode 100644 torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc create mode 100644 torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc diff --git a/bazel/disc.BUILD b/bazel/disc.BUILD index fc72c6c0025..7ed33604618 100644 --- a/bazel/disc.BUILD +++ b/bazel/disc.BUILD @@ -28,9 +28,14 @@ cc_import( shared_library = ":libral_base_context.so", ) +cc_import( + name="disc_custom_op", + shared_library = ":libdisc_custom_ops.so", +) + genrule( name = "build_disc", - outs = ["libral_base_context.so", "disc_compiler_main", "torch-mlir-opt"], + outs = ["libral_base_context.so", "libdisc_custom_ops.so", "disc_compiler_main", "torch-mlir-opt"], local = True, cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}', 'pushd external/disc_compiler/pytorch_blade/', @@ -38,6 +43,7 @@ genrule( 'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel', 'popd', 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/ral/libral_base_context.so $(location libral_base_context.so)', + 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/custom_ops/libdisc_custom_ops.so $(location libdisc_custom_ops.so)', 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)', 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/tests/mhlo/torch-mlir-opt/torch-mlir-opt $(location torch-mlir-opt)']), ) diff --git a/setup.py b/setup.py index 34827fb1dc4..51b0b3d52d9 100644 --- a/setup.py +++ b/setup.py @@ -259,6 +259,12 @@ def bazel_build(self, ext): os.path.join(bazel_bin_path, disc_ral_so_name), '/'.join([ext_dest_dir, disc_ral_so_name])) + disc_customop_so_name = 'libdisc_custom_ops.so' + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' + shutil.copyfile( + os.path.join(bazel_bin_path, disc_customop_so_name), + '/'.join([ext_dest_dir, disc_customop_so_name])) + class Develop(develop.develop): diff --git a/test/test_flash_attention_backward.py b/test/test_flash_attention_backward.py old mode 100644 new mode 100755 index 35b8a7c37fa..df8c15efc81 --- a/test/test_flash_attention_backward.py +++ b/test/test_flash_attention_backward.py @@ -1,3 +1,4 @@ +import os import sys import unittest @@ -149,13 +150,15 @@ def test_flash_attn_gqa_backward_fp16(self): self._backward_internal(torch.float16, n_heads_kv=int(N_HEADS // 2)) def test_flash_attn_gqa_backward_bf16(self): - self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2)) + if not os.environ.get('DISC_DEVICE'): + self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2)) def test_flash_attn_backward_fp16(self): self._backward_internal(torch.float16, n_heads_kv=N_HEADS) def test_flash_attn_backward_bf16(self): - self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS) + if not os.environ.get('DISC_DEVICE'): + self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS) def test_flash_attn_gqa_backward_fp16_alibi(self): self._backward_internal( diff --git a/torch_xla/csrc/runtime/disc/BUILD b/torch_xla/csrc/runtime/disc/BUILD index 817c4299493..d6d25ced337 100755 --- a/torch_xla/csrc/runtime/disc/BUILD +++ b/torch_xla/csrc/runtime/disc/BUILD @@ -12,18 +12,25 @@ load( ptxla_cc_library( name = "disc_ral", - srcs = ["disc_ral.cc"], + srcs = [ + "disc_ral.cc", + "custom_call_flash_attention_forward.cc", + "custom_call_flash_attention_backward.cc" + ], hdrs = [ "disc_ral.h", ], deps = [ ":disc_utils", "@disc_compiler//:disc_ral_cuda", + "@disc_compiler//:disc_custom_op", "@disc_compiler//:headers", "@local_config_cuda//cuda:cuda_headers", "@torch//:libc10", "@torch//:libc10_cuda", "@torch//:libtorch_cuda", + "@flash_attn//:headers", + "@flash_attn//:flash_attn_cuda", ], copts = [ "-DGOOGLE_CUDA", diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc new file mode 100644 index 00000000000..efd4f775f48 --- /dev/null +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc @@ -0,0 +1,385 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "mlir/ral/context/pdll_util.h" +#include "mlir/ral/context/stream_executor_based_impl.h" +#include "static_switch.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace tao { +namespace ral { + +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); + +struct FlashAttentionBackwardParams { + using index_t = uint32_t; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k + // could be different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + int total_k; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + bool is_bf16; + bool is_causal; + + // Backward specific params + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + void FromString(const std::string& str) { + std::vector params_list = absl::StrSplit(str, "|"); + TORCH_CHECK(params_list.size() == 43); + + // Forward specific param + absl::SimpleAtoi(params_list[0], &this->q_batch_stride); + absl::SimpleAtoi(params_list[1], &this->k_batch_stride); + absl::SimpleAtoi(params_list[2], &this->v_batch_stride); + absl::SimpleAtoi(params_list[3], &this->q_row_stride); + absl::SimpleAtoi(params_list[4], &this->k_row_stride); + absl::SimpleAtoi(params_list[5], &this->v_row_stride); + absl::SimpleAtoi(params_list[6], &this->q_head_stride); + absl::SimpleAtoi(params_list[7], &this->k_head_stride); + absl::SimpleAtoi(params_list[8], &this->v_head_stride); + absl::SimpleAtoi(params_list[9], &this->total_k); + absl::SimpleAtoi(params_list[10], &this->h); + absl::SimpleAtoi(params_list[11], &this->h_k); + absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[13], &this->o_batch_stride); + absl::SimpleAtoi(params_list[14], &this->o_row_stride); + absl::SimpleAtoi(params_list[15], &this->o_head_stride); + absl::SimpleAtoi(params_list[16], &this->b); + absl::SimpleAtoi(params_list[17], &this->seqlen_q); + absl::SimpleAtoi(params_list[18], &this->seqlen_k); + absl::SimpleAtoi(params_list[19], &this->d); + absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[22], &this->d_rounded); + absl::SimpleAtof(params_list[23], &this->scale_softmax); + absl::SimpleAtof(params_list[24], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[25], &this->p_dropout); + uint32_t tmp; + absl::SimpleAtoi(params_list[26], &tmp); + this->p_dropout_in_uint8_t = uint8_t(tmp); + absl::SimpleAtof(params_list[27], &this->rp_dropout); + absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[29], &this->is_bf16); + absl::SimpleAtob(params_list[30], &this->is_causal); + + // backward specific params + absl::SimpleAtoi(params_list[31], &this->do_batch_stride); + absl::SimpleAtoi(params_list[32], &this->do_row_stride); + absl::SimpleAtoi(params_list[33], &this->do_head_stride); + absl::SimpleAtoi(params_list[34], &this->dq_batch_stride); + absl::SimpleAtoi(params_list[35], &this->dk_batch_stride); + absl::SimpleAtoi(params_list[36], &this->dv_batch_stride); + absl::SimpleAtoi(params_list[37], &this->dq_row_stride); + absl::SimpleAtoi(params_list[38], &this->dk_row_stride); + absl::SimpleAtoi(params_list[39], &this->dv_row_stride); + absl::SimpleAtoi(params_list[40], &this->dq_head_stride); + absl::SimpleAtoi(params_list[41], &this->dk_head_stride); + absl::SimpleAtoi(params_list[42], &this->dv_head_stride); + } +}; + +void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, + const bool configure) { + FP16_SWITCH(!params.is_bf16, [&] { + if (params.d <= 32) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 64) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 96) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 128) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 160) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 192) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 224) { + run_mha_bwd_(params, stream, configure); + } else if (params.d <= 256) { + run_mha_bwd_(params, stream, configure); + } + }); +} + +// Layout of `buffers` listed above: +// buffers[0] = dout +// buffers[1] = q +// buffers[2] = k +// buffers[3] = v +// buffers[4] = out +// buffers[5] = softmax_lse +// buffers[6] = cu_seqlens_q +// buffers[7] = cu_seqlens_k +// buffers[8] = dq // this is output +// buffers[9] = dk // this is output +// buffers[10] = dv // this is output +// buffers[11] = softmax_d // this is output +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + void* customAttrs) { + auto attr = getOrParsePDLAttr(ctx, customAttrs, + "custom_call_flash_attention_backward"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + std::string backend_config = + dictAttr.get("backend_config").template as().getValue(); + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + + int softmax_element_count = 1, q_element_count = 1, k_element_count = 1, + v_element_count = 1; + for (int i = 0; i < M; i++) { + q_element_count *= q.sizes[i]; + k_element_count *= k.sizes[i]; + v_element_count *= v.sizes[i]; + softmax_element_count *= softmax_lse.sizes[i]; + } + + auto dq_ptr = static_cast( + gpu_driver->alloc(ctx, q_element_count * sizeof(T_IN))); + auto dq_res = assignMemRef(dq_ptr, q.sizes); + + auto dk_ptr = static_cast( + gpu_driver->alloc(ctx, k_element_count * sizeof(T_IN))); + auto dk_res = assignMemRef(dk_ptr, k.sizes); + + auto dv_ptr = static_cast( + gpu_driver->alloc(ctx, v_element_count * sizeof(T_IN))); + auto dv_res = assignMemRef(dv_ptr, v.sizes); + + auto dsoftmax_ptr = static_cast( + gpu_driver->alloc(ctx, softmax_element_count * sizeof(SOFT_MAX_TYPE))); + auto dsoftmax = + assignMemRef(dsoftmax_ptr, softmax_lse.sizes); + + FlashAttentionBackwardParams params; + params.FromString(std::move(backend_config)); + Flash_bwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data; + launch_params.k_ptr = k.data; + launch_params.v_ptr = v.data; + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_ptr = out.data; + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = static_cast(seqlens_q.data); + launch_params.cu_seqlens_k = static_cast(seqlens_k.data); + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Softmax sum + launch_params.softmax_lse_ptr = softmax_lse.data; + + // Set the dimensions. + launch_params.b = params.b; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = params.seqlen_q; + launch_params.seqlen_k = params.seqlen_k; + launch_params.seqlen_q_rounded = params.seqlen_q_rounded; + launch_params.seqlen_k_rounded = params.seqlen_k_rounded; + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + launch_params.is_causal = params.is_causal; + + launch_params.do_ptr = dout.data; + launch_params.do_row_stride = params.do_row_stride; + launch_params.do_head_stride = params.do_head_stride; + launch_params.dq_ptr = dq_res.data; + launch_params.dk_ptr = dk_res.data; + launch_params.dv_ptr = dv_res.data; + launch_params.dq_row_stride = params.dq_row_stride; + launch_params.dk_row_stride = params.dk_row_stride; + launch_params.dv_row_stride = params.dv_row_stride; + launch_params.dq_head_stride = params.dq_head_stride; + launch_params.dk_head_stride = params.dk_head_stride; + launch_params.dv_head_stride = params.dv_head_stride; + + // bool loop = max_seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + auto scalar_type = params.is_bf16 ? torch::kBFloat16 : torch::kFloat16; + auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); + at::Tensor dq_accum; + if (loop) { + dq_accum = + torch::empty({launch_params.b, launch_params.h, + launch_params.seqlen_q_rounded, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } + + at::Tensor dk = torch::from_blob( + dk_res.data, {params.total_k, launch_params.h_k, launch_params.d}, opts); + at::Tensor dv = torch::from_blob( + dv_res.data, {params.total_k, launch_params.h_k, launch_params.d}, opts); + + at::Tensor dk_expanded, dv_expanded; + + if (launch_params.h_k != launch_params.h) { // MQA / GQA + TF_VLOG(2) << "Running FlashAttention Backward as MQA/GQA"; + dk_expanded = + torch::empty({params.total_k, launch_params.h, launch_params.d}, opts); + dv_expanded = + torch::empty({params.total_k, launch_params.h, launch_params.d}, opts); + + launch_params.dk_ptr = dk_expanded.data_ptr(); + launch_params.dv_ptr = dv_expanded.data_ptr(); + launch_params.dk_row_stride = dk_expanded.stride(-3); + launch_params.dv_row_stride = dv_expanded.stride(-3); + launch_params.dk_head_stride = dk_expanded.stride(-2); + launch_params.dv_head_stride = dv_expanded.stride(-2); + } else { + TF_VLOG(2) << "Running FlashAttention Backward"; + dk_expanded = dk; + dv_expanded = dv; + } + + launch_params.dq_accum_ptr = loop ? dq_accum.data_ptr() : nullptr; + launch_params.dk_accum_ptr = nullptr; + launch_params.dv_accum_ptr = nullptr; + + // Softmax sum + launch_params.dsoftmax_sum = dsoftmax.data; + + auto launch = &run_mha_bwd; + + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + + bool is_dropout = (1.f - launch_params.p_dropout) > 0.0; + if (is_dropout) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + launch_params.philox_args = gen->philox_cuda_state(counter_offset); + } + + launch(launch_params, gpu_stream, /*configure=*/false); + + // For MQA/GQA we need to sum dK and dV across the groups + if (launch_params.h_k != launch_params.h) { + at::sum_out(dk, + at::reshape(dk_expanded, {params.total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + at::sum_out(dv, + at::reshape(dv_expanded, {params.total_k, launch_params.h_k, + launch_params.h / launch_params.h_k, + launch_params.d}), + {2}); + } + + return std::make_tuple(dq_res, dk_res, dv_res, dsoftmax); +} + +TAO_RAL_API("custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward); +TAO_RAL_API("custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward); + +} // namespace ral +} // namespace tao \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc new file mode 100644 index 00000000000..7d5ded3ebeb --- /dev/null +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc @@ -0,0 +1,250 @@ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "cutlass/numeric_types.h" +#include "flash.h" +#include "mlir/ral/context/pdll_util.h" +#include "mlir/ral/context/stream_executor_based_impl.h" +#include "static_switch.h" +#include "torch_xla/csrc/runtime/tf_logging.h" + +namespace tao { +namespace ral { + +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); + +struct FlashAttentionForwardParams { + using index_t = uint32_t; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k + // could be different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + int total_k; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + bool is_bf16; + bool is_causal; + + void FromString(const std::string& str) { + std::vector params_list = absl::StrSplit(str, "|"); + TORCH_CHECK(params_list.size() >= 31); // at least 31 variables + absl::SimpleAtoi(params_list[0], &this->q_batch_stride); + absl::SimpleAtoi(params_list[1], &this->k_batch_stride); + absl::SimpleAtoi(params_list[2], &this->v_batch_stride); + absl::SimpleAtoi(params_list[3], &this->q_row_stride); + absl::SimpleAtoi(params_list[4], &this->k_row_stride); + absl::SimpleAtoi(params_list[5], &this->v_row_stride); + absl::SimpleAtoi(params_list[6], &this->q_head_stride); + absl::SimpleAtoi(params_list[7], &this->k_head_stride); + absl::SimpleAtoi(params_list[8], &this->v_head_stride); + absl::SimpleAtoi(params_list[9], &this->total_k); + absl::SimpleAtoi(params_list[10], &this->h); + absl::SimpleAtoi(params_list[11], &this->h_k); + absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[13], &this->o_batch_stride); + absl::SimpleAtoi(params_list[14], &this->o_row_stride); + absl::SimpleAtoi(params_list[15], &this->o_head_stride); + absl::SimpleAtoi(params_list[16], &this->b); + absl::SimpleAtoi(params_list[17], &this->seqlen_q); + absl::SimpleAtoi(params_list[18], &this->seqlen_k); + absl::SimpleAtoi(params_list[19], &this->d); + absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[22], &this->d_rounded); + absl::SimpleAtof(params_list[23], &this->scale_softmax); + absl::SimpleAtof(params_list[24], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[25], &this->p_dropout); + uint32_t tmp; + absl::SimpleAtoi(params_list[26], &tmp); + this->p_dropout_in_uint8_t = uint8_t(tmp); + absl::SimpleAtof(params_list[27], &this->rp_dropout); + absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[29], &this->is_bf16); + absl::SimpleAtob(params_list[30], &this->is_causal); + } +}; + +// Layout of `buffers` listed above: +// buffers[0] = q +// buffers[1] = k +// buffers[2] = v +// buffers[3] = cu_seqlens_q +// buffers[4] = cu_seqlens_k +// result[0] = softmax_lse // this is output +// result[1] = out_for_output // this is output +template +std::tuple, MemRefType> +custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, + MemRefType q, + MemRefType k, + MemRefType v, + MemRefType seqlens_q, + MemRefType seqlens_k, + void* customAttrs) { + auto attr = getOrParsePDLAttr(ctx, customAttrs, + "custom_call_flash_attention_forward"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + std::string backend_config = + dictAttr.get("backend_config").template as().getValue(); + + auto gpu_driver = ctx->getDriver( + tao::ral::gpu::GPUDriver::name()); + auto gpu_stream = + static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + + int output_element_count = 1; + for (int i = 0; i < M; i++) { + output_element_count *= q.sizes[i]; + } + + int bs = seqlens_q.sizes[0] - 1; + int nheads = q.sizes[1]; + int seqlen = q.sizes[0] / bs; + std::vector softmax_lse_sizes{bs, nheads, seqlen}; + + auto softmax_lse_ptr = static_cast( + gpu_driver->alloc(ctx, bs * nheads * seqlen * sizeof(SOFT_MAX_TYPE))); + auto softmax_lse = + assignMemRef(softmax_lse_ptr, softmax_lse_sizes); + + auto output_ptr = static_cast( + gpu_driver->alloc(ctx, output_element_count * sizeof(T_IN))); + auto output = assignMemRef(output_ptr, q.sizes); + + FlashAttentionForwardParams params; + params.FromString(std::move(backend_config)); + + Flash_fwd_params launch_params; + + // Reset the parameters + memset(&launch_params, 0, sizeof(launch_params)); + + launch_params.is_bf16 = params.is_bf16; + + // Set the pointers and strides. + launch_params.q_ptr = q.data; + launch_params.k_ptr = k.data; + launch_params.v_ptr = v.data; + // All stride are in elements, not bytes. + launch_params.q_row_stride = params.q_row_stride; + launch_params.k_row_stride = params.k_row_stride; + launch_params.v_row_stride = params.v_row_stride; + launch_params.q_head_stride = params.q_head_stride; + launch_params.k_head_stride = params.k_head_stride; + launch_params.v_head_stride = params.v_head_stride; + launch_params.o_ptr = output.data; + launch_params.o_row_stride = params.o_row_stride; + launch_params.o_head_stride = params.o_head_stride; + + launch_params.cu_seqlens_q = seqlens_q.data; + launch_params.cu_seqlens_k = seqlens_k.data; + + // P = softmax(QK^T) + launch_params.p_ptr = nullptr; // no softmax returned always + + // Softmax sum + launch_params.softmax_lse_ptr = softmax_lse.data; + + // Set the dimensions. + launch_params.b = params.b; + launch_params.h = params.h; + launch_params.h_k = params.h_k; + launch_params.h_h_k_ratio = params.h_h_k_ratio; + launch_params.seqlen_q = params.seqlen_q; + launch_params.seqlen_k = params.seqlen_k; + launch_params.seqlen_q_rounded = params.seqlen_q_rounded; + launch_params.seqlen_k_rounded = params.seqlen_k_rounded; + launch_params.d = params.d; + launch_params.d_rounded = params.d_rounded; + + // Set the different scale values. + launch_params.scale_softmax = params.scale_softmax; + launch_params.scale_softmax_log2 = params.scale_softmax_log2; + + launch_params.p_dropout = params.p_dropout; + launch_params.p_dropout_in_uint8_t = params.p_dropout_in_uint8_t; + launch_params.rp_dropout = params.rp_dropout; + launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; + + launch_params.is_causal = params.is_causal; + + if ((1.f - launch_params.p_dropout) > 0.0) { + // number of times random will be generated per thread, to offset philox + // counter in thc random state We use a custom RNG that increases the offset + // by batch_size * nheads * 32. + int64_t counter_offset = launch_params.b * launch_params.h * 32; + auto gen = at::get_generator_or_default( + c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + launch_params.philox_args = gen->philox_cuda_state(counter_offset); + } + + FP16_SWITCH(!launch_params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(launch_params.d, [&] { + run_mha_fwd_(launch_params, gpu_stream); + }); + }); + + return std::make_tuple(softmax_lse, output); +} + +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward); +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward); + +} // namespace ral +} // namespace tao From 29d1ae6946c3555fce639141ebb7ca38ccd7191f Mon Sep 17 00:00:00 2001 From: Baole Ai Date: Sun, 7 Apr 2024 14:51:31 +0800 Subject: [PATCH 07/12] fix disc flag when complie python (#39) * fix bazel flag when complie python * fix lint. --- setup.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 51b0b3d52d9..5d9fceacfe4 100644 --- a/setup.py +++ b/setup.py @@ -253,17 +253,18 @@ def bazel_build(self, ext): # package BladeDISC distribution files # please note, TorchBlade also create some symbolic links to 'torch_blade' dir - disc_ral_so_name = 'libral_base_context.so' - bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' - shutil.copyfile( - os.path.join(bazel_bin_path, disc_ral_so_name), - '/'.join([ext_dest_dir, disc_ral_so_name])) - - disc_customop_so_name = 'libdisc_custom_ops.so' - bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' - shutil.copyfile( - os.path.join(bazel_bin_path, disc_customop_so_name), - '/'.join([ext_dest_dir, disc_customop_so_name])) + if build_util.check_env_flag('ENABLE_DISC', 'true'): + disc_ral_so_name = 'libral_base_context.so' + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' + shutil.copyfile( + os.path.join(bazel_bin_path, disc_ral_so_name), + '/'.join([ext_dest_dir, disc_ral_so_name])) + + disc_customop_so_name = 'libdisc_custom_ops.so' + bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' + shutil.copyfile( + os.path.join(bazel_bin_path, disc_customop_so_name), + '/'.join([ext_dest_dir, disc_customop_so_name])) class Develop(develop.develop): From e1b8ed776202539acd9c217fae8056c5fbc0194e Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 22 May 2024 11:50:25 +0800 Subject: [PATCH 08/12] support bf16 on disc backend (#40) add float-norm pass to support bf16 amp training --- third_party/BladeDISC | 2 +- torch_xla/csrc/runtime/BUILD | 2 + .../custom_call_flash_attention_backward.cc | 4 ++ .../custom_call_flash_attention_forward.cc | 2 + .../csrc/runtime/disc_computation_client.cc | 48 +++++++++++++++++-- 5 files changed, 53 insertions(+), 5 deletions(-) diff --git a/third_party/BladeDISC b/third_party/BladeDISC index 67c324289c3..fbe39bce9ae 160000 --- a/third_party/BladeDISC +++ b/third_party/BladeDISC @@ -1 +1 @@ -Subproject commit 67c324289c36da5187405c18600403a0d3681b61 +Subproject commit fbe39bce9ae2d365d77842af38a33fa76d37237a diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 9bd276c856f..5d2992bf802 100755 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -167,6 +167,8 @@ cc_library( "@xla//xla/client:xla_computation", "//torch_xla/csrc/runtime/disc:disc_ral", "//torch_xla/csrc/runtime/disc:disc_compile", + "@xla//xla/service:float_normalization", + "@xla//xla/service/gpu:gpu_float_support", ], ) diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc index efd4f775f48..8f4460d8145 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc @@ -27,6 +27,7 @@ namespace tao { namespace ral { DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); struct FlashAttentionBackwardParams { using index_t = uint32_t; @@ -235,6 +236,7 @@ custom_call_flash_attention_backward( memset(&launch_params, 0, sizeof(launch_params)); launch_params.is_bf16 = params.is_bf16; + launch_params.is_bf16 = true; // Set the pointers and strides. launch_params.q_ptr = q.data; @@ -380,6 +382,8 @@ TAO_RAL_API("custom_call_flash_attention_backward", "gpu", custom_call_flash_attention_backward); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", custom_call_flash_attention_backward); +TAO_RAL_API("custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward); } // namespace ral } // namespace tao \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc index 7d5ded3ebeb..ca281319b85 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc @@ -245,6 +245,8 @@ TAO_RAL_API("custom_call_flash_attention_forward", "gpu", custom_call_flash_attention_forward); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", custom_call_flash_attention_forward); +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward); } // namespace ral } // namespace tao diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index d4066f44117..dbf5ca065c4 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -18,6 +19,11 @@ #include "torch_xla/csrc/runtime/disc/disc_compile.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/float_normalization.h" +#include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/service/hlo_proto_util.h" namespace torch_xla { namespace runtime { @@ -172,10 +178,44 @@ std::vector DISCComputationClient::Compile( mlir::MLIRContext context; mlir::ModuleOp mlir_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); - auto status = torch_xla::ConvertHloToMhlo( - instance.computation.mutable_proto(), &mlir_module); - XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n" - << status.message(); + + auto hlo_proto = instance.computation.proto(); + auto program_shape = instance.computation.GetProgramShape().value(); + xla::HloModuleConfig module_config(program_shape); + module_config.set_debug_options(xla::GetDebugOptionsFromFlags()); + xla::ComputationLayout* entry_layout = + module_config.mutable_entry_computation_layout(); + for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) { + auto status = + entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + program_shape.parameters(i)); + if (!status.ok()) { + XLA_ERROR() << "Error copying layout from shape: "; + return {}; + } + } + + std::unique_ptr hlo_module = + xla::CreateModuleFromProto(hlo_proto, module_config).value(); + xla::HloPassPipeline pipeline("pre-stablehlo"); + stream_executor::CudaComputeCapability gpu_version; + auto dprops = at::cuda::getCurrentDeviceProperties(); + gpu_version.major = dprops->major; + gpu_version.minor = dprops->minor; + xla::gpu::GpuFloatSupport bf16_support(gpu_version, xla::BF16); + pipeline.AddPass(&bf16_support); + auto status = pipeline.Run(hlo_module.get()).status(); + if (!status.ok()) { + XLA_ERROR() << "Error running pre-stablehlo pass pipeline: "; + return {}; + } + { + auto mutable_hlo_proto = hlo_module->ToProto(); + auto status = + torch_xla::ConvertHloToMhlo(&mutable_hlo_proto, &mlir_module); + XLA_CHECK(status.ok()) << "StableHLO -> MHLO conversion failed.\n" + << status.message(); + } // Add input and output attributes auto entry_func_identifier = From 41baac7e6724a4b147b5b4f9e5e9131a0ee184b4 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Mon, 12 Aug 2024 14:45:57 +0800 Subject: [PATCH 09/12] Support Flash Attention 2.5.6 for disc backend (#4) --- test/test_flash_attention_backward.py | 7 +- .../csrc/ops/flash_attention_forward.cpp | 2 +- .../custom_call_flash_attention_backward.cc | 219 ++++++++++++------ .../custom_call_flash_attention_forward.cc | 151 +++++++++--- 4 files changed, 266 insertions(+), 113 deletions(-) diff --git a/test/test_flash_attention_backward.py b/test/test_flash_attention_backward.py index df8c15efc81..35b8a7c37fa 100755 --- a/test/test_flash_attention_backward.py +++ b/test/test_flash_attention_backward.py @@ -1,4 +1,3 @@ -import os import sys import unittest @@ -150,15 +149,13 @@ def test_flash_attn_gqa_backward_fp16(self): self._backward_internal(torch.float16, n_heads_kv=int(N_HEADS // 2)) def test_flash_attn_gqa_backward_bf16(self): - if not os.environ.get('DISC_DEVICE'): - self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2)) + self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2)) def test_flash_attn_backward_fp16(self): self._backward_internal(torch.float16, n_heads_kv=N_HEADS) def test_flash_attn_backward_bf16(self): - if not os.environ.get('DISC_DEVICE'): - self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS) + self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS) def test_flash_attn_gqa_backward_fp16_alibi(self): self._backward_internal( diff --git a/torch_xla/csrc/ops/flash_attention_forward.cpp b/torch_xla/csrc/ops/flash_attention_forward.cpp index 5c478f69a51..9a73f26a9ba 100644 --- a/torch_xla/csrc/ops/flash_attention_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_forward.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(int batch_size, int num_heads, int seqlen_q, xla::PrimitiveType::F32, {batch_size, num_heads, seqlen_q}); xla::Shape out_shape = GetXlaShape(q); xla::Shape rng_state_shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2}); + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2}); return xla::ShapeUtil::MakeTupleShape( {softmax_lse_shape, out_shape, rng_state_shape}); } diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc index 8f4460d8145..402eeedc3de 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc @@ -27,7 +27,6 @@ namespace tao { namespace ral { DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); -DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); struct FlashAttentionBackwardParams { using index_t = uint32_t; @@ -57,6 +56,7 @@ struct FlashAttentionBackwardParams { // The dimensions. int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int total_q; int total_k; // The scaling factors for the kernel. @@ -73,6 +73,12 @@ struct FlashAttentionBackwardParams { bool is_bf16; bool is_causal; + int window_size_left; + int window_size_right; + int alibi_slopes_batch_stride; + bool enable_alibi_slopes; + bool is_seqlens_k_cumulative; + int num_splits; // Backward specific params index_t do_batch_stride; @@ -88,9 +94,11 @@ struct FlashAttentionBackwardParams { index_t dk_head_stride; index_t dv_head_stride; + bool deterministic; + void FromString(const std::string& str) { std::vector params_list = absl::StrSplit(str, "|"); - TORCH_CHECK(params_list.size() == 43); + TORCH_CHECK(params_list.size() == 51); // Forward specific param absl::SimpleAtoi(params_list[0], &this->q_batch_stride); @@ -102,67 +110,61 @@ struct FlashAttentionBackwardParams { absl::SimpleAtoi(params_list[6], &this->q_head_stride); absl::SimpleAtoi(params_list[7], &this->k_head_stride); absl::SimpleAtoi(params_list[8], &this->v_head_stride); - absl::SimpleAtoi(params_list[9], &this->total_k); - absl::SimpleAtoi(params_list[10], &this->h); - absl::SimpleAtoi(params_list[11], &this->h_k); - absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio); - absl::SimpleAtoi(params_list[13], &this->o_batch_stride); - absl::SimpleAtoi(params_list[14], &this->o_row_stride); - absl::SimpleAtoi(params_list[15], &this->o_head_stride); - absl::SimpleAtoi(params_list[16], &this->b); - absl::SimpleAtoi(params_list[17], &this->seqlen_q); - absl::SimpleAtoi(params_list[18], &this->seqlen_k); - absl::SimpleAtoi(params_list[19], &this->d); - absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded); - absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded); - absl::SimpleAtoi(params_list[22], &this->d_rounded); - absl::SimpleAtof(params_list[23], &this->scale_softmax); - absl::SimpleAtof(params_list[24], &this->scale_softmax_log2); - absl::SimpleAtof(params_list[25], &this->p_dropout); + absl::SimpleAtoi(params_list[9], &this->total_q); + absl::SimpleAtoi(params_list[10], &this->total_k); + absl::SimpleAtoi(params_list[11], &this->h); + absl::SimpleAtoi(params_list[12], &this->h_k); + absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[14], &this->o_batch_stride); + absl::SimpleAtoi(params_list[15], &this->o_row_stride); + absl::SimpleAtoi(params_list[16], &this->o_head_stride); + absl::SimpleAtoi(params_list[17], &this->b); + absl::SimpleAtoi(params_list[18], &this->seqlen_q); + absl::SimpleAtoi(params_list[19], &this->seqlen_k); + absl::SimpleAtoi(params_list[20], &this->d); + absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[23], &this->d_rounded); + absl::SimpleAtof(params_list[24], &this->scale_softmax); + absl::SimpleAtof(params_list[25], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[26], &this->p_dropout); uint32_t tmp; - absl::SimpleAtoi(params_list[26], &tmp); + absl::SimpleAtoi(params_list[27], &tmp); this->p_dropout_in_uint8_t = uint8_t(tmp); - absl::SimpleAtof(params_list[27], &this->rp_dropout); - absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout); - absl::SimpleAtob(params_list[29], &this->is_bf16); - absl::SimpleAtob(params_list[30], &this->is_causal); + absl::SimpleAtof(params_list[28], &this->rp_dropout); + absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[30], &this->is_bf16); + absl::SimpleAtob(params_list[31], &this->is_causal); + absl::SimpleAtoi(params_list[32], &this->window_size_left); + absl::SimpleAtoi(params_list[33], &this->window_size_right); + absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride); + absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative); + absl::SimpleAtoi(params_list[36], &this->num_splits); + absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes); // backward specific params - absl::SimpleAtoi(params_list[31], &this->do_batch_stride); - absl::SimpleAtoi(params_list[32], &this->do_row_stride); - absl::SimpleAtoi(params_list[33], &this->do_head_stride); - absl::SimpleAtoi(params_list[34], &this->dq_batch_stride); - absl::SimpleAtoi(params_list[35], &this->dk_batch_stride); - absl::SimpleAtoi(params_list[36], &this->dv_batch_stride); - absl::SimpleAtoi(params_list[37], &this->dq_row_stride); - absl::SimpleAtoi(params_list[38], &this->dk_row_stride); - absl::SimpleAtoi(params_list[39], &this->dv_row_stride); - absl::SimpleAtoi(params_list[40], &this->dq_head_stride); - absl::SimpleAtoi(params_list[41], &this->dk_head_stride); - absl::SimpleAtoi(params_list[42], &this->dv_head_stride); + const int offset = 38; // FlashAttentionForwardParams has 38 variables + absl::SimpleAtoi(params_list[offset + 0], &this->do_batch_stride); + absl::SimpleAtoi(params_list[offset + 1], &this->do_row_stride); + absl::SimpleAtoi(params_list[offset + 2], &this->do_head_stride); + absl::SimpleAtoi(params_list[offset + 3], &this->dq_batch_stride); + absl::SimpleAtoi(params_list[offset + 4], &this->dk_batch_stride); + absl::SimpleAtoi(params_list[offset + 5], &this->dv_batch_stride); + absl::SimpleAtoi(params_list[offset + 6], &this->dq_row_stride); + absl::SimpleAtoi(params_list[offset + 7], &this->dk_row_stride); + absl::SimpleAtoi(params_list[offset + 8], &this->dv_row_stride); + absl::SimpleAtoi(params_list[offset + 9], &this->dq_head_stride); + absl::SimpleAtoi(params_list[offset + 10], &this->dk_head_stride); + absl::SimpleAtoi(params_list[offset + 11], &this->dv_head_stride); + absl::SimpleAtob(params_list[offset + 12], &this->deterministic); } }; void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, const bool configure) { FP16_SWITCH(!params.is_bf16, [&] { - if (params.d <= 32) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 64) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 96) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 128) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 160) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 192) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 224) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 256) { - run_mha_bwd_(params, stream, configure); - } + HEADDIM_SWITCH(params.d, + [&] { run_mha_bwd_(params, stream); }); }); } @@ -175,18 +177,21 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream, // buffers[5] = softmax_lse // buffers[6] = cu_seqlens_q // buffers[7] = cu_seqlens_k -// buffers[8] = dq // this is output -// buffers[9] = dk // this is output -// buffers[10] = dv // this is output -// buffers[11] = softmax_d // this is output +// buffers[8] = rng_state +// buffers[9] = alibi_slopes +// buffers[10] = dq // this is output +// buffers[11] = dk // this is output +// buffers[12] = dv // this is output +// buffers[13] = softmax_d // this is output template std::tuple, MemRefType, MemRefType, MemRefType> -custom_call_flash_attention_backward( +custom_call_flash_attention_backward_impl( ExecutionContext* ctx, void* stream_handle, MemRefType dout, MemRefType q, MemRefType k, MemRefType v, MemRefType out, MemRefType softmax_lse, MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, void* alibi_slopes_ptr, void* customAttrs) { auto attr = getOrParsePDLAttr(ctx, customAttrs, "custom_call_flash_attention_backward"); @@ -236,7 +241,6 @@ custom_call_flash_attention_backward( memset(&launch_params, 0, sizeof(launch_params)); launch_params.is_bf16 = params.is_bf16; - launch_params.is_bf16 = true; // Set the pointers and strides. launch_params.q_ptr = q.data; @@ -256,6 +260,9 @@ custom_call_flash_attention_backward( launch_params.cu_seqlens_q = static_cast(seqlens_q.data); launch_params.cu_seqlens_k = static_cast(seqlens_k.data); + launch_params.alibi_slopes_ptr = alibi_slopes_ptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; + // P = softmax(QK^T) launch_params.p_ptr = nullptr; // no softmax returned always @@ -284,6 +291,10 @@ custom_call_flash_attention_backward( launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; launch_params.is_causal = params.is_causal; + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = true; launch_params.do_ptr = dout.data; launch_params.do_row_stride = params.do_row_stride; @@ -305,10 +316,19 @@ custom_call_flash_attention_backward( auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); at::Tensor dq_accum; if (loop) { - dq_accum = - torch::empty({launch_params.b, launch_params.h, - launch_params.seqlen_q_rounded, launch_params.d_rounded}, - opts.dtype(at::kFloat)); + if (!params.deterministic) { + dq_accum = torch::empty({params.total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } else { + auto dprops = at::cuda::getCurrentDeviceProperties(); + const int nsplits = (dprops->multiProcessorCount + + launch_params.b * launch_params.h - 1) / + (launch_params.b * launch_params.h); + dq_accum = torch::zeros({nsplits, params.total_q + 128 * launch_params.b, + launch_params.h, launch_params.d_rounded}, + opts.dtype(at::kFloat)); + } } at::Tensor dk = torch::from_blob( @@ -344,6 +364,10 @@ custom_call_flash_attention_backward( // Softmax sum launch_params.dsoftmax_sum = dsoftmax.data; + launch_params.deterministic = params.deterministic; + launch_params.dq_accum_split_stride = + !launch_params.deterministic ? 0 : dq_accum.stride(0); + auto launch = &run_mha_bwd; auto gen = at::get_generator_or_default( @@ -353,11 +377,11 @@ custom_call_flash_attention_backward( int64_t counter_offset = launch_params.b * launch_params.h * 32; bool is_dropout = (1.f - launch_params.p_dropout) > 0.0; - if (is_dropout) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - launch_params.philox_args = gen->philox_cuda_state(counter_offset); - } + // TODO(wenting.swt): According to the implementation in + // `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates + // `rng_state` which is passed as ctx to the backward. Hence, for simplifying + // the logic, the redundant branch where `rng_state` is None has been omitted. + launch_params.rng_state = reinterpret_cast(rng_state.data); launch(launch_params, gpu_stream, /*configure=*/false); @@ -378,12 +402,65 @@ custom_call_flash_attention_backward( return std::make_tuple(dq_res, dk_res, dv_res, dsoftmax); } +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_noalibi( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, nullptr, customAttrs); +} + +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_alibi_v1( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, MemRefType alibi_slopes, + void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, alibi_slopes.data, customAttrs); +} + +template +std::tuple, MemRefType, MemRefType, + MemRefType> +custom_call_flash_attention_backward_alibi_v2( + ExecutionContext* ctx, void* stream_handle, MemRefType dout, + MemRefType q, MemRefType k, MemRefType v, + MemRefType out, MemRefType softmax_lse, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType rng_state, MemRefType alibi_slopes, + void* customAttrs) { + return custom_call_flash_attention_backward_impl( + ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k, + rng_state, alibi_slopes.data, customAttrs); +} + +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_noalibi); +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v1); +TAO_RAL_API( + "custom_call_flash_attention_backward", "gpu", + custom_call_flash_attention_backward_alibi_v2); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", - custom_call_flash_attention_backward); + custom_call_flash_attention_backward_noalibi); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", - custom_call_flash_attention_backward); + custom_call_flash_attention_backward_alibi_v1); TAO_RAL_API("custom_call_flash_attention_backward", "gpu", - custom_call_flash_attention_backward); + custom_call_flash_attention_backward_alibi_v2); } // namespace ral } // namespace tao \ No newline at end of file diff --git a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc index ca281319b85..fcac32fa5c3 100644 --- a/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc +++ b/torch_xla/csrc/runtime/disc/custom_call_flash_attention_forward.cc @@ -57,6 +57,7 @@ struct FlashAttentionForwardParams { // The dimensions. int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int total_q; int total_k; // The scaling factors for the kernel. @@ -73,10 +74,16 @@ struct FlashAttentionForwardParams { bool is_bf16; bool is_causal; + int window_size_left; + int window_size_right; + int alibi_slopes_batch_stride; + bool enable_alibi_slopes; + bool is_seqlens_k_cumulative; + int num_splits; void FromString(const std::string& str) { std::vector params_list = absl::StrSplit(str, "|"); - TORCH_CHECK(params_list.size() >= 31); // at least 31 variables + TORCH_CHECK(params_list.size() >= 38); // at least 38 variables absl::SimpleAtoi(params_list[0], &this->q_batch_stride); absl::SimpleAtoi(params_list[1], &this->k_batch_stride); absl::SimpleAtoi(params_list[2], &this->v_batch_stride); @@ -86,30 +93,37 @@ struct FlashAttentionForwardParams { absl::SimpleAtoi(params_list[6], &this->q_head_stride); absl::SimpleAtoi(params_list[7], &this->k_head_stride); absl::SimpleAtoi(params_list[8], &this->v_head_stride); - absl::SimpleAtoi(params_list[9], &this->total_k); - absl::SimpleAtoi(params_list[10], &this->h); - absl::SimpleAtoi(params_list[11], &this->h_k); - absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio); - absl::SimpleAtoi(params_list[13], &this->o_batch_stride); - absl::SimpleAtoi(params_list[14], &this->o_row_stride); - absl::SimpleAtoi(params_list[15], &this->o_head_stride); - absl::SimpleAtoi(params_list[16], &this->b); - absl::SimpleAtoi(params_list[17], &this->seqlen_q); - absl::SimpleAtoi(params_list[18], &this->seqlen_k); - absl::SimpleAtoi(params_list[19], &this->d); - absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded); - absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded); - absl::SimpleAtoi(params_list[22], &this->d_rounded); - absl::SimpleAtof(params_list[23], &this->scale_softmax); - absl::SimpleAtof(params_list[24], &this->scale_softmax_log2); - absl::SimpleAtof(params_list[25], &this->p_dropout); + absl::SimpleAtoi(params_list[9], &this->total_q); + absl::SimpleAtoi(params_list[10], &this->total_k); + absl::SimpleAtoi(params_list[11], &this->h); + absl::SimpleAtoi(params_list[12], &this->h_k); + absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio); + absl::SimpleAtoi(params_list[14], &this->o_batch_stride); + absl::SimpleAtoi(params_list[15], &this->o_row_stride); + absl::SimpleAtoi(params_list[16], &this->o_head_stride); + absl::SimpleAtoi(params_list[17], &this->b); + absl::SimpleAtoi(params_list[18], &this->seqlen_q); + absl::SimpleAtoi(params_list[19], &this->seqlen_k); + absl::SimpleAtoi(params_list[20], &this->d); + absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded); + absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded); + absl::SimpleAtoi(params_list[23], &this->d_rounded); + absl::SimpleAtof(params_list[24], &this->scale_softmax); + absl::SimpleAtof(params_list[25], &this->scale_softmax_log2); + absl::SimpleAtof(params_list[26], &this->p_dropout); uint32_t tmp; - absl::SimpleAtoi(params_list[26], &tmp); + absl::SimpleAtoi(params_list[27], &tmp); this->p_dropout_in_uint8_t = uint8_t(tmp); - absl::SimpleAtof(params_list[27], &this->rp_dropout); - absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout); - absl::SimpleAtob(params_list[29], &this->is_bf16); - absl::SimpleAtob(params_list[30], &this->is_causal); + absl::SimpleAtof(params_list[28], &this->rp_dropout); + absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout); + absl::SimpleAtob(params_list[30], &this->is_bf16); + absl::SimpleAtob(params_list[31], &this->is_causal); + absl::SimpleAtoi(params_list[32], &this->window_size_left); + absl::SimpleAtoi(params_list[33], &this->window_size_right); + absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride); + absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative); + absl::SimpleAtoi(params_list[36], &this->num_splits); + absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes); } }; @@ -122,14 +136,13 @@ struct FlashAttentionForwardParams { // result[0] = softmax_lse // this is output // result[1] = out_for_output // this is output template -std::tuple, MemRefType> -custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, - MemRefType q, - MemRefType k, - MemRefType v, - MemRefType seqlens_q, - MemRefType seqlens_k, - void* customAttrs) { +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_impl( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + void* alibi_slopes_ptr, void* customAttrs) { auto attr = getOrParsePDLAttr(ctx, customAttrs, "custom_call_flash_attention_forward"); if (!attr) { @@ -163,6 +176,13 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, gpu_driver->alloc(ctx, output_element_count * sizeof(T_IN))); auto output = assignMemRef(output_ptr, q.sizes); + auto rng_state_ptr = + static_cast(gpu_driver->alloc(ctx, 2 * sizeof(int64_t))); + auto rng_state = + assignMemRef(rng_state_ptr, std::vector{2}); + + cudaMemsetAsync(rng_state_ptr, 0, 2 * sizeof(int64_t), gpu_stream); + FlashAttentionForwardParams params; params.FromString(std::move(backend_config)); @@ -190,6 +210,8 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, launch_params.cu_seqlens_q = seqlens_q.data; launch_params.cu_seqlens_k = seqlens_k.data; + launch_params.alibi_slopes_ptr = alibi_slopes_ptr; + launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride; // P = softmax(QK^T) launch_params.p_ptr = nullptr; // no softmax returned always @@ -219,6 +241,16 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout; launch_params.is_causal = params.is_causal; + launch_params.window_size_left = params.window_size_left; + launch_params.window_size_right = params.window_size_right; + + launch_params.is_seqlens_k_cumulative = params.is_seqlens_k_cumulative; + + // set params splitkv + launch_params.num_splits = params.num_splits; + + // Forward kernel will populate memory with the seed and offset. + launch_params.rng_state = reinterpret_cast(rng_state_ptr); if ((1.f - launch_params.p_dropout) > 0.0) { // number of times random will be generated per thread, to offset philox @@ -233,20 +265,67 @@ custom_call_flash_attention_forward(ExecutionContext* ctx, void* stream_handle, } FP16_SWITCH(!launch_params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(launch_params.d, [&] { + HEADDIM_SWITCH(launch_params.d, [&] { + // TODO(wenting.swt): support split_kv run_mha_fwd_(launch_params, gpu_stream); }); }); - return std::make_tuple(softmax_lse, output); + return std::make_tuple(softmax_lse, output, rng_state); +} + +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_noalibi( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, nullptr, customAttrs); } +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_alibi_v1( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType alibi_slopes, void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, alibi_slopes.data, + customAttrs); +} + +template +std::tuple, MemRefType, + MemRefType> +custom_call_flash_attention_forward_alibi_v2( + ExecutionContext* ctx, void* stream_handle, MemRefType q, + MemRefType k, MemRefType v, + MemRefType seqlens_q, MemRefType seqlens_k, + MemRefType alibi_slopes, void* customAttrs) { + return custom_call_flash_attention_forward_impl( + ctx, stream_handle, q, k, v, seqlens_q, seqlens_k, alibi_slopes.data, + customAttrs); +} + +TAO_RAL_API("custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_noalibi); +TAO_RAL_API( + "custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v1); +TAO_RAL_API( + "custom_call_flash_attention_forward", "gpu", + custom_call_flash_attention_forward_alibi_v2); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", - custom_call_flash_attention_forward); + custom_call_flash_attention_forward_noalibi); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", - custom_call_flash_attention_forward); + custom_call_flash_attention_forward_alibi_v1); TAO_RAL_API("custom_call_flash_attention_forward", "gpu", - custom_call_flash_attention_forward); + custom_call_flash_attention_forward_alibi_v2); } // namespace ral } // namespace tao From 1e499d939ef13ba3caf9453741e07ebf71e38000 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Tue, 13 Aug 2024 10:45:40 +0800 Subject: [PATCH 10/12] fix build failed with NCCL (#5) * fix build failed on nccl * using nccl hdrs --- torch_xla/csrc/runtime/disc/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_xla/csrc/runtime/disc/BUILD b/torch_xla/csrc/runtime/disc/BUILD index d6d25ced337..999aa85ea64 100755 --- a/torch_xla/csrc/runtime/disc/BUILD +++ b/torch_xla/csrc/runtime/disc/BUILD @@ -26,6 +26,7 @@ ptxla_cc_library( "@disc_compiler//:disc_custom_op", "@disc_compiler//:headers", "@local_config_cuda//cuda:cuda_headers", + "@nccl_archive//:nccl_headers", "@torch//:libc10", "@torch//:libc10_cuda", "@torch//:libtorch_cuda", @@ -58,6 +59,8 @@ ptxla_cc_library( ":disc_ral", ":disc_utils", "//torch_xla/csrc/runtime:tf_logging", + "//torch_xla/csrc/runtime:sys_util", + "//torch_xla/csrc/runtime:env_vars", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", From aa2c40ec1aef44e67e093f77c39bc7d0f14a9f88 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Thu, 10 Oct 2024 11:35:00 +0800 Subject: [PATCH 11/12] Use the value of DISC_DEVICE as the device type of disc backend (#8) * change the device type of disc to cuda to make amp work properly * Use the value of DISC_DEVICE as the device type of disc backend --- torch_xla/csrc/runtime/disc_computation_client.cc | 10 +++++++--- torch_xla/csrc/runtime/disc_computation_client.h | 3 +-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index dbf5ca065c4..6465551dbde 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -17,6 +17,7 @@ #include "mlir/Pass/Pass.h" // from @llvm-project #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/disc/disc_compile.h" +#include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "xla/hlo/ir/hlo_module.h" @@ -97,6 +98,10 @@ DISCComputationClient::DISCComputationClient() { world_size_ = sys_util::GetEnvInt("WORLD_SIZE", 1); local_rank_ = sys_util::GetEnvInt("LOCAL_RANK", 0); global_rank_ = sys_util::GetEnvInt("RANK", local_rank_); + device_type_ = sys_util::GetEnvString(env::kEnvDISCDevice, "CUDA"); + if (device_type_ != "CUDA") { + XLA_ERROR() << "Only CUDA device is supported by DISC backend"; + } } DISCComputationClient::~DISCComputationClient() {} @@ -362,7 +367,7 @@ std::map DISCComputationClient::GetMetrics() const { } std::string DISCComputationClient::GetDefaultDevice() const { - return absl::StrCat(DefaultDevicePrefix, std::to_string(local_rank_)); + return absl::StrCat(device_type_, ":", std::to_string(local_rank_)); } std::vector DISCComputationClient::GetLocalDevices() const { @@ -390,8 +395,7 @@ std::vector DISCComputationClient::GetAllDevices() const { std::vector all_devices; int device_count = world_size_; for (int idx = 0; idx < device_count; idx++) { - all_devices.push_back( - absl::StrCat(DefaultDevicePrefix, std::to_string(idx))); + all_devices.push_back(absl::StrCat(device_type_, ":", std::to_string(idx))); } return all_devices; } diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h index 8df3a6c0e73..0701d7b3591 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.h +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -11,8 +11,6 @@ namespace runtime { class DISCComputationClient : public ComputationClient { public: - const std::string DefaultDevicePrefix = "GPU:"; - DISCComputationClient(); ~DISCComputationClient(); @@ -142,6 +140,7 @@ class DISCComputationClient : public ComputationClient { int world_size_; int local_rank_; int global_rank_; + std::string device_type_; struct DISCData : public Data { DISCData(std::string device, xla::Shape device_shape) : Data(std::move(device), std::move(device_shape)) {} From d094086f6460390b95a5805608a283e607c1db25 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Fri, 11 Oct 2024 13:44:59 +0800 Subject: [PATCH 12/12] disable compilation of DISC by default (#15) --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 5d9fceacfe4..3531a61e211 100644 --- a/setup.py +++ b/setup.py @@ -231,7 +231,7 @@ def bazel_build(self, ext): bazel_argv.extend(build_util.bazel_options_from_env()) - if not build_util.check_env_flag('ENABLE_DISC', 'true'): + if not build_util.check_env_flag('ENABLE_DISC', 'false'): bazel_argv.append('--define=enable_disc=false') self.spawn(bazel_argv) @@ -253,7 +253,7 @@ def bazel_build(self, ext): # package BladeDISC distribution files # please note, TorchBlade also create some symbolic links to 'torch_blade' dir - if build_util.check_env_flag('ENABLE_DISC', 'true'): + if build_util.check_env_flag('ENABLE_DISC', 'false'): disc_ral_so_name = 'libral_base_context.so' bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' shutil.copyfile(