Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Add cuda nccl example #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,18 @@ coverage --test_tag_filters=-cpplint
# as recommended in https://github.com/bazelbuild/bazel/issues/6319
try-import %workspace%/ci.bazelrc
try-import %workspace%/user.bazelrc


# cuda flags from: https://github.com/bazel-contrib/rules_cuda/blob/main/examples/.bazelrc
# Convenient flag shortcuts.
build --flag_alias=enable_cuda=@rules_cuda//cuda:enable
build --flag_alias=cuda_archs=@rules_cuda//cuda:archs
build --flag_alias=cuda_compiler=@rules_cuda//cuda:compiler
build --flag_alias=cuda_copts=@rules_cuda//cuda:copts
build --flag_alias=cuda_runtime=@rules_cuda//cuda:runtime

build --enable_cuda=True

# Use --config=clang to build with clang instead of gcc and nvcc.
build:clang --repo_env=CC=clang
build:clang --@rules_cuda//cuda:compiler=clang
4 changes: 4 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ repositories()
load("//bazel:init_deps.bzl", "init_deps")
init_deps()

load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
bazel_skylib_workspace()


# Load the LLVM toolchain
load("@toolchains_llvm//toolchain:deps.bzl", "bazel_toolchain_dependencies")
bazel_toolchain_dependencies()
Expand Down
17 changes: 15 additions & 2 deletions bazel/workspace.bzl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# external dependencies that can be loaded in WORKSPACE files.

load("//third_party/toolchains_llvm:workspace.bzl", toolchains_llvm = "repo")

load("//third_party/bazel_skylib:workspace.bzl", bazel_skylib = "repo")
load("//third_party/bazel_gazelle:workspace.bzl", bazel_gazelle = "repo")

load("//third_party/rules_python:workspace.bzl", rules_python = "repo")
load("//third_party/rules_foreign_cc:workspace.bzl", rules_foreign_cc = "repo")
load("//third_party/pybind11:workspace.bzl", pybind11 = "repo")
load("//third_party/rules_proto:workspace.bzl", rules_proto = "repo")
load("//third_party/rules_go:workspace.bzl", rules_go = "repo")
load("//third_party/rules_rust:workspace.bzl", rules_rust = "repo")
load("//third_party/rules_cuda:workspace.bzl", rules_cuda = "repo")

load("//third_party/rules_oci:workspace.bzl", rules_oci = "repo")
load("//third_party/rules_pkg:workspace.bzl", rules_pkg = "repo")
Expand Down Expand Up @@ -49,8 +51,14 @@ load("//third_party/mcap:workspace.bzl", mcap = "repo")
load("//third_party/onnxruntime:workspace.bzl", onnxruntime = "repo")
load("//third_party/foxglove_schemas:workspace.bzl", foxglove_schemas = "repo")

# cuda
load("//third_party/rules_cuda:workspace.bzl", rules_cuda = "repo")
load("//third_party/nccl:workspace.bzl", nccl = "repo")

def init_language_repos():
toolchains_llvm()

bazel_skylib()
bazel_gazelle()

rules_proto()
Expand All @@ -62,7 +70,7 @@ def init_language_repos():
rules_pkg()
rules_rust()
rules_oci()
rules_cuda()


def init_compression_libs():
boringssl()
Expand Down Expand Up @@ -100,8 +108,13 @@ def init_third_parties():
foxglove_schemas()
onnxruntime()

def init_cuda_repos():
rules_cuda()
nccl()

# Define all external repositories required by
def repositories():
init_language_repos()
init_compression_libs()
init_third_parties()
init_cuda_repos()
1 change: 1 addition & 0 deletions third_party/bazel_skylib/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package(default_visibility = ["//visibility:public"])
2 changes: 2 additions & 0 deletions third_party/bazel_skylib/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## bazel-skylib
See: https://github.com/bazelbuild/bazel-skylib
11 changes: 11 additions & 0 deletions third_party/bazel_skylib/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Loads the bazel_skylib library"""

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

def repo():
http_archive(
name = "bazel_skylib",
sha256 = "2e6fa9a61db799266072df115a719a14a9af0e8a630b1f770ef0bd757e68cd71",
strip_prefix = "bazel-skylib-de3035d605b4c89a62d6da060188e4ab0c5034b9",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/de3035d605b4c89a62d6da060188e4ab0c5034b9.tar.gz"],
)
24 changes: 24 additions & 0 deletions third_party/nccl/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package(default_visibility = ["//visibility:public"])

filegroup(
name = "nccl_shared",
srcs = [
"@nccl//:nccl_shared",
],
)

filegroup(
name = "perf_binaries",
srcs = [
"@nccl-tests//:all_gather_perf",
"@nccl-tests//:all_reduce_perf",
"@nccl-tests//:alltoall_perf",
"@nccl-tests//:broadcast_perf",
"@nccl-tests//:gather_perf",
"@nccl-tests//:hypercube_perf",
"@nccl-tests//:reduce_perf",
"@nccl-tests//:reduce_scatter_perf",
"@nccl-tests//:scatter_perf",
"@nccl-tests//:sendrecv_perf",
],
)
2 changes: 2 additions & 0 deletions third_party/nccl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## NCCL
See: https://github.com/bazel-contrib/rules_cuda/blob/main/examples/WORKSPACE.bazel
52 changes: 52 additions & 0 deletions third_party/nccl/nccl-tests.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load("@ai_playground//third_party/nccl:nccl-tests.bzl", "nccl_tests_binary")

# NOTE: all paths in this file relative to @nccl-tests repo root.

cc_library(
name = "nccl_tests_include",
hdrs = glob(["src/*.h"]),
includes = ["src"],
)

cuda_library(
name = "common_cuda",
srcs = [
"src/common.cu",
"verifiable/verifiable.cu",
] + glob([
"**/*.h",
]),
deps = [
":nccl_tests_include",
"@nccl",
],
)

cc_library(
name = "common_cc",
srcs = ["src/timer.cc"],
hdrs = ["src/timer.h"],
alwayslink = 1,
)

# :common_cuda, :common_cc and @nccl//:nccl_shared are implicitly hardcoded in `nccl_tests_binary`
nccl_tests_binary(name = "all_reduce")

nccl_tests_binary(name = "all_gather")

nccl_tests_binary(name = "broadcast")

nccl_tests_binary(name = "reduce_scatter")

nccl_tests_binary(name = "reduce")

nccl_tests_binary(name = "alltoall")

nccl_tests_binary(name = "scatter")

nccl_tests_binary(name = "gather")

nccl_tests_binary(name = "sendrecv")

nccl_tests_binary(name = "hypercube")
21 changes: 21 additions & 0 deletions third_party/nccl/nccl-tests.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_library")

# NOTE: all paths in this file relative to @nccl-tests repo root.

def nccl_tests_binary(name, cc_deps = [], cuda_deps = []):
cuda_library(
name = name,
srcs = ["src/{}.cu".format(name)],
deps = [
"@nccl//:nccl_shared",
":common_cuda",
],
alwayslink = 1,
)

bin_name = name + "_perf"
native.cc_binary(
name = bin_name,
deps = [":common_cc", ":" + name],
visibility = ["//visibility:public"],
)
165 changes: 165 additions & 0 deletions third_party/nccl/nccl.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@rules_cuda//cuda:defs.bzl", "cuda_library", "cuda_objects")
load("@ai_playground//third_party/nccl:nccl.bzl", "if_cuda_clang", "if_cuda_nvcc", "nccl_primitive")

# NOTE: all paths in this file relative to @nccl repo root.

expand_template(
name = "nccl_h",
out = "src/include/nccl.h",
substitutions = {
"${nccl:Major}": "2",
"${nccl:Minor}": "18",
"${nccl:Patch}": "3",
"${nccl:Suffix}": "",
# NCCL_VERSION(X,Y,Z) ((X) * 10000 + (Y) * 100 + (Z))
"${nccl:Version}": "21803",
},
template = "src/nccl.h.in",
)

cc_library(
name = "nccl_include",
hdrs = [
":nccl_h",
] + glob([
"src/include/**/*.h",
"src/include/**/*.hpp",
]),
includes = [
# this will add both nccl/src/include in repo and
# bazel-out/<compilation_mode>/bin/nccl/src/include to include paths
# so the previous expand_template generate nccl.h to the very path!
"src/include",
],
)

cuda_objects(
name = "nccl_device_common",
srcs = [
"src/collectives/device/functions.cu",
"src/collectives/device/onerank_reduce.cu",
] + glob([
"src/collectives/device/**/*.h",
]),
copts = if_cuda_nvcc(["--extended-lambda"]),
ptxasopts = ["-maxrregcount=96"],
deps = [":nccl_include"],
)

# must be manually disabled if cuda version is lower than 11.
USE_BF16 = True

filegroup(
name = "collective_dev_hdrs",
srcs = [
"src/collectives/device/all_gather.h",
"src/collectives/device/all_reduce.h",
"src/collectives/device/broadcast.h",
"src/collectives/device/common.h",
"src/collectives/device/common_kernel.h",
"src/collectives/device/gen_rules.sh",
"src/collectives/device/op128.h",
"src/collectives/device/primitives.h",
"src/collectives/device/prims_ll.h",
"src/collectives/device/prims_ll128.h",
"src/collectives/device/prims_simple.h",
"src/collectives/device/reduce.h",
"src/collectives/device/reduce_kernel.h",
"src/collectives/device/reduce_scatter.h",
"src/collectives/device/sendrecv.h",
],
)

# cuda_objects for each type of primitive
nccl_primitive(
name = "all_gather",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "all_reduce",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "broadcast",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "reduce",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "reduce_scatter",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "sendrecv",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

# device link
cuda_library(
name = "collectives",
rdc = 1,
deps = [
":all_gather",
":all_reduce",
":broadcast",
":reduce",
":reduce_scatter",
":sendrecv",
],
alwayslink = 1,
)

cc_binary(
name = "nccl",
srcs = glob(
[
"src/*.cc",
"src/collectives/*.cc",
"src/graph/*.cc",
"src/graph/*.h",
"src/misc/*.cc",
"src/transport/*.cc",
],
exclude = [
# https://github.com/NVIDIA/nccl/issues/658
"src/enhcompat.cc",
],
),
copts = if_cuda_clang(["-xcu"]),
linkshared = 1,
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":collectives",
":nccl_include",
"@rules_cuda//cuda:runtime",
],
)

# To allow downstream targets to link with the nccl shared library, we need to `cc_import` it again.
# See https://groups.google.com/g/bazel-discuss/c/RtbidPdVFyU/m/TsUDOVHIAwAJ
cc_import(
name = "nccl_shared",
shared_library = ":nccl",
visibility = ["//visibility:public"],
)
Loading
Loading