Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][AMD][CI/Build][Doc] Upgrade to ROCm 6.1, Dockerfile improvements, test fixes #5422

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
32b9f67
Add ROCm tuned num_threads and partition_size for pagedattn
mawong-amd Jun 11, 2024
c3d3671
Use PYTORCH_ROCM_ARCH env variable to set target list
mawong-amd Jun 11, 2024
1b3b2a7
Upgrade Dockerfile to ROCm6.1 properly and improve logic, build FA an…
mawong-amd Jun 11, 2024
a4679ed
ROCm 6.1.2, reorder dependency installation for vLLM concurrent build
mawong-amd Jun 12, 2024
e5a3428
Standardize use of PyTorch 2.3.0 on ROCm 6.x
mawong-amd Jun 13, 2024
454cb21
Ensure use of docker buildkit for AMD container
mawong-amd Jun 13, 2024
84fe59d
Use nightly torch for ROCm 6.1, relax req to torch version >= 2.3.0
mawong-amd Jun 13, 2024
50b909b
Fix up rebase errors, upgrade to torch 2.4.0, install amdsmi
mawong-amd Jun 20, 2024
ef1e4cb
Fix test_cuda_device_count_stateless on ROCm
mawong-amd Jun 20, 2024
3bb3875
Use OpenAI triton instead of ROCm
mawong-amd Jun 20, 2024
90ab818
Revert ray workdir init from https://github.com/vllm-project/vllm/pul…
mawong-amd Jun 20, 2024
ba8130c
Use 240612 torch wheels, prevent runtime interference
mawong-amd Jun 20, 2024
2583953
Use CUDA_VISIBLE_DEVICES for ROCm tests
mawong-amd Jun 21, 2024
4786050
Reviewer comments on Dockerfile
mawong-amd Jun 21, 2024
85cb7d0
Unify setting CUDA_VISIBLE_DEVICES and carrying it over to HIP_VISIBL…
mawong-amd Jun 21, 2024
ae9e82c
Remove sccache so it does not shadow ccache
mawong-amd Jun 21, 2024
e49260d
Propagate CUDA_VISIBLE_DEVICES to HIP_VISIBLE_DEVICES if set before init
mawong-amd Jun 21, 2024
554cff1
Revert "Add ROCm tuned num_threads and partition_size for pagedattn"
mawong-amd Jun 21, 2024
32fb98b
Revert "Ensure use of docker buildkit for AMD container"
mawong-amd Jun 21, 2024
e239bfc
Fix HIP_VISIBLE_DEVICE updating in device_count_stateless
mawong-amd Jun 21, 2024
016a553
Robustify cuda_device_count_stateless logic on ROCm
mawong-amd Jun 23, 2024
40c33ec
Skip xfail tests on ROCm to conserve CI resources
mawong-amd Jun 24, 2024
c293b3a
Address reviewer comments
mawong-amd Jun 24, 2024
353c0b2
Add amdsmi support for GPU memory check
mawong-amd Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

#
# Try to find python package with an executable that exactly matches
Expand Down Expand Up @@ -98,18 +97,11 @@ elseif(HIP_FOUND)
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)

# ROCm 5.x
if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
"expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
endif()

# ROCm 6.x
if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
else()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
Expand Down
209 changes: 145 additions & 64 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

FROM $BASE_IMAGE

ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

RUN echo "Base image is $BASE_IMAGE"

ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

# Default ROCm 6.1 base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"

# Tested and supported base rocm/pytorch images
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"

# Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"

# Whether to build CK-based flash-attention
# If 0, will not build flash attention
# This is useful for gfx target where flash-attention is not supported
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
# Triton FA is used by default on ROCm now so this is unnecessary.
ARG BUILD_FA="1"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"

ARG FA_BRANCH="ae7928c"
RUN echo "FA_BRANCH is $FA_BRANCH"
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

# whether to build flash-attention
# if 0, will not build flash attention
# this is useful for gfx target where flash-attention is not supported
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"

# whether to build triton on rocm
# Whether to build triton on rocm
ARG BUILD_TRITON="1"
ARG TRITON_BRANCH="0ef1848"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
### Base image build stage
FROM $BASE_IMAGE AS base

# Import arg(s) defined before this build stage
ARG PYTORCH_ROCM_ARCH

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
Expand All @@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \
build-essential \
wget \
unzip \
nvidia-cuda-toolkit \
tmux \
ccache \
&& rm -rf /var/lib/apt/lists/*

### Mount Point ###
# When launching the container, mount the code directory to /app
# When launching the container, mount the code directory to /vllm-workspace
ARG APP_MOUNT=/vllm-workspace
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
RUN pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.4.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-5.7"*) \
pip uninstall -y torch \
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
*"rocm-6.0"*) \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know whether base image rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging works for pytorch rocm5.7 and rocm 6.0 pytorch wheels?

Copy link
Contributor Author

@mawong-amd mawong-amd Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand this question, or the spirit behind the question: is the question whether "pytorch_staging" works on ROCm 5.7 or 6.0 without wheels, or whether these wheels work on a non-"pytorch_staging" image, or something else? Because you'd never install rocm5.7 or rocm6.0 wheels on a rocm6.1.2 image.

For the first question, the answer is "no": you need at least PyTorch 2.4.0 for torch.cuda._device_count_amdsmi, while pytorch_staging goes up to PyTorch 2.3.0. For the second, the answer is "yes": the relevant wheels for each ROCm version are designed to (and work on) any official base image from the relevant versions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the confusing question. I was thinking about whether we should use the "staging" as the base image as the name means its testing status. And also whether rocm-6.0 + torch 2.4 or rocm-5.7 + torch 2.4, works out of box, for certain users on certain devices (especially, navi3x users, for example), because in the past the supported based docker images has torch 2.0.1 or torch 2.1.1.

Copy link
Contributor Author

@mawong-amd mawong-amd Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Staging" in this case refers to the PyTorch version (2.3.0, which we are targeting for stable release sometime), so it's a moot point when we're upgrading to a PyTorch 2.4.0 nightly. And this particular nightly has not raised problems on a wide range of internal testing on multiple GPUs architectures, including MI300, MI250, and Navi32.

Let's keep in mind the broader picture here: whether this PR is an improvement to the status quo, as opposed to whether this PR is "perfect". The status quo has multiple failing tests, even on the tests AMD supports, and most importantly, does not have ROCm 6.1 support, which brings in a host of performance improvements. And upstream vLLM has moved such that PyTorch 2.3+/2.4+ on ROCm is needed for a full complement of tests.

We can offer this support: even though on paper it is "experimental", it has not raised issues in testing. Unless someone provides a specific failing example that arises as a result of this PR, quibbling about potential issues that may arise because of the "experimental" or "nightly" status, without doing the legwork to verify that such issues actually exist, seems to me to be an example of making the perfect the enemy of the good.

pip uninstall -y torch \
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
*"rocm-6.1"*) \
pip uninstall -y torch \
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
*) ;; esac

ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:

# Install ROCm flash-attention
RUN if [ "$BUILD_FA" = "1" ]; then \
mkdir libs \
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ENV CCACHE_DIR=/root/.cache/ccache


### AMD-SMI build stage
FROM base AS build_amdsmi
# Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=/install


### Flash-Attention wheel build stage
FROM base AS build_fa
ARG BUILD_FA
ARG FA_GFX_ARCHS
ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_FA" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
&& if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \
&& cd ..; \
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-5.7"*) \
export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
&& patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
*) ;; esac \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi

# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually removed it so that later steps of numpy upgrade can continue
RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi

# build triton
RUN if [ "$BUILD_TRITON" = "1" ]; then \
### Triton wheel build stage
FROM base AS build_triton
ARG BUILD_TRITON
ARG TRITON_BRANCH
# Build triton wheel if `BUILD_TRITON = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCm/triton.git \
&& cd triton/python \
&& pip3 install . \
&& cd ../..; \
&& git clone https://github.com/OpenAI/triton.git \
&& cd triton \
&& git checkout "${TRITON_BRANCH}" \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=/install; \
# Create an empty directory otherwise as later build stages expect one
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
else mkdir -p /install; \
fi

WORKDIR /vllm-workspace

### Final vLLM build stage
FROM base AS final
# Import the vLLM development directory from the build context
COPY . .

#RUN python3 -m pip install pynvml # to be removed eventually
RUN python3 -m pip install --upgrade pip numba
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
RUN case "$(which python3)" in \
*"/opt/conda/envs/py_3.9"*) \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
*) ;; esac

# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade numba scipy huggingface-hub[cli]

# make sure punica kernels are built (for LoRA)
# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
# Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
# Silences the HF Tokenizers warning
ENV TOKENIZERS_PARALLELISM=false

WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so

ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
RUN --mount=type=cache,target=${CCACHE_DIR} \
--mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
&& python3 setup.py install \
&& export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
&& cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
&& cd ..
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.0"*) \
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
*"rocm-6.1"*) \
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
&& cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
# Prevent interference if torch bundles its own HIP runtime
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
*) ;; esac \
&& python3 setup.py clean --all \
&& python3 setup.py develop

# Copy amdsmi wheel into final image
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
mkdir -p libs \
&& cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y amdsmi;

# Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y triton; fi

# Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y flash-attn; fi

# Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \
pip install libs/*.whl; fi

CMD ["/bin/bash"]
20 changes: 12 additions & 8 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,31 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
if (${GPU_LANG} STREQUAL "HIP")
#
# `GPU_ARCHES` controls the `--offload-arch` flags.
# `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
# via the `PYTORCH_ROCM_ARCH` env variable.
#

# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
# "rocm_agent_enumerator" in "enable_language(HIP)"
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
#
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
Copy link
Collaborator

@hongxiayang hongxiayang Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reusing PYTORCH_ROCM_ARCH environment variable in VLLM environment exposes some potential problems. This environment variable was used by pytorch on rocm when building its rocm backend (https://github.com/pytorch/pytorch/blob/0707811286d1846209676435f4f86f2b4b3d1a17/cmake/Dependencies.cmake#L1075). Reuse this will hide this information, and may potentially make the debugging more difficult.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure if this is a concern:

  1. This just restores the behavior to what it was prior to Cmake based build system #2830
  2. We also now install torch wheels, so PYTORCH_ROCM_ARCH in the base image is decoupled from whatever "Pytorch used when building its ROCm backend"
  3. The debugging info provided by this environment variable seems to be very limited.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am concerned about the name clash. Can you change it to VLLM_ROCM_SUPPORTED_ARCHS?
Though the dockerfile in this PR installed wheels, the wheels was built using the same environment variable and same logic.

Copy link
Contributor Author

@mawong-amd mawong-amd Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was working fine prior to #2830 and the workflow people are used to is to specify PYTORCH_ROCM_ARCH to decide the architectures to build vLLM for: this has been the case at ROCm/vllm and here, just that it hasn't been working here in a while.. I haven't yet seen a good reason to change this workflow. If this change is made, it should be made in a separate PR with proper discussion of the change.

Also, if the wheels are all built using the same env var, then it can be easily remembered as a constant.

set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
else()
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
endif()
#
# Find the intersection of the supported + detected architectures to
# set the module architecture flags.
#

set(VLLM_ROCM_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")

set(${GPU_ARCHES})
foreach (_ARCH ${VLLM_ROCM_SUPPORTED_ARCHS})
foreach (_ARCH ${HIP_ARCHITECTURES})
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
list(APPEND ${GPU_ARCHES} ${_ARCH})
endif()
endforeach()

if(NOT ${GPU_ARCHES})
message(FATAL_ERROR
"None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif()

Expand Down
6 changes: 3 additions & 3 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Option 2: Build from source
- `Pytorch <https://pytorch.org/>`_
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_

For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.

Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started <https://pytorch.org/get-started/locally/>`_

Expand Down Expand Up @@ -126,12 +126,12 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl

$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation


.. tip::

- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention.
- To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of pytorch, ideally, should match the ROCm driver version.
4 changes: 2 additions & 2 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# and debugging.
import ray

from ..utils import VLLM_PATH, RemoteOpenAIServer
from ..utils import RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"


@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
ray.init()
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
yield
ray.shutdown()

Expand Down
Loading
Loading