Skip to content

Commit d5702ef

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Install rocm nightly (#64)
Summary: We are building rocm nightly docker on-demand, also update the third-party kernel inventories with their dependencies. Pull Request resolved: #64 Test Plan: CI Fixes #49 Reviewed By: FindHao Differential Revision: D66305346 Pulled By: xuzhao9 fbshipit-source-id: 15a5277d9cc34f88b18ab510371cdf411b06f120
1 parent 45d195c commit d5702ef

9 files changed

+301
-82
lines changed

.github/workflows/docker-rocm.yaml

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
name: TritonBench Nightly ROCM Docker Build
2+
on:
3+
pull_request:
4+
paths:
5+
- .github/workflows/docker-rocm.yaml
6+
- docker/tritonbench-rocm-nightly.dockerfile
7+
workflow_dispatch:
8+
inputs:
9+
nightly_date:
10+
description: "PyTorch nightly version"
11+
required: false
12+
env:
13+
CONDA_ENV: "tritonbench"
14+
SETUP_SCRIPT: "/workspace/setup_instance.sh"
15+
16+
jobs:
17+
build-push-docker:
18+
if: ${{ github.repository_owner == 'pytorch-labs' }}
19+
runs-on: 32-core-ubuntu
20+
environment: docker-s3-upload
21+
steps:
22+
- name: Checkout
23+
uses: actions/checkout@v3
24+
with:
25+
path: tritonbench
26+
- name: Login to GitHub Container Registry
27+
if: github.event_name != 'pull_request'
28+
uses: docker/login-action@v2
29+
with:
30+
registry: ghcr.io
31+
username: pytorch-labs
32+
password: ${{ secrets.TRITONBENCH_ACCESS_TOKEN }}
33+
- name: Build TritonBench nightly docker
34+
run: |
35+
set -x
36+
export NIGHTLY_DATE="${{ github.event.inputs.nightly_date }}"
37+
cd tritonbench/docker
38+
# branch name is github.head_ref when triggered by pull_request
39+
# and it is github.ref_name when triggered by workflow_dispatch
40+
branch_name=${{ github.head_ref || github.ref_name }}
41+
docker build . --build-arg TRITONBENCH_BRANCH="${branch_name}" --build-arg FORCE_DATE="${NIGHTLY_DATE}" \
42+
-f tritonbench-rocm-nightly.dockerfile -t ghcr.io/pytorch-labs/tritonbench:rocm-latest
43+
# Extract pytorch version from the docker
44+
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:rocm-latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
45+
export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}")
46+
docker tag ghcr.io/pytorch-labs/tritonbench:rocm-latest ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG}
47+
- name: Push docker to remote
48+
if: github.event_name != 'pull_request'
49+
run: |
50+
# Extract pytorch version from the docker
51+
PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"')
52+
export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}")
53+
docker push ghcr.io/pytorch-labs/tritonbench:rocm-${DOCKER_TAG}
54+
docker push ghcr.io/pytorch-labs/tritonbench:rocm-latest
55+
concurrency:
56+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
57+
cancel-in-progress: true

.github/workflows/docker.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
pull_request:
77
paths:
88
- .github/workflows/docker.yaml
9-
- docker/*.dockerfile
9+
- docker/tritonbench-nightly.dockerfile
1010
workflow_dispatch:
1111
inputs:
1212
nightly_date:

README.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ $ python run.py --op gemm
3636

3737
We depend on the following projects as a source of customized Triton or CUTLASS kernels:
3838

39-
* (Required) [FBGEMM](https://github.com/pytorch/FBGEMM)
40-
* (Required) [kernels](https://github.com/triton-lang/kernels)
41-
* (Required) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders)
42-
* (Optional) [ThunderKittens](https://github.com/HazyResearch/ThunderKittens)
43-
* (Optional) [cutlass-kernels](https://github.com/ColfaxResearch/cutlass-kernels)
44-
* (Optional) [flash-attention](https://github.com/Dao-AILab/flash-attention)
39+
* (CUDA, HIP) [kernels](https://github.com/triton-lang/kernels)
40+
* (CUDA, HIP) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders)
41+
* (CUDA, HIP) [Liger-Kernel](https://github.com/linkedin/Liger-Kernel)
42+
* (CUDA) [xformers](https://github.com/facebookresearch/xformers)
43+
* (CUDA) [flash-attention](https://github.com/Dao-AILab/flash-attention)
44+
* (CUDA) [FBGEMM](https://github.com/pytorch/FBGEMM)
45+
* (CUDA) [ThunderKittens](https://github.com/HazyResearch/ThunderKittens)
46+
* (CUDA) [cutlass-kernels](https://github.com/ColfaxResearch/cutlass-kernels)
4547

4648

4749
## License

docker/tritonbench-nightly.dockerfile

+8-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ ENV SETUP_SCRIPT=/workspace/setup_instance.sh
99
ARG TRITONBENCH_BRANCH=${TRITONBENCH_BRANCH:-main}
1010
ARG FORCE_DATE=${FORCE_DATE}
1111

12+
# Install deps
13+
RUN sudo apt install -y patch
14+
1215
# Checkout TritonBench and submodules
1316
RUN git clone --recurse-submodules -b "${TRITONBENCH_BRANCH}" --single-branch \
1417
https://github.com/pytorch-labs/tritonbench /workspace/tritonbench
@@ -22,13 +25,13 @@ RUN cd /workspace/tritonbench && \
2225

2326
RUN cd /workspace/tritonbench && \
2427
. ${SETUP_SCRIPT} && \
25-
sudo python tools/cuda_utils.py --setup-cuda-softlink
28+
sudo python -m tools.cuda_utils --setup-cuda-softlink
2629

2730
# Install PyTorch nightly and verify the date is correct
2831
RUN cd /workspace/tritonbench && \
2932
. ${SETUP_SCRIPT} && \
30-
python tools/cuda_utils.py --install-torch-deps && \
31-
python tools/cuda_utils.py --install-torch-nightly
33+
python -m tools.cuda_utils --install-torch-deps && \
34+
python -m tools.cuda_utils --install-torch-nightly
3235

3336
# Check the installed version of nightly if needed
3437
RUN cd /workspace/tritonbench && \
@@ -37,9 +40,9 @@ RUN cd /workspace/tritonbench && \
3740
echo "torch version check skipped"; \
3841
elif [ -z "${FORCE_DATE}" ]; then \
3942
FORCE_DATE=$(date '+%Y%m%d') \
40-
python tools/cuda_utils.py --check-torch-nightly-version --force-date "${FORCE_DATE}"; \
43+
python -m tools.cuda_utils --check-torch-nightly-version --force-date "${FORCE_DATE}"; \
4144
else \
42-
python tools/cuda_utils.py --check-torch-nightly-version --force-date "${FORCE_DATE}"; \
45+
python -m tools.cuda_utils --check-torch-nightly-version --force-date "${FORCE_DATE}"; \
4346
fi
4447

4548
# Tritonbench library build and test require libcuda.so.1
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Build ROCM base docker file
2+
# We are not building AMD CI in a short term, but this could be useful
3+
# for sharing benchmark results with AMD.
4+
ARG BASE_IMAGE=rocm/pytorch:latest
5+
6+
FROM ${BASE_IMAGE}
7+
8+
ENV CONDA_ENV=pytorch
9+
ENV CONDA_ENV_TRITON_MAIN=triton-main
10+
ENV SETUP_SCRIPT=/workspace/setup_instance.sh
11+
ARG TRITONBENCH_BRANCH=${TRITONBENCH_BRANCH:-main}
12+
ARG FORCE_DATE=${FORCE_DATE}
13+
14+
RUN mkdir -p /workspace; touch "${SETUP_SCRIPT}"
15+
16+
RUN echo "\
17+
. /opt/conda/etc/profile.d/conda.sh\n\
18+
conda activate base\n\
19+
export CONDA_HOME=/opt/conda\n" > "${SETUP_SCRIPT}"
20+
21+
RUN echo ". /workspace/setup_instance.sh\n" >> ${HOME}/.bashrc
22+
23+
# Checkout TritonBench and submodules
24+
RUN git clone --recurse-submodules -b "${TRITONBENCH_BRANCH}" --single-branch \
25+
https://github.com/pytorch-labs/tritonbench /workspace/tritonbench
26+
27+
# Setup conda env
28+
RUN cd /workspace/tritonbench && \
29+
. ${SETUP_SCRIPT} && \
30+
python tools/python_utils.py --create-conda-env ${CONDA_ENV} && \
31+
echo "if [ -z \${CONDA_ENV} ]; then export CONDA_ENV=${CONDA_ENV}; fi" >> "${SETUP_SCRIPT}" && \
32+
echo "conda activate \${CONDA_ENV}" >> "${SETUP_SCRIPT}"
33+
34+
35+
# Install PyTorch nightly and verify the date is correct
36+
RUN cd /workspace/tritonbench && \
37+
. ${SETUP_SCRIPT} && \
38+
python -m tools.rocm_utils --install-torch-deps && \
39+
python -m tools.rocm_utils --install-torch-nightly
40+
41+
42+
# Install Tritonbench
43+
RUN cd /workspace/tritonbench && \
44+
bash .ci/tritonbench/install.sh

install.py

+11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tools.cuda_utils import CUDA_VERSION_MAP, DEFAULT_CUDA_VERSION
99
from tools.git_utils import checkout_submodules
1010
from tools.python_utils import pip_install_requirements
11+
from tools.torch_utils import is_hip
1112

1213
logging.basicConfig(level=logging.INFO)
1314
logger = logging.getLogger(__name__)
@@ -77,6 +78,13 @@ def install_liger():
7778
subprocess.check_call(cmd)
7879

7980

81+
def setup_hip(args: argparse.Namespace):
82+
# We have to disable all third-parties that donot support hip/rocm
83+
args.all = False
84+
args.liger = True
85+
args.hstu = True
86+
87+
8088
if __name__ == "__main__":
8189
parser = argparse.ArgumentParser(allow_abbrev=False)
8290
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
@@ -105,6 +113,9 @@ def install_liger():
105113
parser.add_argument("--test", action="store_true", help="Run tests")
106114
args = parser.parse_args()
107115

116+
if args.all and is_hip():
117+
setup_hip(args)
118+
108119
# install framework dependencies
109120
pip_install_requirements("requirements.txt")
110121
# checkout submodules

tools/cuda_utils.py

+9-70
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import importlib
1+
import argparse
22
import os
33
import re
44
import subprocess
55
from pathlib import Path
66

7-
from typing import Optional
8-
97
# defines the default CUDA version to compile against
108
DEFAULT_CUDA_VERSION = "12.4"
11-
REPO_ROOT = Path(__file__).parent.parent
129

1310
CUDA_VERSION_MAP = {
1411
"12.4": {
@@ -17,10 +14,6 @@
1714
"jax": "jax[cuda12]",
1815
},
1916
}
20-
PIN_CMAKE_VERSION = "3.22.*"
21-
22-
TORCH_NIGHTLY_PACKAGES = ["torch"]
23-
BUILD_REQUIREMENTS_FILE = REPO_ROOT.joinpath("utils", "build_requirements.txt")
2417

2518

2619
def _nvcc_output_match(nvcc_output, target_cuda_version):
@@ -94,6 +87,8 @@ def setup_cuda_softlink(cuda_version: str):
9487

9588

9689
def install_pytorch_nightly(cuda_version: str, env, dryrun=False):
90+
from .torch_utils import TORCH_NIGHTLY_PACKAGES
91+
9792
uninstall_torch_cmd = ["pip", "uninstall", "-y"]
9893
uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES)
9994
if dryrun:
@@ -137,68 +132,7 @@ def install_torch_deps(cuda_version: str):
137132
subprocess.check_call(cmd)
138133

139134

140-
def install_torch_build_deps(cuda_version: str):
141-
install_torch_deps(cuda_version=cuda_version)
142-
# Pin cmake version to stable
143-
# See: https://github.com/pytorch/builder/pull/1269
144-
torch_build_deps = [
145-
"cffi",
146-
"sympy",
147-
"typing_extensions",
148-
"future",
149-
"six",
150-
"dataclasses",
151-
"tabulate",
152-
"tqdm",
153-
"mkl",
154-
"mkl-include",
155-
f"cmake={PIN_CMAKE_VERSION}",
156-
]
157-
cmd = ["conda", "install", "-y"] + torch_build_deps
158-
subprocess.check_call(cmd)
159-
build_deps = ["ffmpeg"]
160-
cmd = ["conda", "install", "-y"] + build_deps
161-
subprocess.check_call(cmd)
162-
# pip build deps
163-
cmd = ["pip", "install", "-r"] + [str(BUILD_REQUIREMENTS_FILE.resolve())]
164-
subprocess.check_call(cmd)
165-
# conda forge deps
166-
# ubuntu 22.04 comes with libstdcxx6 12.3.0
167-
# we need to install the same library version in conda to maintain ABI compatibility
168-
conda_deps = ["libstdcxx-ng=12.3.0"]
169-
cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps
170-
subprocess.check_call(cmd)
171-
172-
173-
def get_torch_nightly_version(pkg_name: str):
174-
pkg = importlib.import_module(pkg_name)
175-
version = pkg.__version__
176-
regex = ".*dev([0-9]+).*"
177-
date_str = re.match(regex, version).groups()[0]
178-
pkg_ver = {"version": version, "date": date_str}
179-
return (pkg_name, pkg_ver)
180-
181-
182-
def check_torch_nightly_version(force_date: Optional[str] = None):
183-
pkg_versions = dict(map(get_torch_nightly_version, TORCH_NIGHTLY_PACKAGES))
184-
pkg_dates = [x[1]["date"] for x in pkg_versions.items()]
185-
if not len(set(pkg_dates)) == 1:
186-
raise RuntimeError(
187-
f"Found more than 1 dates in the torch nightly packages: {pkg_versions}."
188-
)
189-
if force_date and not pkg_dates[0] == force_date:
190-
raise RuntimeError(
191-
f"Force date value {force_date}, but found torch packages {pkg_versions}."
192-
)
193-
force_date_str = f"User force date {force_date}" if force_date else ""
194-
print(
195-
f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}"
196-
)
197-
198-
199135
if __name__ == "__main__":
200-
import argparse
201-
202136
parser = argparse.ArgumentParser()
203137
parser.add_argument(
204138
"--cudaver",
@@ -240,9 +174,14 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
240174
if args.install_torch_deps:
241175
install_torch_deps(cuda_version=args.cudaver)
242176
if args.install_torch_build_deps:
243-
install_torch_build_deps(cuda_version=args.cudaver)
177+
from .torch_utils import install_torch_build_deps
178+
179+
install_torch_deps(cuda_version=args.cudaver)
180+
install_torch_build_deps()
244181
if args.install_torch_nightly:
245182
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
246183
if args.check_torch_nightly_version:
184+
from .torch_utils import check_torch_nightly_version
185+
247186
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."
248187
check_torch_nightly_version(args.force_date)

0 commit comments

Comments
 (0)