Skip to content

Commit

Permalink
Merge disc backend to acc 2.3 (#3)
Browse files Browse the repository at this point in the history
* build with BladeDISC (#8)

* [to #53687860] feat: DISC client header, implement DISCComputation and DISCData

POC implement in : https://code.alibaba-inc.com/torchx/xla/codereview/14984824

Link: https://code.alibaba-inc.com/torchx/xla/codereview/14987956

* Disc computation (#2)

Support Disc as backend
Co-authored-by: yancey.yx <[email protected]>
Co-authored-by: wangang.wa <[email protected]>

* add bazel flag to disable disc backend (#23)

* add flag to disable disc backend in bazel workspace

* support disc debug mode to dump mhlo and logs (#25)

support disc backend debug mode to dump DISC compilation logs

* support flash attention in disc (pytorch#34)

* fix disc flag when complie python (pytorch#39)

* fix bazel flag when complie python

* fix lint.

* support bf16 on disc backend (pytorch#40)

add float-norm pass to support bf16 amp training

* Support Flash Attention 2.5.6 for disc backend (#4)

* fix build failed with NCCL (#5)

* fix build failed on nccl

* using nccl hdrs

* Use the value of DISC_DEVICE as the device type of disc backend (#8)

* change the device type of disc to cuda to make amp work properly

* Use the value of DISC_DEVICE as the device type of disc backend

* disable compilation of DISC by default (#15)

---------

Co-authored-by: Yan Xu <[email protected]>
Co-authored-by: wenting.swt <[email protected]>
Co-authored-by: Dalong <[email protected]>
Co-authored-by: Baole Ai <[email protected]>
Co-authored-by: Yan Xu <[email protected]>
  • Loading branch information
6 people authored Oct 11, 2024
1 parent 063384c commit fab18e0
Show file tree
Hide file tree
Showing 32 changed files with 2,419 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,6 @@ build:linux --copt="-Wno-error=unused-but-set-variable"
# Only include debug info for files not under XLA.
build:dbg -c dbg
build:dbg --per_file_copt=external/xla/.*@-g0,-DNDEBUG

# build with DISC backend
build --define enable_disc=true
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/Dao-AILab/flash-attention.git
[submodule "third_party/BladeDISC"]
path = third_party/BladeDISC
url = https://github.com/alibaba/BladeDISC.git
7 changes: 7 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ load(
"if_cuda_is_configured",
)

load(
"//bazel:rules_def.bzl",
"if_enable_disc",
)

cc_binary(
name = "_XLAC.so",
copts = [
Expand All @@ -28,5 +33,7 @@ cc_binary(
"@torch//:libtorch_python",
] + if_cuda_is_configured([
"@xla//xla/stream_executor:cuda_platform",
]) + if_enable_disc([
"//torch_xla/csrc/runtime/disc:disc_ral",
]),
)
7 changes: 7 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,10 @@ new_local_repository(
build_file = "//bazel:flash_attn.BUILD",
path = "third_party/flash-attention/",
)
################################ BladeDISC Setup ################################

new_local_repository(
name = "disc_compiler",
build_file = "//bazel:disc.BUILD",
path = "third_party/BladeDISC/",
)
49 changes: 49 additions & 0 deletions bazel/disc.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

package(
default_visibility = [
"//visibility:public",
],
)

cc_library(
name = "headers",
hdrs = glob(
[
"mlir/ral/*.h",
"mlir/ral/context/base/cuda/*.h",
"mlir/ral/context/base/cuda/cuda_context_impl.h",
"mlir/ral/device/cpu/*.h",
"mlir/ral/device/gpu/*.h",
],
),
includes = [
"tao_compiler",
"tao_compiler/mlir",
],
strip_include_prefix = "external/disc_compiler/tao_compiler/mlir",
)

cc_import(
name="disc_ral_cuda",
shared_library = ":libral_base_context.so",
)

cc_import(
name="disc_custom_op",
shared_library = ":libdisc_custom_ops.so",
)

genrule(
name = "build_disc",
outs = ["libral_base_context.so", "libdisc_custom_ops.so", "disc_compiler_main", "torch-mlir-opt"],
local = True,
cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}',
'pushd external/disc_compiler/pytorch_blade/',
'python ../scripts/python/common_setup.py',
'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel',
'popd',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/ral/libral_base_context.so $(location libral_base_context.so)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/custom_ops/libdisc_custom_ops.so $(location libdisc_custom_ops.so)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/tests/mhlo/torch-mlir-opt/torch-mlir-opt $(location torch-mlir-opt)']),
)
7 changes: 3 additions & 4 deletions bazel/flash_attn.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ genrule(
name = "build_flash_attn",
srcs = ["setup.py"],
outs = ["flash_attn_cuda.so"],
cmd = ';'.join(['pushd third_party/flash-attention/',
'MAX_JOBS=50 FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel 2>&1 | tee build.log',
'cp build/*/*.so flash_attn_cuda.so',
cmd = ';'.join(['pushd external/flash_attn/',
'FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel',
'popd',
'cp third_party/flash-attention/flash_attn_cuda.so $(OUTS)']),
'cp external/flash_attn/build/*/*.so $(location flash_attn_cuda.so)']),
)
6 changes: 6 additions & 0 deletions bazel/rules_def.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def ptxla_cc_test(
],
**kwargs
)

def if_enable_disc(if_true, if_false=[]):
return select({
"//torch_xla/csrc/runtime:enable_disc": if_true,
"//conditions:default": if_false
})
8 changes: 8 additions & 0 deletions bazel/torch.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ cc_import(
shared_library = "build/lib/libtorch_cpu.so",
)

cc_import(
name = "libtorch_cuda",
shared_library = "build/lib/libtorch_cuda.so",
)
cc_import(
name = "libtorch_python",
shared_library = "build/lib/libtorch_python.so",
Expand All @@ -64,3 +68,7 @@ cc_import(
name = "libc10",
shared_library = "build/lib/libc10.so",
)
cc_import(
name = "libc10_cuda",
shared_library = "build/lib/libc10_cuda.so",
)
24 changes: 21 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def bazel_build(self, ext):

bazel_argv.extend(build_util.bazel_options_from_env())

if not build_util.check_env_flag('ENABLE_DISC', 'false'):
bazel_argv.append('--define=enable_disc=false')

self.spawn(bazel_argv)

ext_bazel_bin_path = os.path.join(self.build_temp, 'bazel-bin', ext.relpath,
Expand All @@ -244,9 +247,24 @@ def bazel_build(self, ext):

# copy flash attention cuda so file
flash_attn_so_name = 'flash_attn_cuda.so'
shutil.copyfile(
'/'.join(['third_party/flash-attention', flash_attn_so_name]),
'/'.join([ext_dest_dir, flash_attn_so_name]))
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/flash_attn/'
shutil.copyfile('/'.join([bazel_bin_path, flash_attn_so_name]),
'/'.join([ext_dest_dir, flash_attn_so_name]))

# package BladeDISC distribution files
# please note, TorchBlade also create some symbolic links to 'torch_blade' dir
if build_util.check_env_flag('ENABLE_DISC', 'false'):
disc_ral_so_name = 'libral_base_context.so'
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler'
shutil.copyfile(
os.path.join(bazel_bin_path, disc_ral_so_name),
'/'.join([ext_dest_dir, disc_ral_so_name]))

disc_customop_so_name = 'libdisc_custom_ops.so'
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler'
shutil.copyfile(
os.path.join(bazel_bin_path, disc_customop_so_name),
'/'.join([ext_dest_dir, disc_customop_so_name]))


class Develop(develop.develop):
Expand Down
Empty file modified test/test_flash_attention_backward.py
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions third_party/BladeDISC
Submodule BladeDISC added at fbe39b
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load(
"//bazel:rules_def.bzl",
"ptxla_cc_library",
"ptxla_cc_test",
)

genrule(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/flash_attention_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(int batch_size, int num_heads, int seqlen_q,
xla::PrimitiveType::F32, {batch_size, num_heads, seqlen_q});
xla::Shape out_shape = GetXlaShape(q);
xla::Shape rng_state_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2});
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2});
return xla::ShapeUtil::MakeTupleShape(
{softmax_lse_shape, out_shape, rng_state_shape});
}
Expand Down
59 changes: 58 additions & 1 deletion torch_xla/csrc/runtime/BUILD
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@ load(

load(
"//bazel:rules_def.bzl",
"ptxla_cc_library",
"ptxla_cc_test",
"if_enable_disc",
)

licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])

config_setting(
name = "enable_disc",
define_values = {"enable_disc": "true"},
)

cc_library(
name = "runtime",
srcs = [
Expand All @@ -31,7 +38,12 @@ cc_library(
":pjrt_computation_client",
":ifrt_computation_client",
"@tsl//tsl/platform:stacktrace",
],
] + if_enable_disc([
":disc_computation_client",
]),
copts = if_enable_disc([
"-DTORCHACC_ENABLE_DISC",
]),
)

cc_library(
Expand Down Expand Up @@ -137,6 +149,29 @@ cc_library(
],
)

cc_library(
name = "disc_computation_client",
srcs = [
"disc_computation_client.cc",
],
hdrs = [
"disc_computation_client.h",
],
deps = [
":computation_client",
":stablehlo_helper",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"//torch_xla/csrc/runtime/disc:disc_ral",
"//torch_xla/csrc/runtime/disc:disc_compile",
"@xla//xla/service:float_normalization",
"@xla//xla/service/gpu:gpu_float_support",
],
)

cc_library(
name = "cache",
hdrs = ["cache.h"],
Expand Down Expand Up @@ -519,3 +554,25 @@ ptxla_cc_test(
"@tsl//tsl/platform:test_main",
],
)

ptxla_cc_test(
name = "disc_computation_client_test",
srcs = ["disc_computation_client_test.cc"],
deps = [
":disc_computation_client",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@stablehlo//:register",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
81 changes: 81 additions & 0 deletions torch_xla/csrc/runtime/disc/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)

load(
"//bazel:rules_def.bzl",
"ptxla_cc_library",
"ptxla_cc_test",
)


ptxla_cc_library(
name = "disc_ral",
srcs = [
"disc_ral.cc",
"custom_call_flash_attention_forward.cc",
"custom_call_flash_attention_backward.cc"
],
hdrs = [
"disc_ral.h",
],
deps = [
":disc_utils",
"@disc_compiler//:disc_ral_cuda",
"@disc_compiler//:disc_custom_op",
"@disc_compiler//:headers",
"@local_config_cuda//cuda:cuda_headers",
"@nccl_archive//:nccl_headers",
"@torch//:libc10",
"@torch//:libc10_cuda",
"@torch//:libtorch_cuda",
"@flash_attn//:headers",
"@flash_attn//:flash_attn_cuda",
],
copts = [
"-DGOOGLE_CUDA",
]
)

ptxla_cc_library(
name = "disc_utils",
srcs = ["disc_utils.cc"],
hdrs = [
"disc_utils.h",
],
deps = [
"//torch_xla/csrc/runtime:tf_logging",
]
)

ptxla_cc_library(
name = "disc_compile",
srcs = ["disc_compile.cc"],
hdrs = [
"disc_compile.h",
],
deps = [
":disc_ral",
":disc_utils",
"//torch_xla/csrc/runtime:tf_logging",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:env_vars",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
],
copts = [
"-DGOOGLE_CUDA",
]
)

ptxla_cc_test(
name = "disc_ral_test",
srcs = ["disc_ral_test.cc"],
deps = [
":disc_ral",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
]
)
Loading

0 comments on commit fab18e0

Please sign in to comment.