Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hang and cuBLASMp error in pmatmul.cu #241

Open
GuoxiaWang opened this issue Dec 19, 2024 · 4 comments
Open

Hang and cuBLASMp error in pmatmul.cu #241

GuoxiaWang opened this issue Dec 19, 2024 · 4 comments
Labels

Comments

@GuoxiaWang
Copy link

GuoxiaWang commented Dec 19, 2024

Reproduction Environment:

  • Machine: H20
  • Configuration: Single machine with 8 GPUs
  • Docker Image: nvcr.io/nvidia/nvhpc:24.11-devel-cuda_multi-ubuntu20.04
  • NVIDIA-SMI 535.183.06
  • Driver Version: 535.183.06
  • CUDA Version: 12.2
  • cuBLASMp 0.3.1
  • nvshmem-3.1.7
wget https://developer.download.nvidia.com/compute/cublasmp/0.3.1/local_installers/cublasmp-local-repo-ubuntu2004-0.3.1_0.3.1-1_amd64.deb
dpkg -i cublasmp-local-repo-ubuntu2004-0.3.1_0.3.1-1_amd64.deb
cp /var/cublasmp-local-repo-ubuntu2004-0.3.1/cublasmp-*-keyring.gpg /usr/share/keyrings/
apt-get update
apt-get -y install cublasmp
apt-get -y install cublasmp-cuda-12
apt-get install libudev1=245.4-4ubuntu3.22 -y
apt-get install udev rdma-core -y

export NVSHMEM_BUILD_TESTS=0
export NVSHMEM_BUILD_EXAMPLE=0
export NVSHMEM_IBDEVX_SUPPORT=0
export NVSHMEM_IBGDA_SUPPORT=0
export NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY=0
export NVSHMEM_IBDEVX_SUPPORT=0
export NVSHMEM_IBRC_SUPPORT=1
export NVSHMEM_LIBFABRIC_SUPPORT=0
export NVSHMEM_MPI_SUPPORT=1
export NVSHMEM_USE_GDRCOPY=0
export NVSHMEM_TORCH_SUPPORT=1
export NVSHMEM_ENABLE_ALL_DEVICE_INLINING=1
export NVSHMEM_USE_NCCL=1
export NVSHMEM_DEFAULT_PMIX=1
export NVSHMEM_DEBUG=1
export NVSHMEM_TRACE=1
export NVSHMEM_PREFIX=/test/nvshmem_examples/nvshmem-3.1.7
export NCCL_HOME=/opt/nvidia/hpc_sdk/Linux_x86_64/24.11/comm_libs/12.6/nccl
export CUDA_HOME=/usr/local/cuda/
export MPI_HOME=/opt/nvidia/hpc_sdk/Linux_x86_64/24.11/comm_libs/12.6/hpcx/hpcx-2.20/ompi/

rm -rf nvshmem_src/build
cd nvshmem_src/

mkdir -p build
cd build

export PATH=/test/nvshmem_examples/cmake-3.31.2-linux-x86_64/bin:$PATH
export PATH=/opt/compiler/gcc-12/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/lib:/home/opt/nvidia_lib:/usr/lib64:/usr/local/lib

cmake .. -DCMAKE_CUDA_ARCHITECTURES="90" -DCMAKE_CUDA_COMPILER=/opt/nvidia/hpc_sdk/Linux_x86_64/24.11/cuda/12.6/bin/nvcc -DCMAKE_C_COMPILER=/opt/compiler/gcc-12/bin/gcc -DCMAKE_CXX_COMPILER=/opt/compiler/gcc-12/bin/g++ -DCMAKE_VERBOSE_MAKEFILE=ON

make -j64 install

Build and Execution Commands:

rm -rf build
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Release \
  -DCMAKE_CUDA_ARCHITECTURES="90" \
  -DCMAKE_CUDA_COMPILER=/opt/nvidia/hpc_sdk/Linux_x86_64/24.11/cuda/12.6/bin/nvcc \
  -DCMAKE_C_COMPILER=/opt/compiler/gcc-12/bin/gcc \
  -DCMAKE_CXX_COMPILER=/opt/compiler/gcc-12/bin/g++ \
  -DCUBLASMP_INCLUDE_DIRECTORIES=/usr/include/libcublasmp/12 \
  -DCUBLASMP_LIBRARIES=/usr/lib/x86_64-linux-gnu/libcublasmp/12/libcublasmp.so \
  -DCAL_INCLUDE_DIRECTORIES=/opt/nvidia/hpc_sdk/Linux_x86_64/24.11/math_libs/12.6/targets/x86_64-linux/include/ \
  -DCAL_LIBRARIES=/opt/nvidia/hpc_sdk/Linux_x86_64/24.11/math_libs/12.6/targets/x86_64-linux/lib/libcal.so \
  -DNVSHMEM_INCLUDE_DIRECTORIES=/test/nvshmem_examples/nvshmem-3.1.7/include/ \
  -DNVSHMEM_HOST_LIBRARIES=/test/nvshmem_examples/nvshmem-3.1.7/lib/libnvshmem_host.so \
  -DNVSHMEM_DEVICE_LIBRARIES=/test/nvshmem_examples/nvshmem-3.1.7/lib/libnvshmem_device.a
make -j

mpirun --allow-run-as-root -bind-to none -tag-output -timestamp-output -n 8 ./pmatmul

Issue 1: Hang

AG + Matmul

https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASMp/pmatmul.cu#L151-L153

        const int64_t m = 512 * nranks;
        const int64_t n = 512 * nranks;
        const int64_t k = 512;

Issue 2: cuBLASMp error

AG + Matmul

https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASMp/pmatmul.cu#L151-L153

        const int64_t m = 1024 * nranks;
        const int64_t n = 1024 * nranks;
        const int64_t k = 1024;
Thu Dec 19 17:26:19 2024[1,2]<stderr>:cuBLASMp error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:237 : 6
Thu Dec 19 17:26:19 2024[1,3]<stderr>:cuBLASMp error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:237 : 6
Thu Dec 19 17:26:19 2024[1,4]<stderr>:cuBLASMp error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:237 : 6
Thu Dec 19 17:26:19 2024[1,5]<stderr>:cuBLASMp error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:237 : 6
Thu Dec 19 17:26:19 2024[1,1]<stderr>:cuBLASMp error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:237 : 6
Thu Dec 19 17:26:19 2024[1,7]<stderr>:cuBLASMp error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:237 : 6
Thu Dec 19 17:26:19 2024[1,6]<stderr>:cuBLASMp error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:237 : 6
Thu Dec 19 17:26:19 2024[1,0]<stderr>:CAL error at /test/nvshmem_examples/CUDALibrarySamples/cuBLASMp/pmatmul.cu:266 : 5
@anderson101866
Copy link

Dear cuBLASMp experts,

From customer, they saw Issue1 sometimes may hang with their setup.

Also, for issue2, it crash with Illegal Memory Access

@almogsegal
Copy link
Contributor

Hi @GuoxiaWang,

Thanks for reaching out!

I couldn't reproduce the first issue, I suspect this is related to the environment somehow. Can you please try the following:

docker container run -it --gpus all --cap-add CAP_SYS_PTRACE --shm-size="8g" nvcr.io/nvidia/pytorch:24.12-py3

pip install nvidia-cublasmp-cu12

git clone https://github.com/NVIDIA/CUDALibrarySamples.git
cd CUDALibrarySamples/cuBLASMp
mkdir -p build
cd build
export CUBLASMP_HOME=/usr/local/lib/python3.12/dist-packages/nvidia/cublasmp/cu12
export CAL_HOME=/usr/local/lib/python3.12/dist-packages/nvidia/libcal/cu12
export NVSHMEM_HOME=/usr/local/lib/python3.12/dist-packages/nvidia/nvshmem
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_CUDA_ARCHITECTURES="70;80;90" -DCUBLASMP_INCLUDE_DIRECTORIES=${CUBLASMP_HOME}/include -DCUBLASMP_LIBRARIES=${CUBLASMP_HOME}/lib/libcublasmp.so.0 -DCAL_INCLUDE_DIRECTORIES=${CAL_HOME}/include -DCAL_LIBRARIES=${CAL_HOME}/lib/libcal.so.0 -DNVSHMEM_INCLUDE_DIRECTORIES=${NVSHMEM_HOME}/include -DNVSHMEM_HOST_LIBRARIES=${NVSHMEM_HOME}/lib/libnvshmem_host.so.3
make -j

UCX_TLS=tcp NVSHMEM_BOOTSTRAP_UID_PLUGIN=nvshmem_bootstrap_uid.so.3 mpirun --allow-run-as-root -np 8 pmatmul

As for the second issue, I could reproduce it. I need to debug this but I have a workaround to unblock you. To achieve better performance with AG+GEMM, B can be allocated using nvshmem_malloc. Doing so, resolves the issue. I.e. performing the following changes:

// lines 174-175:
d_W0 = (input_t*)nvshmem_malloc(loc_b_m * loc_b_n * sizeof(input_t));
d_X1 = (output_t*)nvshmem_malloc(loc_c_m * loc_c_n * sizeof(output_t));

// line 283:
nvshmem_free(d_W0);

// line 414
nvshmem_free(d_X1);

Please let me know if the above solves both issues for you.

Thanks,
Almog

@almogsegal
Copy link
Contributor

Hi @GuoxiaWang,

I've debugged the issue and found the problem is the sample. C matrix allocation was wrong and insufficient. I've fixed the sample. FYI, the performance improvement suggestion in my previous comment is still valid.

Please give it a try and keep me posted.

Thanks,
Almog

@GuoxiaWang
Copy link
Author

GuoxiaWang commented Dec 28, 2024

@almogsegal
Thank you very much for the support. I have tested it, and the sample passes successfully.

However, I have found an issue for PaddlePaddle weight layout: cuBLASMp only supports forward computation, and lacks APIs for backward gradients. To implement it, we face complexities due to layout constraints. For example, the backward computation of ReduceScatter requires Transpose operations, and when calculating Dw and Dx, it needs either two AllGatherMatmuls or one AllGather(Dy) plus Matmul and one AllGatherMatmuls.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants