Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge disc backend to acc 2.3 #3

Merged
merged 12 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,6 @@ build:linux --copt="-Wno-error=unused-but-set-variable"
# Only include debug info for files not under XLA.
build:dbg -c dbg
build:dbg --per_file_copt=external/xla/.*@-g0,-DNDEBUG

# build with DISC backend
build --define enable_disc=true
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/Dao-AILab/flash-attention.git
[submodule "third_party/BladeDISC"]
path = third_party/BladeDISC
url = https://github.com/alibaba/BladeDISC.git
7 changes: 7 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ load(
"if_cuda_is_configured",
)

load(
"//bazel:rules_def.bzl",
"if_enable_disc",
)

cc_binary(
name = "_XLAC.so",
copts = [
Expand All @@ -28,5 +33,7 @@ cc_binary(
"@torch//:libtorch_python",
] + if_cuda_is_configured([
"@xla//xla/stream_executor:cuda_platform",
]) + if_enable_disc([
"//torch_xla/csrc/runtime/disc:disc_ral",
]),
)
7 changes: 7 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,10 @@ new_local_repository(
build_file = "//bazel:flash_attn.BUILD",
path = "third_party/flash-attention/",
)
################################ BladeDISC Setup ################################

new_local_repository(
name = "disc_compiler",
build_file = "//bazel:disc.BUILD",
path = "third_party/BladeDISC/",
)
49 changes: 49 additions & 0 deletions bazel/disc.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

package(
default_visibility = [
"//visibility:public",
],
)

cc_library(
name = "headers",
hdrs = glob(
[
"mlir/ral/*.h",
"mlir/ral/context/base/cuda/*.h",
"mlir/ral/context/base/cuda/cuda_context_impl.h",
"mlir/ral/device/cpu/*.h",
"mlir/ral/device/gpu/*.h",
],
),
includes = [
"tao_compiler",
"tao_compiler/mlir",
],
strip_include_prefix = "external/disc_compiler/tao_compiler/mlir",
)

cc_import(
name="disc_ral_cuda",
shared_library = ":libral_base_context.so",
)

cc_import(
name="disc_custom_op",
shared_library = ":libdisc_custom_ops.so",
)

genrule(
name = "build_disc",
outs = ["libral_base_context.so", "libdisc_custom_ops.so", "disc_compiler_main", "torch-mlir-opt"],
local = True,
cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}',
'pushd external/disc_compiler/pytorch_blade/',
'python ../scripts/python/common_setup.py',
'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel',
'popd',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/ral/libral_base_context.so $(location libral_base_context.so)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/custom_ops/libdisc_custom_ops.so $(location libdisc_custom_ops.so)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/tests/mhlo/torch-mlir-opt/torch-mlir-opt $(location torch-mlir-opt)']),
)
7 changes: 3 additions & 4 deletions bazel/flash_attn.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ genrule(
name = "build_flash_attn",
srcs = ["setup.py"],
outs = ["flash_attn_cuda.so"],
cmd = ';'.join(['pushd third_party/flash-attention/',
'MAX_JOBS=50 FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel 2>&1 | tee build.log',
'cp build/*/*.so flash_attn_cuda.so',
cmd = ';'.join(['pushd external/flash_attn/',
'FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel',
'popd',
'cp third_party/flash-attention/flash_attn_cuda.so $(OUTS)']),
'cp external/flash_attn/build/*/*.so $(location flash_attn_cuda.so)']),
)
6 changes: 6 additions & 0 deletions bazel/rules_def.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def ptxla_cc_test(
],
**kwargs
)

def if_enable_disc(if_true, if_false=[]):
return select({
"//torch_xla/csrc/runtime:enable_disc": if_true,
"//conditions:default": if_false
})
8 changes: 8 additions & 0 deletions bazel/torch.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ cc_import(
shared_library = "build/lib/libtorch_cpu.so",
)

cc_import(
name = "libtorch_cuda",
shared_library = "build/lib/libtorch_cuda.so",
)
cc_import(
name = "libtorch_python",
shared_library = "build/lib/libtorch_python.so",
Expand All @@ -64,3 +68,7 @@ cc_import(
name = "libc10",
shared_library = "build/lib/libc10.so",
)
cc_import(
name = "libc10_cuda",
shared_library = "build/lib/libc10_cuda.so",
)
24 changes: 21 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def bazel_build(self, ext):

bazel_argv.extend(build_util.bazel_options_from_env())

if not build_util.check_env_flag('ENABLE_DISC', 'false'):
bazel_argv.append('--define=enable_disc=false')

self.spawn(bazel_argv)

ext_bazel_bin_path = os.path.join(self.build_temp, 'bazel-bin', ext.relpath,
Expand All @@ -244,9 +247,24 @@ def bazel_build(self, ext):

# copy flash attention cuda so file
flash_attn_so_name = 'flash_attn_cuda.so'
shutil.copyfile(
'/'.join(['third_party/flash-attention', flash_attn_so_name]),
'/'.join([ext_dest_dir, flash_attn_so_name]))
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/flash_attn/'
shutil.copyfile('/'.join([bazel_bin_path, flash_attn_so_name]),
'/'.join([ext_dest_dir, flash_attn_so_name]))

# package BladeDISC distribution files
# please note, TorchBlade also create some symbolic links to 'torch_blade' dir
if build_util.check_env_flag('ENABLE_DISC', 'false'):
disc_ral_so_name = 'libral_base_context.so'
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler'
shutil.copyfile(
os.path.join(bazel_bin_path, disc_ral_so_name),
'/'.join([ext_dest_dir, disc_ral_so_name]))

disc_customop_so_name = 'libdisc_custom_ops.so'
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler'
shutil.copyfile(
os.path.join(bazel_bin_path, disc_customop_so_name),
'/'.join([ext_dest_dir, disc_customop_so_name]))


class Develop(develop.develop):
Expand Down
Empty file modified test/test_flash_attention_backward.py
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions third_party/BladeDISC
Submodule BladeDISC added at fbe39b
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load(
"//bazel:rules_def.bzl",
"ptxla_cc_library",
"ptxla_cc_test",
)

genrule(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/flash_attention_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(int batch_size, int num_heads, int seqlen_q,
xla::PrimitiveType::F32, {batch_size, num_heads, seqlen_q});
xla::Shape out_shape = GetXlaShape(q);
xla::Shape rng_state_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2});
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2});
return xla::ShapeUtil::MakeTupleShape(
{softmax_lse_shape, out_shape, rng_state_shape});
}
Expand Down
59 changes: 58 additions & 1 deletion torch_xla/csrc/runtime/BUILD
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@ load(

load(
"//bazel:rules_def.bzl",
"ptxla_cc_library",
"ptxla_cc_test",
"if_enable_disc",
)

licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])

config_setting(
name = "enable_disc",
define_values = {"enable_disc": "true"},
)

cc_library(
name = "runtime",
srcs = [
Expand All @@ -31,7 +38,12 @@ cc_library(
":pjrt_computation_client",
":ifrt_computation_client",
"@tsl//tsl/platform:stacktrace",
],
] + if_enable_disc([
":disc_computation_client",
]),
copts = if_enable_disc([
"-DTORCHACC_ENABLE_DISC",
]),
)

cc_library(
Expand Down Expand Up @@ -137,6 +149,29 @@ cc_library(
],
)

cc_library(
name = "disc_computation_client",
srcs = [
"disc_computation_client.cc",
],
hdrs = [
"disc_computation_client.h",
],
deps = [
":computation_client",
":stablehlo_helper",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@xla//xla/client:xla_computation",
"//torch_xla/csrc/runtime/disc:disc_ral",
"//torch_xla/csrc/runtime/disc:disc_compile",
"@xla//xla/service:float_normalization",
"@xla//xla/service/gpu:gpu_float_support",
],
)

cc_library(
name = "cache",
hdrs = ["cache.h"],
Expand Down Expand Up @@ -506,3 +541,25 @@ ptxla_cc_test(
"@tsl//tsl/platform:test_main",
],
)

ptxla_cc_test(
name = "disc_computation_client_test",
srcs = ["disc_computation_client_test.cc"],
deps = [
":disc_computation_client",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@stablehlo//:register",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
81 changes: 81 additions & 0 deletions torch_xla/csrc/runtime/disc/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)

load(
"//bazel:rules_def.bzl",
"ptxla_cc_library",
"ptxla_cc_test",
)


ptxla_cc_library(
name = "disc_ral",
srcs = [
"disc_ral.cc",
"custom_call_flash_attention_forward.cc",
"custom_call_flash_attention_backward.cc"
],
hdrs = [
"disc_ral.h",
],
deps = [
":disc_utils",
"@disc_compiler//:disc_ral_cuda",
"@disc_compiler//:disc_custom_op",
"@disc_compiler//:headers",
"@local_config_cuda//cuda:cuda_headers",
"@nccl_archive//:nccl_headers",
"@torch//:libc10",
"@torch//:libc10_cuda",
"@torch//:libtorch_cuda",
"@flash_attn//:headers",
"@flash_attn//:flash_attn_cuda",
],
copts = [
"-DGOOGLE_CUDA",
]
)

ptxla_cc_library(
name = "disc_utils",
srcs = ["disc_utils.cc"],
hdrs = [
"disc_utils.h",
],
deps = [
"//torch_xla/csrc/runtime:tf_logging",
]
)

ptxla_cc_library(
name = "disc_compile",
srcs = ["disc_compile.cc"],
hdrs = [
"disc_compile.h",
],
deps = [
":disc_ral",
":disc_utils",
"//torch_xla/csrc/runtime:tf_logging",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:env_vars",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
],
copts = [
"-DGOOGLE_CUDA",
]
)

ptxla_cc_test(
name = "disc_ral_test",
srcs = ["disc_ral_test.cc"],
deps = [
":disc_ral",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
]
)
Loading
Loading