Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,7 @@ Before running examples, build and install libs under corelib following instruct
- [HSTU attention documentation](./corelib/hstu/README.md)
- [Dynamic Embeddings documentation](./corelib/dynamicemb/README.md)

On top of those two core libs, Megatron-Core along with other libs are required. You can install them via pypi package:

```bash
pip install torchx gin-config torchmetrics==1.0.3 typing-extensions iopath megatron-core==0.9.0
```

If you fail to install the megatron-core package, usually due to the python version incompatibility, please try to clone and then install the source code.

```bash
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git megatron-lm && \
pip install -e ./megatron-lm
```

We provide our custom HSTU CUDA operators for enhanced performance. You need to install these operators using the following command:

```bash
cd /workspace/recsys-examples/examples/hstu && \
python setup.py install
```
On top of those two core libs, Megatron-Core along with other libs and custom HSTU CUDA operators are required. You can refer to the [dockerfile](./docker/Dockerfile) for detailed install instruction.

# Get Started
The examples we supported:
Expand Down
2 changes: 2 additions & 0 deletions corelib/dynamicemb/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def get_extensions():
"arch=compute_80,code=sm_80",
"-gencode",
"arch=compute_90,code=sm_90",
"-gencode",
"arch=compute_100,code=sm_100",
"-w",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
Expand Down
4 changes: 3 additions & 1 deletion corelib/hstu/csrc/hstu_attn/hstu_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,9 @@ std::vector<at::Tensor> hstu_varlen_bwd(
const bool deterministic) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major >= 8, "HSTU only supports Ampere GPUs or newer.");
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0, "HSTU backward does not support sm86 or sm89.");
if (dprops->major == 8) {
TORCH_CHECK(dprops->minor == 0, "For Ampere GPUs, HSTU backward does not support sm86 or sm89.");
}
auto stream = at::cuda::getCurrentCUDAStream().stream();

auto q_dtype = q.dtype();
Expand Down
2 changes: 2 additions & 0 deletions corelib/hstu/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def generate_cuda_sources():
cc_flag = []
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_100,code=sm_100")
# cc_flag.append("arch=compute_86,code=sm_86")

if FORCE_CXX11_ABI:
Expand Down
14 changes: 7 additions & 7 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.11-py3
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:25.03-py3
ARG DEVEL_IMAGE=devel

FROM ${BASE_IMAGE} AS devel
Expand All @@ -13,9 +13,9 @@ RUN ARCH=$([ "${TARGETPLATFORM}" = "linux/arm64" ] && echo "aarch64" || echo "x8
ln -s /usr/local/cuda-12.8/targets/x86_64-linux/lib/stubs/libnvidia-ml.so /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1; \
else \
if [ ${ARCH} = "aarch64" ]; then \
ln -s /usr/local/cuda-12.6/targets/sbsa-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
ln -s /usr/local/cuda-12.8/targets/sbsa-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
else \
ln -s /usr/local/cuda-12.6/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
ln -s /usr/local/cuda-12.8/targets/${ARCH}-linux/lib/stubs/libnvidia-ml.so /usr/lib/${ARCH}-linux-gnu/libnvidia-ml.so.1; \
fi \
fi

Expand All @@ -24,14 +24,14 @@ RUN if [ "${INFERENCEBUILD}" != "1" ]; then \
pip install -e ./megatron-lm; \
fi

RUN pip install torchx gin-config torchmetrics==1.0.3 typing-extensions iopath pyvers
RUN pip install torchx gin-config torchmetrics typing-extensions iopath

RUN pip install --no-cache-dir setuptools==69.5.1 setuptools-git-versioning scikit-build && \
RUN pip install --no-cache-dir setuptools==75.8.2 setuptools-git-versioning scikit-build && \
git clone --recursive -b v1.2.0 https://github.com/pytorch/FBGEMM.git fbgemm && \
cd fbgemm/fbgemm_gpu && \
python setup.py install --package_variant=cuda -DTORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 9.0"
python setup.py install --package_variant=cuda -DTORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 9.0 10.0"

RUN pip install --no-deps tensordict orjson && \
RUN pip install --no-deps tensordict orjson pyvers && \
git clone --recursive -b v1.2.0 https://github.com/pytorch/torchrec.git torchrec && \
cd torchrec && \
pip install --no-deps .
Expand Down
2 changes: 1 addition & 1 deletion examples/commons/checkpoint/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def load(
save_path = os.path.join(
path, "torch_module", "model.{}.pth".format(dist.get_rank())
)
state_dict = torch.load(save_path)
state_dict = torch.load(save_path, weights_only=False)
unwrapped_module.load_state_dict(state_dict["model_state_dict"])
if dense_optimizer and state_dict["optimizer_state_dict"]:
dense_optimizer.load_state_dict(state_dict["optimizer_state_dict"])
17 changes: 8 additions & 9 deletions examples/hstu/modules/hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,14 @@ def create_hstu_attention(
linear_dim,
is_causal,
)
elif sm_major_version == 8 and sm_minor_version == 0:
return FusedHSTUAttention(
num_heads,
attention_dim,
linear_dim,
is_causal,
)
print(
"CUTLASS backend only support H100, H20 and A100, fallback to PyTorch backend"
assert sm_major_version >= 8, "Ampere or Ampere next GPU is required."
if sm_major_version == 8:
assert sm_minor_version == 0, "For Ampere, only A100 is supported."
return FusedHSTUAttention(
num_heads,
attention_dim,
linear_dim,
is_causal,
)
elif kernel_backend == KernelBackend.TRITON:
if is_causal:
Expand Down
33 changes: 16 additions & 17 deletions examples/hstu/ops/fused_hstu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,8 @@ def _ln_linear_silu_fwd(
sm = torch.cuda.get_device_properties(0).major
if sm == 8:
addmm_silu_fwd_impl = triton_addmm_silu_fwd
elif sm == 9:
addmm_silu_fwd_impl = torch_addmm_silu_fwd
else:
raise ValueError(f"Unsupported SM major version: {sm}")
addmm_silu_fwd_impl = torch_addmm_silu_fwd
# 2. linear & silu
# bias is 1D
linear_uvqk, silu_linear_uvqk = addmm_silu_fwd_impl(
Expand Down Expand Up @@ -276,18 +274,19 @@ def _hstu_attn_cutlass_fwd(
alpha,
):
sm_major_version = torch.cuda.get_device_properties(0).major
sm_minor_version = torch.cuda.get_device_properties(0).minor
extension_args = ()
if sm_major_version == 8:
cutlass_hstu_varlen_fwd = flash_attn_cuda_ampere.varlen_fwd
ampere_paged_kv_args = (None, None, None, None, None)
extension_args = ampere_paged_kv_args
elif sm_major_version == 9:
if sm_major_version == 9:
cutlass_hstu_varlen_fwd = flash_attn_cuda_hopper.varlen_fwd
hopper_fp8_args = (None, None, None)
extension_args = hopper_fp8_args

else:
raise ValueError(f"Unsupported SM major version: {sm_major_version}")
assert sm_major_version >= 8, "Ampere or Ampere next GPU is required."
if sm_major_version == 8:
assert sm_minor_version == 0, "For Ampere, only A100 is supported."
cutlass_hstu_varlen_fwd = flash_attn_cuda_ampere.varlen_fwd
ampere_paged_kv_args = (None, None, None, None, None)
extension_args = ampere_paged_kv_args
assert q.dim() == 3, "q shape should be (L, num_heads, head_dim)"
assert k.dim() == 3, "k shape should be (L, num_heads, head_dim)"
assert v.dim() == 3, "v shape should be (L, num_heads, hidden_dim)"
Expand Down Expand Up @@ -392,10 +391,8 @@ def _linear_residual_fwd(
sm = torch.cuda.get_device_properties(0).major
if sm == 8:
addmm_silu_fwd_impl = triton_addmm_silu_fwd
elif sm == 9:
addmm_silu_fwd_impl = torch_addmm_silu_fwd
else:
raise ValueError(f"Unsupported SM major version: {sm}")
addmm_silu_fwd_impl = torch_addmm_silu_fwd
y, _ = addmm_silu_fwd_impl(
x=x,
w=w,
Expand Down Expand Up @@ -602,12 +599,14 @@ def _hstu_attn_cutlass_bwd(
dv: Optional[torch.Tensor] = None,
):
sm_major_version = torch.cuda.get_device_properties(0).major
if sm_major_version == 8:
cutlass_hstu_varlen_bwd = flash_attn_cuda_ampere.varlen_bwd
elif sm_major_version == 9:
sm_minor_version = torch.cuda.get_device_properties(0).minor
if sm_major_version == 9:
cutlass_hstu_varlen_bwd = flash_attn_cuda_hopper.varlen_bwd
else:
raise ValueError(f"Unsupported SM major version: {sm_major_version}")
assert sm_major_version >= 8, "Ampere or Ampere next GPU is required."
if sm_major_version == 8:
assert sm_minor_version == 0, "For Ampere, only A100 is supported."
cutlass_hstu_varlen_bwd = flash_attn_cuda_ampere.varlen_bwd
assert dout.dim() == 3
dq, dk, dv, _ = cutlass_hstu_varlen_bwd(
dout,
Expand Down