Skip to content

Commit

Permalink
[Dev] Add nccl and nccl-test build example
Browse files Browse the repository at this point in the history
  • Loading branch information
lshmouse committed Sep 30, 2024
1 parent 1e49482 commit 9bd5bde
Show file tree
Hide file tree
Showing 13 changed files with 380 additions and 2 deletions.
15 changes: 15 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,18 @@ coverage --test_tag_filters=-cpplint
# as recommended in https://github.com/bazelbuild/bazel/issues/6319
try-import %workspace%/ci.bazelrc
try-import %workspace%/user.bazelrc


# cuda flags from: https://github.com/bazel-contrib/rules_cuda/blob/main/examples/.bazelrc
# Convenient flag shortcuts.
build --flag_alias=enable_cuda=@rules_cuda//cuda:enable
build --flag_alias=cuda_archs=@rules_cuda//cuda:archs
build --flag_alias=cuda_compiler=@rules_cuda//cuda:compiler
build --flag_alias=cuda_copts=@rules_cuda//cuda:copts
build --flag_alias=cuda_runtime=@rules_cuda//cuda:runtime

build --enable_cuda=True

# Use --config=clang to build with clang instead of gcc and nvcc.
build:clang --repo_env=CC=clang
build:clang --@rules_cuda//cuda:compiler=clang
4 changes: 4 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ repositories()
load("//bazel:init_deps.bzl", "init_deps")
init_deps()

load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
bazel_skylib_workspace()


# Load the LLVM toolchain
load("@toolchains_llvm//toolchain:deps.bzl", "bazel_toolchain_dependencies")
bazel_toolchain_dependencies()
Expand Down
17 changes: 15 additions & 2 deletions bazel/workspace.bzl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# external dependencies that can be loaded in WORKSPACE files.

load("//third_party/toolchains_llvm:workspace.bzl", toolchains_llvm = "repo")

load("//third_party/bazel_skylib:workspace.bzl", bazel_skylib = "repo")
load("//third_party/bazel_gazelle:workspace.bzl", bazel_gazelle = "repo")

load("//third_party/rules_python:workspace.bzl", rules_python = "repo")
load("//third_party/rules_foreign_cc:workspace.bzl", rules_foreign_cc = "repo")
load("//third_party/pybind11:workspace.bzl", pybind11 = "repo")
load("//third_party/rules_proto:workspace.bzl", rules_proto = "repo")
load("//third_party/rules_go:workspace.bzl", rules_go = "repo")
load("//third_party/rules_rust:workspace.bzl", rules_rust = "repo")
load("//third_party/rules_cuda:workspace.bzl", rules_cuda = "repo")

load("//third_party/rules_oci:workspace.bzl", rules_oci = "repo")
load("//third_party/rules_pkg:workspace.bzl", rules_pkg = "repo")
Expand Down Expand Up @@ -49,8 +51,14 @@ load("//third_party/mcap:workspace.bzl", mcap = "repo")
load("//third_party/onnxruntime:workspace.bzl", onnxruntime = "repo")
load("//third_party/foxglove_schemas:workspace.bzl", foxglove_schemas = "repo")

# cuda
load("//third_party/rules_cuda:workspace.bzl", rules_cuda = "repo")
load("//third_party/nccl:workspace.bzl", nccl = "repo")

def init_language_repos():
toolchains_llvm()

bazel_skylib()
bazel_gazelle()

rules_proto()
Expand All @@ -62,7 +70,7 @@ def init_language_repos():
rules_pkg()
rules_rust()
rules_oci()
rules_cuda()


def init_compression_libs():
boringssl()
Expand Down Expand Up @@ -100,8 +108,13 @@ def init_third_parties():
foxglove_schemas()
onnxruntime()

def init_cuda_repos():
rules_cuda()
nccl()

# Define all external repositories required by
def repositories():
init_language_repos()
init_compression_libs()
init_third_parties()
init_cuda_repos()
1 change: 1 addition & 0 deletions third_party/bazel_skylib/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package(default_visibility = ["//visibility:public"])
2 changes: 2 additions & 0 deletions third_party/bazel_skylib/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## bazel-skylib
See: https://github.com/bazelbuild/bazel-skylib
11 changes: 11 additions & 0 deletions third_party/bazel_skylib/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Loads the bazel_skylib library"""

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

def repo():
http_archive(
name = "bazel_skylib",
sha256 = "2e6fa9a61db799266072df115a719a14a9af0e8a630b1f770ef0bd757e68cd71",
strip_prefix = "bazel-skylib-de3035d605b4c89a62d6da060188e4ab0c5034b9",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/de3035d605b4c89a62d6da060188e4ab0c5034b9.tar.gz"],
)
24 changes: 24 additions & 0 deletions third_party/nccl/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package(default_visibility = ["//visibility:public"])

filegroup(
name = "nccl_shared",
srcs = [
"@nccl//:nccl_shared",
],
)

filegroup(
name = "perf_binaries",
srcs = [
"@nccl-tests//:all_gather_perf",
"@nccl-tests//:all_reduce_perf",
"@nccl-tests//:alltoall_perf",
"@nccl-tests//:broadcast_perf",
"@nccl-tests//:gather_perf",
"@nccl-tests//:hypercube_perf",
"@nccl-tests//:reduce_perf",
"@nccl-tests//:reduce_scatter_perf",
"@nccl-tests//:scatter_perf",
"@nccl-tests//:sendrecv_perf",
],
)
2 changes: 2 additions & 0 deletions third_party/nccl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## NCCL
See: https://github.com/bazel-contrib/rules_cuda/blob/main/examples/WORKSPACE.bazel
52 changes: 52 additions & 0 deletions third_party/nccl/nccl-tests.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load("@ai_playground//third_party/nccl:nccl-tests.bzl", "nccl_tests_binary")

# NOTE: all paths in this file relative to @nccl-tests repo root.

cc_library(
name = "nccl_tests_include",
hdrs = glob(["src/*.h"]),
includes = ["src"],
)

cuda_library(
name = "common_cuda",
srcs = [
"src/common.cu",
"verifiable/verifiable.cu",
] + glob([
"**/*.h",
]),
deps = [
":nccl_tests_include",
"@nccl",
],
)

cc_library(
name = "common_cc",
srcs = ["src/timer.cc"],
hdrs = ["src/timer.h"],
alwayslink = 1,
)

# :common_cuda, :common_cc and @nccl//:nccl_shared are implicitly hardcoded in `nccl_tests_binary`
nccl_tests_binary(name = "all_reduce")

nccl_tests_binary(name = "all_gather")

nccl_tests_binary(name = "broadcast")

nccl_tests_binary(name = "reduce_scatter")

nccl_tests_binary(name = "reduce")

nccl_tests_binary(name = "alltoall")

nccl_tests_binary(name = "scatter")

nccl_tests_binary(name = "gather")

nccl_tests_binary(name = "sendrecv")

nccl_tests_binary(name = "hypercube")
21 changes: 21 additions & 0 deletions third_party/nccl/nccl-tests.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_library")

# NOTE: all paths in this file relative to @nccl-tests repo root.

def nccl_tests_binary(name, cc_deps = [], cuda_deps = []):
cuda_library(
name = name,
srcs = ["src/{}.cu".format(name)],
deps = [
"@nccl//:nccl_shared",
":common_cuda",
],
alwayslink = 1,
)

bin_name = name + "_perf"
native.cc_binary(
name = bin_name,
deps = [":common_cc", ":" + name],
visibility = ["//visibility:public"],
)
165 changes: 165 additions & 0 deletions third_party/nccl/nccl.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@rules_cuda//cuda:defs.bzl", "cuda_library", "cuda_objects")
load("@ai_playground//third_party/nccl:nccl.bzl", "if_cuda_clang", "if_cuda_nvcc", "nccl_primitive")

# NOTE: all paths in this file relative to @nccl repo root.

expand_template(
name = "nccl_h",
out = "src/include/nccl.h",
substitutions = {
"${nccl:Major}": "2",
"${nccl:Minor}": "18",
"${nccl:Patch}": "3",
"${nccl:Suffix}": "",
# NCCL_VERSION(X,Y,Z) ((X) * 10000 + (Y) * 100 + (Z))
"${nccl:Version}": "21803",
},
template = "src/nccl.h.in",
)

cc_library(
name = "nccl_include",
hdrs = [
":nccl_h",
] + glob([
"src/include/**/*.h",
"src/include/**/*.hpp",
]),
includes = [
# this will add both nccl/src/include in repo and
# bazel-out/<compilation_mode>/bin/nccl/src/include to include paths
# so the previous expand_template generate nccl.h to the very path!
"src/include",
],
)

cuda_objects(
name = "nccl_device_common",
srcs = [
"src/collectives/device/functions.cu",
"src/collectives/device/onerank_reduce.cu",
] + glob([
"src/collectives/device/**/*.h",
]),
copts = if_cuda_nvcc(["--extended-lambda"]),
ptxasopts = ["-maxrregcount=96"],
deps = [":nccl_include"],
)

# must be manually disabled if cuda version is lower than 11.
USE_BF16 = True

filegroup(
name = "collective_dev_hdrs",
srcs = [
"src/collectives/device/all_gather.h",
"src/collectives/device/all_reduce.h",
"src/collectives/device/broadcast.h",
"src/collectives/device/common.h",
"src/collectives/device/common_kernel.h",
"src/collectives/device/gen_rules.sh",
"src/collectives/device/op128.h",
"src/collectives/device/primitives.h",
"src/collectives/device/prims_ll.h",
"src/collectives/device/prims_ll128.h",
"src/collectives/device/prims_simple.h",
"src/collectives/device/reduce.h",
"src/collectives/device/reduce_kernel.h",
"src/collectives/device/reduce_scatter.h",
"src/collectives/device/sendrecv.h",
],
)

# cuda_objects for each type of primitive
nccl_primitive(
name = "all_gather",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "all_reduce",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "broadcast",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "reduce",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "reduce_scatter",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

nccl_primitive(
name = "sendrecv",
hdrs = ["collective_dev_hdrs"],
use_bf16 = USE_BF16,
deps = [":nccl_device_common"],
)

# device link
cuda_library(
name = "collectives",
rdc = 1,
deps = [
":all_gather",
":all_reduce",
":broadcast",
":reduce",
":reduce_scatter",
":sendrecv",
],
alwayslink = 1,
)

cc_binary(
name = "nccl",
srcs = glob(
[
"src/*.cc",
"src/collectives/*.cc",
"src/graph/*.cc",
"src/graph/*.h",
"src/misc/*.cc",
"src/transport/*.cc",
],
exclude = [
# https://github.com/NVIDIA/nccl/issues/658
"src/enhcompat.cc",
],
),
copts = if_cuda_clang(["-xcu"]),
linkshared = 1,
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":collectives",
":nccl_include",
"@rules_cuda//cuda:runtime",
],
)

# To allow downstream targets to link with the nccl shared library, we need to `cc_import` it again.
# See https://groups.google.com/g/bazel-discuss/c/RtbidPdVFyU/m/TsUDOVHIAwAJ
cc_import(
name = "nccl_shared",
shared_library = ":nccl",
visibility = ["//visibility:public"],
)
Loading

0 comments on commit 9bd5bde

Please sign in to comment.