Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion corelib/hstu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def hstu_attn_varlen_func(
# nheads should be divisible by nhead_rab
has_drab=False, # Whether to apply drab
is_delta_q=False, # Whether to apply delta_q
quantization_mode=-1, # -1: no quantization, 0: cast to fp8, 1: 1xDIM&128x1 quantization, 2: per-block quantization, 3: per-head quantization, 4: per-batch quantization, 5: per-tensor quantization.
quant_mode=-1, # -1: no quantization, 0: cast to fp8, 1: 1xDIM&128x1 quantization, 2: per-block quantization, 3: per-head quantization, 4: per-batch quantization, 5: per-tensor quantization.
)
```

Expand Down
33 changes: 32 additions & 1 deletion corelib/hstu/hopper/hstu_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,13 @@ def forward(
has_drab=False,
func=None,
quant_mode=-1,
):
):
# Debug: Log to file to avoid stdout buffering issues
import sys
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if rank == 0:
sys.stderr.write(f"[HSTU Attn Forward] quant_mode: {quant_mode}, q.dtype: {q.dtype}, k.dtype: {k.dtype}, v.dtype: {v.dtype}\n")
sys.stderr.flush()
vt = None
descale_q = None
descale_k = None
Expand Down Expand Up @@ -512,6 +518,12 @@ def forward(
)

with torch.cuda.nvtx.range("hstu_varlen_fwd_kernel"):
# Debug: Check dtypes before kernel
import sys
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if rank == 0:
sys.stderr.write(f"[Before CUDA kernel] quant_mode={quant_mode}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}\n")
sys.stderr.flush()
out, rab = _hstu_attn_varlen_forward(
q,
k,
Expand All @@ -537,6 +549,25 @@ def forward(
cu_seqlens_block_descale_q,
cu_seqlens_block_descale_k,
)

# Debug: Check dtypes after kernel
import sys
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if rank == 0:
sys.stderr.write(f"[After CUDA kernel] out.dtype={out.dtype}, original_q.dtype={ctx.q_fp16.dtype}\n")
sys.stderr.flush()

# Restore original dtype: CUDA kernel outputs FP16, but input may be BF16
# This ensures dtype consistency with the rest of the model
# TODO: Replace dtype conversion when kernel outputs consistent dtype
if quant_mode >= 0 and out.dtype != ctx.q_fp16.dtype:
import sys
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if rank == 0:
sys.stderr.write(f"[Dtype Cast] Converting {out.dtype} -> {ctx.q_fp16.dtype}\n")
sys.stderr.flush()
out = out.to(ctx.q_fp16.dtype)

ctx.save_for_backward(
q, k, v, rab, cu_seqlens_q, cu_seqlens_k, num_contexts, num_targets
)
Expand Down
2 changes: 2 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ RUN apt update -y --fix-missing && \

RUN pip install --no-cache pre-commit

RUN pip install --no-cache --upgrade --no-build-isolation transformer_engine[pytorch]

FROM ${DEVEL_IMAGE} AS build

WORKDIR /workspace/recsys-examples
Expand Down
2 changes: 2 additions & 0 deletions docker/Dockerfile.pyt_build
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ RUN pip install --no-deps tensordict orjson && \
cd torchrec && \
pip install --no-deps .

RUN pip install --no-cache --upgrade --no-build-isolation transformer_engine[pytorch]

RUN test -f /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1 || ln -s ${CUDA_HOME}/targets/x86_64-linux/lib/stubs/libnvidia-ml.so /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1

# Debugging - build w/DOCKER_BUILDKIT=1 to see output
Expand Down
8 changes: 8 additions & 0 deletions examples/hstu/configs/hstu_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class HSTUConfig(TransformerConfig):

kernel_backend: KernelBackend = KernelBackend.CUTLASS
hstu_layer_type: HSTULayerType = HSTULayerType.FUSED # DEBUG|NATIVE|FUSED
hstu_attn_quantization_mode: int = -1 # Default to -1, which means no fp8 for hstu attn
fp8_alignment_mode: Optional[str] = None

target_group_size: int = 1
learnable_input_layernorm: bool = True
Expand Down Expand Up @@ -144,6 +146,8 @@ def get_hstu_config(
norm_epsilon=1e-5,
is_causal: bool = True,
kernel_backend: KernelBackend = KernelBackend.CUTLASS,
hstu_attn_quantization_mode: int = -1,
fp8_alignment_mode: str = "truncate",
target_group_size: int = 1,
hstu_layer_type: HSTULayerType = HSTULayerType.FUSED,
learnable_input_layernorm: bool = True,
Expand All @@ -154,6 +158,7 @@ def get_hstu_config(
is_inference: bool = False,
add_uvqk_bias: bool = True,
fuse_norm_mul_dropout: bool = True,
**transformer_config_kwargs,
) -> HSTUConfig:
"""
Create the HSTU configuration.
Expand Down Expand Up @@ -223,4 +228,7 @@ def get_hstu_config(
add_uvqk_bias=add_uvqk_bias,
is_inference=is_inference,
fuse_norm_mul_dropout=fuse_norm_mul_dropout,
hstu_attn_quantization_mode=hstu_attn_quantization_mode,
fp8_alignment_mode=fp8_alignment_mode,
**transformer_config_kwargs,
)
4 changes: 2 additions & 2 deletions examples/hstu/model/ranking_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
from modules.multi_task_loss_module import MultiTaskLossModule
from torchrec.sparse.jagged_tensor import JaggedTensor


class RankingGR(BaseModel):
"""
A class representing the ranking model. Inherits from BaseModel. A ranking model consists of
a sparse architecture and a dense architecture. A ranking model is able to process multiple labels
and thus has multiple logit dimensions. Each label is associated with a loss functoin (e.g. BCE, CE).
and thus has multiple logit dimensions. Each label is associated with a loss function (e.g. BCE, CE).

Args:
hstu_config (HSTUConfig): The HSTU configuration.
Expand Down Expand Up @@ -110,6 +109,7 @@ def get_logit_and_labels(
"""
# DMP embedding
embeddings: Dict[str, JaggedTensor] = self._embedding_collection(batch.features)

# maybe freeze embedding for debugging
embeddings = self._embedding_collection._maybe_detach(embeddings)
# For model-parallel embedding, torchrec does gradient division by (tp_size * dp_size). However, we only need to divide by dp size. In such case, we need to scale the gradient by tp_size.
Expand Down
4 changes: 3 additions & 1 deletion examples/hstu/modules/fused_hstu_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import torch
from commons.utils.nvtx_op import output_nvtx_hook, register_setter_and_getter_for_nvtx
from configs import HSTUConfig
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self, config: HSTUConfig):
self._alpha = 1.0 / (self._attention_dim_per_head**0.5)
self._residual = config.residual
self._attn_backend = config.kernel_backend

self._quantization_mode = config.hstu_attn_quantization_mode
# stream and event are shared across all layers
self._wgrad_stream = config.async_wgrad_stream
self._wgrad_event = config.async_wgrad_event
Expand Down Expand Up @@ -155,6 +156,7 @@ def forward(self, jd: JaggedData) -> JaggedData:
wgrad_event=self._wgrad_event,
recompute_input_layernorm=self._recompute_input_layernorm,
recompute_input_silu=self._recompute_input_silu,
quant_mode = -1, # FusedHSTULayer using C++ API for HSTU attn, passthrough quant_mode when python API is used
)
return JaggedData(
values=output,
Expand Down
6 changes: 5 additions & 1 deletion examples/hstu/modules/hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def __init__(
attention_dim: int,
linear_dim: int,
is_causal: bool,
quant_mode: int = -1,
):
super().__init__()
from hopper.hstu_attn_interface import hstu_attn_varlen_func
Expand All @@ -331,7 +332,7 @@ def __init__(
assert (
self.linear_dim == self.attention_dim
), "only support linear_dim and attention_dim"

self.quant_mode = quant_mode
@output_nvtx_hook(nvtx_tag="FusedHSTUAttnHopper")
def forward(
self,
Expand Down Expand Up @@ -390,6 +391,7 @@ def forward(
window_size=(-1, 0) if self.is_causal else (-1, -1),
rab=None,
alpha=1.0 / (self.attention_dim**0.5),
quant_mode=self.quant_mode,
).view(-1, self.num_heads * self.linear_dim)


Expand All @@ -399,6 +401,7 @@ def create_hstu_attention(
attention_dim: int,
linear_dim: int,
is_causal: bool,
quant_mode: int = -1,
) -> HSTUAttention:
"""
Factory function to create an HSTUAttention module based on the kernel backend.
Expand All @@ -425,6 +428,7 @@ def create_hstu_attention(
attention_dim,
linear_dim,
is_causal,
quant_mode,
)
elif sm_major_version == 8 and sm_minor_version == 0:
return FusedHSTUAttention(
Expand Down
105 changes: 105 additions & 0 deletions examples/hstu/modules/hstu_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import warnings
from typing import Dict, Optional, Union

import torch
from commons.utils.nvtx_op import output_nvtx_hook
from commons.utils.logger import print_rank_0
from configs.hstu_config import HSTUConfig
from configs.inference_config import InferenceHSTUConfig
from dataset.utils import RankingBatch
Expand Down Expand Up @@ -302,8 +304,111 @@ def forward(
training=self.training,
).to(self._training_dtype)

# FP8 alignment: Truncate to nearest multiple of 16 for TE Linear fwd + bwdcompatibility
if not isinstance(self.config, InferenceHSTUConfig):
if self.config.fp8 is not None and self.training:
warnings.warn("Aligning JaggedData to nearest multiple of 16 for FP8 TE fwd + bwd compatibility")
jd = self._align_jagged_data_for_fp8(jd, fp8_alignment_mode=self.config.fp8_alignment_mode)

return jd

def _align_jagged_data_for_fp8(
Copy link
Collaborator

@JacoCheung JacoCheung Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @esoba , Since you have padding here, there should be a discarding process at the postprocessor before loss compute. That's being said,

 final_loss = drop_pad_values(final_loss)
 final_loss.mean().backward()

Otherwise, the padded token will impact the backward both data gradient and weight gradient even if the padded value is initialized as 0.

See our loss calculation.

And our post-processor (if it's ranking) will take the padded token as normal token..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I was seeing the issue when I set this to truncate (cut off the last N elements to get nearest divisible by 16), let me double check this to see if there is any undefined behavior doing it this way as well. Thanks for the catch!

self, jd: JaggedData, fp8_alignment_mode: str = "truncate"
) -> JaggedData:
"""
Aligns JaggedData to have total tokens divisible by 16 for FP8 compatibility.

Either truncates or pads tokens from the end of the batch (last sequence) to meet Transformer Engine divisibility requirement.

Args:
jd: Input JaggedData
fp8_alignment_mode: Alignment mode ("truncate" or "pad")

Returns:
Aligned JaggedData with maintained invariants
"""
num_tokens = jd.values.shape[0]

# Check if already aligned
if num_tokens % 16 == 0:
return jd

# Calculate aligned size based on mode
if fp8_alignment_mode == "truncate":
# Round down to previous multiple of 16
aligned_size = (num_tokens // 16) * 16
elif fp8_alignment_mode == "pad":
# Round up to next multiple of 16
aligned_size = ((num_tokens + 15) // 16) * 16
else:
raise ValueError(f"Invalid fp8 alignment mode: {fp8_alignment_mode}")

if aligned_size < 16:
# Too few tokens - pad to minimum size
padding_needed = 16 - num_tokens
jd.values = torch.nn.functional.pad(jd.values, (0, 0, 0, padding_needed))
return jd

# Truncate: Remove tokens from the end
if fp8_alignment_mode == "truncate":
tokens_to_remove = num_tokens - aligned_size
jd.values = jd.values[:aligned_size]

# Update metadata: Find which sequences are affected
# We're removing from the end, so only the last sequence(s) are affected
batch_size = jd.seqlen.shape[0]
new_seqlen = jd.seqlen.clone()

# Work backwards from last sequence
remaining_to_remove = tokens_to_remove
for seq_idx in range(batch_size - 1, -1, -1):
if remaining_to_remove <= 0:
break

current_len = new_seqlen[seq_idx].item()
if current_len <= remaining_to_remove:
# Remove entire sequence
new_seqlen[seq_idx] = 0
remaining_to_remove -= current_len
else:
# Partially truncate this sequence
new_seqlen[seq_idx] = current_len - remaining_to_remove
remaining_to_remove = 0

# Recompute offsets
new_seqlen_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(new_seqlen)

# Verify invariant
print_rank_0(f"Truncated tokens to {aligned_size} from {num_tokens} for FP8 TE compatibility")
assert new_seqlen_offsets[-1].item() == aligned_size, \
f"Offset mismatch: {new_seqlen_offsets[-1].item()} != {aligned_size}"

# TODO: verify padding logic with attention kernel
elif fp8_alignment_mode == "pad":
# Pad: Add tokens to the end to reach aligned size
# Leave seqlen unchanged, attn mask in attn kernel should ignore the rest
padding_needed = aligned_size - num_tokens
jd.values = torch.nn.functional.pad(jd.values, (0, 0, 0, padding_needed))

print_rank_0(f"Padded tokens to {aligned_size} from {num_tokens} for FP8 TE compatibility (as ghost tokens)")

# No need to update seqlen or offsets - padding is invisible to the model
return jd

return JaggedData(
values=jd.values,
seqlen=new_seqlen,
seqlen_offsets=new_seqlen_offsets,
max_seqlen=jd.max_seqlen,
max_num_candidates=jd.max_num_candidates,
num_candidates=jd.num_candidates,
num_candidates_offsets=jd.num_candidates_offsets,
contextual_max_seqlen=jd.contextual_max_seqlen,
contextual_seqlen=jd.contextual_seqlen,
contextual_seqlen_offsets=jd.contextual_seqlen_offsets,
has_interleaved_action=jd.has_interleaved_action,
)


class HSTUBlockPostprocessor(torch.nn.Module):
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/hstu/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
torch.Tensor: The output tensor.
"""
assert input.dim() == 2, "Tensor must be 2-dimensional"
return self._mlp(input)
return self._mlp(input)
10 changes: 10 additions & 0 deletions examples/hstu/modules/native_hstu_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import os
from functools import partial
import warnings

import nvtx
import torch
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(self, config: HSTUConfig):
attention_dim=self._attention_dim_per_head,
linear_dim=self._linear_dim_per_head,
is_causal=config.is_causal,
quant_mode=config.hstu_attn_quantization_mode,
)
register_setter_and_getter_for_nvtx(
HSTULayer.forward, key_or_attr_name="values"
Expand Down Expand Up @@ -216,8 +218,10 @@ def forward(self, jd: JaggedData) -> JaggedData:
bias=self._input_layernorm_bias,
eps=self._eps,
)

with nvtx.annotate("hstu uvqk linear_silu fwd", color="BLUE"):
tu, tv, tq, tk = self.get_user_value_query_key_tensors(normed_x)

# TODO: remove contiguous once cutlass backend is ready
with nvtx.annotate("hstu attn fwd", color="BLUE"):
jagged_attn_output = self._attn_func(
Expand All @@ -230,6 +234,12 @@ def forward(self, jd: JaggedData) -> JaggedData:
max_seqlen=jd.max_seqlen,
target_group_size=self._target_group_size,
)

# TODO: Remove this once the attention kernel outputs consistent dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JacoCheung could you double check why we need this?

Copy link
Collaborator

@JacoCheung JacoCheung Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When training with fp8 enabled, the model weight could be bf16/fp16 (NetworkArgs.dtype_str) (Usually it's bf16). So as activation is.

But the hstu kernel output is fp16, so here we need a cast between fp16->bf16. @shijieliu

@esoba do you think if there's a need to move the cast into the kernel or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think casting fp16 to bf16 wouldn't result in error since dynamic range is larger, but would assume some additional quantization error pops up (ideally that gets learned by model anyways). I think for consistency it would probably be better to move it into the kernel but as a workaround casting outside should be fine.

expected_dtype = torch.bfloat16 if self.config.bf16 else torch.float16
if jagged_attn_output.dtype != expected_dtype:
warnings.warn(f"Jagged attn output dtype mismatch: {jagged_attn_output.dtype} != {expected_dtype}. Casting to {expected_dtype}.")
jagged_attn_output = jagged_attn_output.to(expected_dtype, non_blocking=True)

with nvtx.annotate("hstu norm mul dropout fwd", color="GREEN"):
if self._debug_shortcut_output_ln_mul_dropout:
Expand Down
Loading