diff --git a/README.md b/README.md index aa1ac2347..355c87ce8 100644 --- a/README.md +++ b/README.md @@ -37,25 +37,7 @@ Before running examples, build and install libs under corelib following instruct - [HSTU attention documentation](./corelib/hstu/README.md) - [Dynamic Embeddings documentation](./corelib/dynamicemb/README.md) -On top of those two core libs, Megatron-Core along with other libs are required. You can install them via pypi package: - -```bash -pip install torchx gin-config torchmetrics==1.0.3 typing-extensions iopath megatron-core==0.9.0 -``` - -If you fail to install the megatron-core package, usually due to the python version incompatibility, please try to clone and then install the source code. - -```bash -git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git megatron-lm && \ -pip install -e ./megatron-lm -``` - -We provide our custom HSTU CUDA operators for enhanced performance. You need to install these operators using the following command: - -```bash -cd /workspace/recsys-examples/examples/hstu && \ -python setup.py install -``` +On top of those two core libs, Megatron-Core along with other libs and custom HSTU CUDA operators are required. You can refer to the [dockerfile](./docker/Dockerfile) for detailed install instruction. # Get Started The examples we supported: diff --git a/corelib/dynamicemb/setup.py b/corelib/dynamicemb/setup.py index 473296b8f..99becc1b0 100644 --- a/corelib/dynamicemb/setup.py +++ b/corelib/dynamicemb/setup.py @@ -83,6 +83,8 @@ def get_extensions(): "arch=compute_80,code=sm_80", "-gencode", "arch=compute_90,code=sm_90", + "-gencode", + "arch=compute_100,code=sm_100", "-w", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", diff --git a/corelib/hstu/csrc/hstu_attn/hstu_api.cpp b/corelib/hstu/csrc/hstu_attn/hstu_api.cpp index 68c864af2..5dc33ed75 100755 --- a/corelib/hstu/csrc/hstu_attn/hstu_api.cpp +++ b/corelib/hstu/csrc/hstu_attn/hstu_api.cpp @@ -556,7 +556,9 @@ std::vector hstu_varlen_bwd( const bool deterministic) { auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK(dprops->major >= 8, "HSTU only supports Ampere GPUs or newer."); - TORCH_CHECK(dprops->major == 8 && dprops->minor == 0, "HSTU backward does not support sm86 or sm89."); + if (dprops->major == 8) { + TORCH_CHECK(dprops->minor == 0, "For Ampere GPUs, HSTU backward does not support sm86 or sm89."); + } auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); diff --git a/corelib/hstu/setup.py b/corelib/hstu/setup.py index bea3626d2..4c5538f5b 100644 --- a/corelib/hstu/setup.py +++ b/corelib/hstu/setup.py @@ -242,6 +242,8 @@ def generate_cuda_sources(): cc_flag = [] cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_100,code=sm_100") # cc_flag.append("arch=compute_86,code=sm_86") if FORCE_CXX11_ABI: diff --git a/docker/Dockerfile b/docker/Dockerfile index ebacf2621..f77d5abad 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.11-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:25.03-py3 ARG DEVEL_IMAGE=devel FROM ${BASE_IMAGE} AS devel @@ -13,9 +13,9 @@ RUN ARCH=$([ "${TARGETPLATFORM}" = "linux/arm64" ] && echo "aarch64" || echo "x8 ln -s /usr/local/cuda-12.8/targets/x86_64-linux/lib/stubs/libnvidia-ml.so /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1; \ else \ if [ ${ARCH} = "aarch64" ]; then \ - ln -s /usr/local/cuda-12.6/targets/sbsa-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \ + ln -s /usr/local/cuda-12.8/targets/sbsa-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \ else \ - ln -s /usr/local/cuda-12.6/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \ + ln -s /usr/local/cuda-12.8/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \ fi \ fi @@ -24,14 +24,14 @@ RUN if [ "${INFERENCEBUILD}" != "1" ]; then \ pip install -e ./megatron-lm; \ fi -RUN pip install torchx gin-config torchmetrics==1.0.3 typing-extensions iopath pyvers +RUN pip install torchx gin-config torchmetrics typing-extensions iopath -RUN pip install --no-cache-dir setuptools==69.5.1 setuptools-git-versioning scikit-build && \ +RUN pip install --no-cache-dir setuptools==75.8.2 setuptools-git-versioning scikit-build && \ git clone --recursive -b v1.2.0 https://github.com/pytorch/FBGEMM.git fbgemm && \ cd fbgemm/fbgemm_gpu && \ - python setup.py install --package_variant=cuda -DTORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 9.0" + python setup.py install --package_variant=cuda -DTORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 9.0 10.0" -RUN pip install --no-deps tensordict orjson && \ +RUN pip install --no-deps tensordict orjson pyvers && \ git clone --recursive -b v1.2.0 https://github.com/pytorch/torchrec.git torchrec && \ cd torchrec && \ pip install --no-deps . diff --git a/examples/commons/checkpoint/checkpoint.py b/examples/commons/checkpoint/checkpoint.py index 958bec301..87ec24852 100644 --- a/examples/commons/checkpoint/checkpoint.py +++ b/examples/commons/checkpoint/checkpoint.py @@ -110,7 +110,7 @@ def load( save_path = os.path.join( path, "torch_module", "model.{}.pth".format(dist.get_rank()) ) - state_dict = torch.load(save_path) + state_dict = torch.load(save_path, weights_only=False) unwrapped_module.load_state_dict(state_dict["model_state_dict"]) if dense_optimizer and state_dict["optimizer_state_dict"]: dense_optimizer.load_state_dict(state_dict["optimizer_state_dict"]) diff --git a/examples/hstu/modules/hstu_attention.py b/examples/hstu/modules/hstu_attention.py index 3e864b439..fe78222af 100644 --- a/examples/hstu/modules/hstu_attention.py +++ b/examples/hstu/modules/hstu_attention.py @@ -426,15 +426,14 @@ def create_hstu_attention( linear_dim, is_causal, ) - elif sm_major_version == 8 and sm_minor_version == 0: - return FusedHSTUAttention( - num_heads, - attention_dim, - linear_dim, - is_causal, - ) - print( - "CUTLASS backend only support H100, H20 and A100, fallback to PyTorch backend" + assert sm_major_version >= 8, "Ampere or Ampere next GPU is required." + if sm_major_version == 8: + assert sm_minor_version == 0, "For Ampere, only A100 is supported." + return FusedHSTUAttention( + num_heads, + attention_dim, + linear_dim, + is_causal, ) elif kernel_backend == KernelBackend.TRITON: if is_causal: diff --git a/examples/hstu/ops/fused_hstu_op.py b/examples/hstu/ops/fused_hstu_op.py index 2a1cdddde..eb57377db 100644 --- a/examples/hstu/ops/fused_hstu_op.py +++ b/examples/hstu/ops/fused_hstu_op.py @@ -192,10 +192,8 @@ def _ln_linear_silu_fwd( sm = torch.cuda.get_device_properties(0).major if sm == 8: addmm_silu_fwd_impl = triton_addmm_silu_fwd - elif sm == 9: - addmm_silu_fwd_impl = torch_addmm_silu_fwd else: - raise ValueError(f"Unsupported SM major version: {sm}") + addmm_silu_fwd_impl = torch_addmm_silu_fwd # 2. linear & silu # bias is 1D linear_uvqk, silu_linear_uvqk = addmm_silu_fwd_impl( @@ -276,18 +274,19 @@ def _hstu_attn_cutlass_fwd( alpha, ): sm_major_version = torch.cuda.get_device_properties(0).major + sm_minor_version = torch.cuda.get_device_properties(0).minor extension_args = () - if sm_major_version == 8: - cutlass_hstu_varlen_fwd = flash_attn_cuda_ampere.varlen_fwd - ampere_paged_kv_args = (None, None, None, None, None) - extension_args = ampere_paged_kv_args - elif sm_major_version == 9: + if sm_major_version == 9: cutlass_hstu_varlen_fwd = flash_attn_cuda_hopper.varlen_fwd hopper_fp8_args = (None, None, None) extension_args = hopper_fp8_args - else: - raise ValueError(f"Unsupported SM major version: {sm_major_version}") + assert sm_major_version >= 8, "Ampere or Ampere next GPU is required." + if sm_major_version == 8: + assert sm_minor_version == 0, "For Ampere, only A100 is supported." + cutlass_hstu_varlen_fwd = flash_attn_cuda_ampere.varlen_fwd + ampere_paged_kv_args = (None, None, None, None, None) + extension_args = ampere_paged_kv_args assert q.dim() == 3, "q shape should be (L, num_heads, head_dim)" assert k.dim() == 3, "k shape should be (L, num_heads, head_dim)" assert v.dim() == 3, "v shape should be (L, num_heads, hidden_dim)" @@ -392,10 +391,8 @@ def _linear_residual_fwd( sm = torch.cuda.get_device_properties(0).major if sm == 8: addmm_silu_fwd_impl = triton_addmm_silu_fwd - elif sm == 9: - addmm_silu_fwd_impl = torch_addmm_silu_fwd else: - raise ValueError(f"Unsupported SM major version: {sm}") + addmm_silu_fwd_impl = torch_addmm_silu_fwd y, _ = addmm_silu_fwd_impl( x=x, w=w, @@ -602,12 +599,14 @@ def _hstu_attn_cutlass_bwd( dv: Optional[torch.Tensor] = None, ): sm_major_version = torch.cuda.get_device_properties(0).major - if sm_major_version == 8: - cutlass_hstu_varlen_bwd = flash_attn_cuda_ampere.varlen_bwd - elif sm_major_version == 9: + sm_minor_version = torch.cuda.get_device_properties(0).minor + if sm_major_version == 9: cutlass_hstu_varlen_bwd = flash_attn_cuda_hopper.varlen_bwd else: - raise ValueError(f"Unsupported SM major version: {sm_major_version}") + assert sm_major_version >= 8, "Ampere or Ampere next GPU is required." + if sm_major_version == 8: + assert sm_minor_version == 0, "For Ampere, only A100 is supported." + cutlass_hstu_varlen_bwd = flash_attn_cuda_ampere.varlen_bwd assert dout.dim() == 3 dq, dk, dv, _ = cutlass_hstu_varlen_bwd( dout,