diff --git a/corelib/hstu/README.md b/corelib/hstu/README.md index e4a6795cd..9b0a3d9b5 100644 --- a/corelib/hstu/README.md +++ b/corelib/hstu/README.md @@ -76,7 +76,7 @@ def hstu_attn_varlen_func( # nheads should be divisible by nhead_rab has_drab=False, # Whether to apply drab is_delta_q=False, # Whether to apply delta_q - quantization_mode=-1, # -1: no quantization, 0: cast to fp8, 1: 1xDIM&128x1 quantization, 2: per-block quantization, 3: per-head quantization, 4: per-batch quantization, 5: per-tensor quantization. + quant_mode=-1, # -1: no quantization, 0: cast to fp8, 1: 1xDIM&128x1 quantization, 2: per-block quantization, 3: per-head quantization, 4: per-batch quantization, 5: per-tensor quantization. ) ``` diff --git a/corelib/hstu/hopper/hstu_attn_interface.py b/corelib/hstu/hopper/hstu_attn_interface.py index dbfe82d00..dd2cf54fd 100755 --- a/corelib/hstu/hopper/hstu_attn_interface.py +++ b/corelib/hstu/hopper/hstu_attn_interface.py @@ -452,7 +452,13 @@ def forward( has_drab=False, func=None, quant_mode=-1, - ): + ): + # Debug: Log to file to avoid stdout buffering issues + import sys + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + if rank == 0: + sys.stderr.write(f"[HSTU Attn Forward] quant_mode: {quant_mode}, q.dtype: {q.dtype}, k.dtype: {k.dtype}, v.dtype: {v.dtype}\n") + sys.stderr.flush() vt = None descale_q = None descale_k = None @@ -512,6 +518,12 @@ def forward( ) with torch.cuda.nvtx.range("hstu_varlen_fwd_kernel"): + # Debug: Check dtypes before kernel + import sys + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + if rank == 0: + sys.stderr.write(f"[Before CUDA kernel] quant_mode={quant_mode}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}\n") + sys.stderr.flush() out, rab = _hstu_attn_varlen_forward( q, k, @@ -537,6 +549,25 @@ def forward( cu_seqlens_block_descale_q, cu_seqlens_block_descale_k, ) + + # Debug: Check dtypes after kernel + import sys + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + if rank == 0: + sys.stderr.write(f"[After CUDA kernel] out.dtype={out.dtype}, original_q.dtype={ctx.q_fp16.dtype}\n") + sys.stderr.flush() + + # Restore original dtype: CUDA kernel outputs FP16, but input may be BF16 + # This ensures dtype consistency with the rest of the model + # TODO: Replace dtype conversion when kernel outputs consistent dtype + if quant_mode >= 0 and out.dtype != ctx.q_fp16.dtype: + import sys + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + if rank == 0: + sys.stderr.write(f"[Dtype Cast] Converting {out.dtype} -> {ctx.q_fp16.dtype}\n") + sys.stderr.flush() + out = out.to(ctx.q_fp16.dtype) + ctx.save_for_backward( q, k, v, rab, cu_seqlens_q, cu_seqlens_k, num_contexts, num_targets ) diff --git a/docker/Dockerfile b/docker/Dockerfile index 7ec00b6ef..fba6c954c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -45,6 +45,8 @@ RUN apt update -y --fix-missing && \ RUN pip install --no-cache pre-commit +RUN pip install --no-cache --upgrade --no-build-isolation transformer_engine[pytorch] + FROM ${DEVEL_IMAGE} AS build WORKDIR /workspace/recsys-examples diff --git a/docker/Dockerfile.pyt_build b/docker/Dockerfile.pyt_build index dc4996b91..a3507120d 100644 --- a/docker/Dockerfile.pyt_build +++ b/docker/Dockerfile.pyt_build @@ -38,6 +38,8 @@ RUN pip install --no-deps tensordict orjson && \ cd torchrec && \ pip install --no-deps . +RUN pip install --no-cache --upgrade --no-build-isolation transformer_engine[pytorch] + RUN test -f /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1 || ln -s ${CUDA_HOME}/targets/x86_64-linux/lib/stubs/libnvidia-ml.so /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1 # Debugging - build w/DOCKER_BUILDKIT=1 to see output diff --git a/examples/hstu/configs/hstu_config.py b/examples/hstu/configs/hstu_config.py index d8146279f..d82185b78 100644 --- a/examples/hstu/configs/hstu_config.py +++ b/examples/hstu/configs/hstu_config.py @@ -111,6 +111,8 @@ class HSTUConfig(TransformerConfig): kernel_backend: KernelBackend = KernelBackend.CUTLASS hstu_layer_type: HSTULayerType = HSTULayerType.FUSED # DEBUG|NATIVE|FUSED + hstu_attn_quantization_mode: int = -1 # Default to -1, which means no fp8 for hstu attn + fp8_alignment_mode: Optional[str] = None target_group_size: int = 1 learnable_input_layernorm: bool = True @@ -144,6 +146,8 @@ def get_hstu_config( norm_epsilon=1e-5, is_causal: bool = True, kernel_backend: KernelBackend = KernelBackend.CUTLASS, + hstu_attn_quantization_mode: int = -1, + fp8_alignment_mode: str = "truncate", target_group_size: int = 1, hstu_layer_type: HSTULayerType = HSTULayerType.FUSED, learnable_input_layernorm: bool = True, @@ -154,6 +158,7 @@ def get_hstu_config( is_inference: bool = False, add_uvqk_bias: bool = True, fuse_norm_mul_dropout: bool = True, + **transformer_config_kwargs, ) -> HSTUConfig: """ Create the HSTU configuration. @@ -223,4 +228,7 @@ def get_hstu_config( add_uvqk_bias=add_uvqk_bias, is_inference=is_inference, fuse_norm_mul_dropout=fuse_norm_mul_dropout, + hstu_attn_quantization_mode=hstu_attn_quantization_mode, + fp8_alignment_mode=fp8_alignment_mode, + **transformer_config_kwargs, ) diff --git a/examples/hstu/model/ranking_gr.py b/examples/hstu/model/ranking_gr.py index 638715025..e20807660 100644 --- a/examples/hstu/model/ranking_gr.py +++ b/examples/hstu/model/ranking_gr.py @@ -27,12 +27,11 @@ from modules.multi_task_loss_module import MultiTaskLossModule from torchrec.sparse.jagged_tensor import JaggedTensor - class RankingGR(BaseModel): """ A class representing the ranking model. Inherits from BaseModel. A ranking model consists of a sparse architecture and a dense architecture. A ranking model is able to process multiple labels - and thus has multiple logit dimensions. Each label is associated with a loss functoin (e.g. BCE, CE). + and thus has multiple logit dimensions. Each label is associated with a loss function (e.g. BCE, CE). Args: hstu_config (HSTUConfig): The HSTU configuration. @@ -110,6 +109,7 @@ def get_logit_and_labels( """ # DMP embedding embeddings: Dict[str, JaggedTensor] = self._embedding_collection(batch.features) + # maybe freeze embedding for debugging embeddings = self._embedding_collection._maybe_detach(embeddings) # For model-parallel embedding, torchrec does gradient division by (tp_size * dp_size). However, we only need to divide by dp size. In such case, we need to scale the gradient by tp_size. diff --git a/examples/hstu/modules/fused_hstu_layer.py b/examples/hstu/modules/fused_hstu_layer.py index 83848b9cb..66e06178b 100644 --- a/examples/hstu/modules/fused_hstu_layer.py +++ b/examples/hstu/modules/fused_hstu_layer.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings import torch from commons.utils.nvtx_op import output_nvtx_hook, register_setter_and_getter_for_nvtx from configs import HSTUConfig @@ -59,7 +60,7 @@ def __init__(self, config: HSTUConfig): self._alpha = 1.0 / (self._attention_dim_per_head**0.5) self._residual = config.residual self._attn_backend = config.kernel_backend - + self._quantization_mode = config.hstu_attn_quantization_mode # stream and event are shared across all layers self._wgrad_stream = config.async_wgrad_stream self._wgrad_event = config.async_wgrad_event @@ -155,6 +156,7 @@ def forward(self, jd: JaggedData) -> JaggedData: wgrad_event=self._wgrad_event, recompute_input_layernorm=self._recompute_input_layernorm, recompute_input_silu=self._recompute_input_silu, + quant_mode = -1, # FusedHSTULayer using C++ API for HSTU attn, passthrough quant_mode when python API is used ) return JaggedData( values=output, diff --git a/examples/hstu/modules/hstu_attention.py b/examples/hstu/modules/hstu_attention.py index d8af93ba1..a857dd127 100644 --- a/examples/hstu/modules/hstu_attention.py +++ b/examples/hstu/modules/hstu_attention.py @@ -319,6 +319,7 @@ def __init__( attention_dim: int, linear_dim: int, is_causal: bool, + quant_mode: int = -1, ): super().__init__() from hopper.hstu_attn_interface import hstu_attn_varlen_func @@ -331,7 +332,7 @@ def __init__( assert ( self.linear_dim == self.attention_dim ), "only support linear_dim and attention_dim" - + self.quant_mode = quant_mode @output_nvtx_hook(nvtx_tag="FusedHSTUAttnHopper") def forward( self, @@ -390,6 +391,7 @@ def forward( window_size=(-1, 0) if self.is_causal else (-1, -1), rab=None, alpha=1.0 / (self.attention_dim**0.5), + quant_mode=self.quant_mode, ).view(-1, self.num_heads * self.linear_dim) @@ -399,6 +401,7 @@ def create_hstu_attention( attention_dim: int, linear_dim: int, is_causal: bool, + quant_mode: int = -1, ) -> HSTUAttention: """ Factory function to create an HSTUAttention module based on the kernel backend. @@ -425,6 +428,7 @@ def create_hstu_attention( attention_dim, linear_dim, is_causal, + quant_mode, ) elif sm_major_version == 8 and sm_minor_version == 0: return FusedHSTUAttention( diff --git a/examples/hstu/modules/hstu_processor.py b/examples/hstu/modules/hstu_processor.py index 0eb34e774..f7794dc1d 100644 --- a/examples/hstu/modules/hstu_processor.py +++ b/examples/hstu/modules/hstu_processor.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import warnings from typing import Dict, Optional, Union import torch from commons.utils.nvtx_op import output_nvtx_hook +from commons.utils.logger import print_rank_0 from configs.hstu_config import HSTUConfig from configs.inference_config import InferenceHSTUConfig from dataset.utils import RankingBatch @@ -302,8 +304,111 @@ def forward( training=self.training, ).to(self._training_dtype) + # FP8 alignment: Truncate to nearest multiple of 16 for TE Linear fwd + bwdcompatibility + if not isinstance(self.config, InferenceHSTUConfig): + if self.config.fp8 is not None and self.training: + warnings.warn("Aligning JaggedData to nearest multiple of 16 for FP8 TE fwd + bwd compatibility") + jd = self._align_jagged_data_for_fp8(jd, fp8_alignment_mode=self.config.fp8_alignment_mode) + return jd + def _align_jagged_data_for_fp8( + self, jd: JaggedData, fp8_alignment_mode: str = "truncate" + ) -> JaggedData: + """ + Aligns JaggedData to have total tokens divisible by 16 for FP8 compatibility. + + Either truncates or pads tokens from the end of the batch (last sequence) to meet Transformer Engine divisibility requirement. + + Args: + jd: Input JaggedData + fp8_alignment_mode: Alignment mode ("truncate" or "pad") + + Returns: + Aligned JaggedData with maintained invariants + """ + num_tokens = jd.values.shape[0] + + # Check if already aligned + if num_tokens % 16 == 0: + return jd + + # Calculate aligned size based on mode + if fp8_alignment_mode == "truncate": + # Round down to previous multiple of 16 + aligned_size = (num_tokens // 16) * 16 + elif fp8_alignment_mode == "pad": + # Round up to next multiple of 16 + aligned_size = ((num_tokens + 15) // 16) * 16 + else: + raise ValueError(f"Invalid fp8 alignment mode: {fp8_alignment_mode}") + + if aligned_size < 16: + # Too few tokens - pad to minimum size + padding_needed = 16 - num_tokens + jd.values = torch.nn.functional.pad(jd.values, (0, 0, 0, padding_needed)) + return jd + + # Truncate: Remove tokens from the end + if fp8_alignment_mode == "truncate": + tokens_to_remove = num_tokens - aligned_size + jd.values = jd.values[:aligned_size] + + # Update metadata: Find which sequences are affected + # We're removing from the end, so only the last sequence(s) are affected + batch_size = jd.seqlen.shape[0] + new_seqlen = jd.seqlen.clone() + + # Work backwards from last sequence + remaining_to_remove = tokens_to_remove + for seq_idx in range(batch_size - 1, -1, -1): + if remaining_to_remove <= 0: + break + + current_len = new_seqlen[seq_idx].item() + if current_len <= remaining_to_remove: + # Remove entire sequence + new_seqlen[seq_idx] = 0 + remaining_to_remove -= current_len + else: + # Partially truncate this sequence + new_seqlen[seq_idx] = current_len - remaining_to_remove + remaining_to_remove = 0 + + # Recompute offsets + new_seqlen_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(new_seqlen) + + # Verify invariant + print_rank_0(f"Truncated tokens to {aligned_size} from {num_tokens} for FP8 TE compatibility") + assert new_seqlen_offsets[-1].item() == aligned_size, \ + f"Offset mismatch: {new_seqlen_offsets[-1].item()} != {aligned_size}" + + # TODO: verify padding logic with attention kernel + elif fp8_alignment_mode == "pad": + # Pad: Add tokens to the end to reach aligned size + # Leave seqlen unchanged, attn mask in attn kernel should ignore the rest + padding_needed = aligned_size - num_tokens + jd.values = torch.nn.functional.pad(jd.values, (0, 0, 0, padding_needed)) + + print_rank_0(f"Padded tokens to {aligned_size} from {num_tokens} for FP8 TE compatibility (as ghost tokens)") + + # No need to update seqlen or offsets - padding is invisible to the model + return jd + + return JaggedData( + values=jd.values, + seqlen=new_seqlen, + seqlen_offsets=new_seqlen_offsets, + max_seqlen=jd.max_seqlen, + max_num_candidates=jd.max_num_candidates, + num_candidates=jd.num_candidates, + num_candidates_offsets=jd.num_candidates_offsets, + contextual_max_seqlen=jd.contextual_max_seqlen, + contextual_seqlen=jd.contextual_seqlen, + contextual_seqlen_offsets=jd.contextual_seqlen_offsets, + has_interleaved_action=jd.has_interleaved_action, + ) + class HSTUBlockPostprocessor(torch.nn.Module): """ diff --git a/examples/hstu/modules/mlp.py b/examples/hstu/modules/mlp.py index d2124e014..d9de8bf76 100644 --- a/examples/hstu/modules/mlp.py +++ b/examples/hstu/modules/mlp.py @@ -104,4 +104,4 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: The output tensor. """ assert input.dim() == 2, "Tensor must be 2-dimensional" - return self._mlp(input) + return self._mlp(input) \ No newline at end of file diff --git a/examples/hstu/modules/native_hstu_layer.py b/examples/hstu/modules/native_hstu_layer.py index 9c85635e8..373cec79f 100644 --- a/examples/hstu/modules/native_hstu_layer.py +++ b/examples/hstu/modules/native_hstu_layer.py @@ -15,6 +15,7 @@ import os from functools import partial +import warnings import nvtx import torch @@ -140,6 +141,7 @@ def __init__(self, config: HSTUConfig): attention_dim=self._attention_dim_per_head, linear_dim=self._linear_dim_per_head, is_causal=config.is_causal, + quant_mode=config.hstu_attn_quantization_mode, ) register_setter_and_getter_for_nvtx( HSTULayer.forward, key_or_attr_name="values" @@ -216,8 +218,10 @@ def forward(self, jd: JaggedData) -> JaggedData: bias=self._input_layernorm_bias, eps=self._eps, ) + with nvtx.annotate("hstu uvqk linear_silu fwd", color="BLUE"): tu, tv, tq, tk = self.get_user_value_query_key_tensors(normed_x) + # TODO: remove contiguous once cutlass backend is ready with nvtx.annotate("hstu attn fwd", color="BLUE"): jagged_attn_output = self._attn_func( @@ -230,6 +234,12 @@ def forward(self, jd: JaggedData) -> JaggedData: max_seqlen=jd.max_seqlen, target_group_size=self._target_group_size, ) + + # TODO: Remove this once the attention kernel outputs consistent dtype + expected_dtype = torch.bfloat16 if self.config.bf16 else torch.float16 + if jagged_attn_output.dtype != expected_dtype: + warnings.warn(f"Jagged attn output dtype mismatch: {jagged_attn_output.dtype} != {expected_dtype}. Casting to {expected_dtype}.") + jagged_attn_output = jagged_attn_output.to(expected_dtype, non_blocking=True) with nvtx.annotate("hstu norm mul dropout fwd", color="GREEN"): if self._debug_shortcut_output_ln_mul_dropout: diff --git a/examples/hstu/modules/utils.py b/examples/hstu/modules/utils.py index 32e9be185..123c94ab6 100644 --- a/examples/hstu/modules/utils.py +++ b/examples/hstu/modules/utils.py @@ -14,7 +14,14 @@ # limitations under the License. import torch +import warnings +try: + import transformer_engine.pytorch as te + use_te = True +except: + warnings.warn("transformer_engine.pytorch is not installed, FP8 mixed precision will not be supported") + use_te = False def init_mlp_weights_optional_bias( m: torch.nn.Module, @@ -30,3 +37,38 @@ def init_mlp_weights_optional_bias( # Always initialize bias to zero. if m.bias is not None: m.bias.data.fill_(0.0) + +def convert_te_linear_to_torch_linear(m: torch.nn.Module) -> torch.nn.Module: + """ + Convert a Transformer Engine Linear layer to a PyTorch Linear layer. + Copies weights and biases from the TE layer to the new PyTorch layer. + + Args: + m: The module to convert. If not a TE Linear layer, returns unchanged. + + Returns: + torch.nn.Linear if m is a TE Linear layer, otherwise returns m unchanged. + """ + if not use_te: + return m + + # Check if this is a Transformer Engine Linear layer + if not isinstance(m, te.Linear): + return m + + # Create new PyTorch Linear layer with same dimensions + new_layer = torch.nn.Linear( + m.in_features, + m.out_features, + bias=m.bias is not None, + device=m.weight.device, + dtype=m.weight.dtype + ) + + # Copy weights and bias + with torch.no_grad(): + new_layer.weight.copy_(m.weight) + if m.bias is not None and new_layer.bias is not None: + new_layer.bias.copy_(m.bias) + + return new_layer \ No newline at end of file diff --git a/examples/hstu/movielens_ranking_fp8.gin b/examples/hstu/movielens_ranking_fp8.gin new file mode 100644 index 000000000..11362b779 --- /dev/null +++ b/examples/hstu/movielens_ranking_fp8.gin @@ -0,0 +1,37 @@ +TrainerArgs.train_batch_size = 128 +TrainerArgs.eval_batch_size = 128 +TrainerArgs.eval_interval = 100 +TrainerArgs.log_interval = 100 +TrainerArgs.seed = 1234 +TrainerArgs.max_train_iters = 1000 +TrainerArgs.profile = True + +DatasetArgs.dataset_name = 'ml-20m' +DatasetArgs.max_sequence_length = 256 +DatasetArgs.shuffle = True +DatasetArgs.max_num_candidates = 16 + +NetworkArgs.dtype_str = "bfloat16" +NetworkArgs.num_layers = 1 +NetworkArgs.num_attention_heads = 4 +NetworkArgs.hidden_size = 128 +NetworkArgs.kv_channels = 128 +NetworkArgs.target_group_size = 1 + +# ratings 0-5 +RankingArgs.prediction_head_arch = [512, 10] +RankingArgs.prediction_head_bias = True +RankingArgs.num_tasks = 1 +RankingArgs.eval_metrics = ("AUC",) + +OptimizerArgs.optimizer_str = 'adam' +OptimizerArgs.learning_rate = 1e-3 +OptimizerArgs.adam_beta1 = 0.9 +OptimizerArgs.adam_beta2 = 0.98 + +TensorModelParallelArgs.tensor_model_parallel_size = 2 + +# MixedPrecisionArgs.mixed_precision_dtype = "fp8" +MixedPrecisionArgs.linear_recipe = "tensorwise" +MixedPrecisionArgs.linear_scaling_precision = "hybrid" +MixedPrecisionArgs.hstu_attn_quantization_mode = "fp8" \ No newline at end of file diff --git a/examples/hstu/movielens_retrieval_fp8.gin b/examples/hstu/movielens_retrieval_fp8.gin new file mode 100644 index 000000000..598a25fb8 --- /dev/null +++ b/examples/hstu/movielens_retrieval_fp8.gin @@ -0,0 +1,32 @@ +TrainerArgs.train_batch_size = 128 +TrainerArgs.eval_batch_size = 128 +TrainerArgs.eval_interval = 100 +TrainerArgs.log_interval = 100 +TrainerArgs.ckpt_save_interval = -1 +TrainerArgs.seed = 1234 + +DatasetArgs.dataset_name = 'ml-20m' +DatasetArgs.max_sequence_length = 200 +DatasetArgs.shuffle = True + +NetworkArgs.dtype_str = "bfloat16" +NetworkArgs.num_layers = 4 +NetworkArgs.num_attention_heads = 4 +NetworkArgs.hidden_size = 256 +NetworkArgs.kv_channels = 64 +NetworkArgs.hidden_dropout = 0 +NetworkArgs.norm_epsilon = 1e-6 +NetworkArgs.is_causal = True + +OptimizerArgs.optimizer_str = 'adam' +OptimizerArgs.learning_rate = 1e-3 +OptimizerArgs.adam_beta1 = 0.9 +OptimizerArgs.adam_beta2 = 0.98 + +RetrievalArgs.num_negatives = 128 +RetrievalArgs.eval_metrics = ("NDCG@10", "NDCG@20", "HR@10") + +# MixedPrecisionArgs.mixed_precision_dtype = "fp8" +MixedPrecisionArgs.linear_recipe = "tensorwise" +MixedPrecisionArgs.linear_scaling_precision = "hybrid" +MixedPrecisionArgs.hstu_attn_quantization_mode = "fp8" \ No newline at end of file diff --git a/examples/hstu/ops/fused_hstu_op.py b/examples/hstu/ops/fused_hstu_op.py index a30b21d5f..3252ceb2e 100644 --- a/examples/hstu/ops/fused_hstu_op.py +++ b/examples/hstu/ops/fused_hstu_op.py @@ -94,6 +94,7 @@ def forward( wgrad_event: Optional[torch.cuda.Event] = None, recompute_input_layernorm: bool = False, recompute_input_silu: bool = False, + quant_mode: int = -1, ) -> torch.Tensor: """Forward pass of the fused HSTU layer. Args: @@ -125,6 +126,7 @@ def forward( wgrad_event (Optional[torch.cuda.Event]): CUDA event for weight gradient computation. Defaults to None. recompute_input_layernorm (bool): Whether to recompute the input layer norm. Defaults to False. recompute_input_silu (bool): Whether to recompute the input silu. Defaults to False. + quant_mode (int): Quantization mode for fp8. Defaults to -1. Returns: torch.Tensor: Output tensor of shape [T, hidden_size] """ @@ -144,6 +146,7 @@ def forward( ctx.wgrad_event = wgrad_event ctx.recompute_input_layernorm = recompute_input_layernorm ctx.recompute_input_silu = recompute_input_silu + ctx.quant_mode = quant_mode saved_tensor_map = OrderedDict() if num_contextuals is None and attn_backend == KernelBackend.TRITON: num_contextuals = 0 @@ -277,6 +280,7 @@ def _hstu_attn_cutlass_fwd( num_targets, target_group_size, alpha, + quant_mode = -1, ): sm_major_version = torch.cuda.get_device_properties(0).major extension_args = () @@ -286,7 +290,7 @@ def _hstu_attn_cutlass_fwd( extension_args = ampere_paged_kv_args elif sm_major_version == 9: cutlass_hstu_varlen_fwd = flash_attn_cuda_hopper.varlen_fwd - hopper_fp8_args = (-1, None, None, None, None, None, None, None, None) + hopper_fp8_args = (quant_mode, None, None, None, None, None, None, None, None) # Replace -1 hardcode with quantization_mode extension_args = hopper_fp8_args else: @@ -340,6 +344,7 @@ def _hstu_attn_cutlass_fwd( ctx.window_size_right = 0 ctx.has_drab = False ctx.is_delta_q = False + ctx.quant_mode = quant_mode return P @@ -444,6 +449,7 @@ def _linear_residual_fwd( num_targets=num_targets, target_group_size=target_group_size, alpha=ctx.alpha, + quant_mode=quant_mode, ) else: assert isinstance( @@ -528,6 +534,7 @@ def backward( None, None, None, + None, # gradient for quant_mode ]: def _linear_residual_bwd( grad_output, @@ -603,6 +610,7 @@ def _hstu_attn_cutlass_bwd( dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, + quant_mode: int = -1, ): sm_major_version = torch.cuda.get_device_properties(0).major assert dout.dim() == 3 @@ -653,7 +661,7 @@ def _hstu_attn_cutlass_bwd( window_size_left, window_size_right, alpha, - -1, # quant_mode + quant_mode, # previously hardcoded -1, added quant mode from config None, # rab_padded False, # has_drab None, # func @@ -838,6 +846,7 @@ def _ln_linear_silu_bwd( dq=pre_dq.view(-1, ctx.num_heads, ctx.attention_dim_per_head), dk=pre_dk.view(-1, ctx.num_heads, ctx.attention_dim_per_head), dv=pre_dv.view(-1, ctx.num_heads, ctx.attention_dim_per_head), + quant_mode=ctx.quant_mode, ) grad_output = duvqk else: @@ -934,6 +943,7 @@ def _ln_linear_silu_bwd( None, None, None, + None, # gradient for quant_mode ) @@ -970,6 +980,7 @@ def fused_hstu_op( wgrad_event: Optional[torch.cuda.Event] = None, recompute_input_layernorm: bool = False, recompute_input_silu: bool = False, + quant_mode: int = -1, ): out = FusedHSTULayerFunction.apply( input, @@ -1000,6 +1011,7 @@ def fused_hstu_op( wgrad_event, recompute_input_layernorm, recompute_input_silu, + quant_mode, ) return out diff --git a/examples/hstu/ops/pt_ops/torch_addmm.py b/examples/hstu/ops/pt_ops/torch_addmm.py index 9969366f6..4b5dfb848 100644 --- a/examples/hstu/ops/pt_ops/torch_addmm.py +++ b/examples/hstu/ops/pt_ops/torch_addmm.py @@ -16,6 +16,13 @@ import torch +try: + import transformer_engine.pytorch as te + use_te = True +except: + warnings.warn("transformer_engine.pytorch is not installed, FP8 mixed precision will not be supported") + use_te = False + def torch_addmm_silu_fwd( x: torch.Tensor, @@ -32,3 +39,25 @@ def torch_addmm_silu_fwd( else: silu_z = None return z, silu_z + +# TODO: Validate correctness +def te_addmm_silu_fwd( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, + silu: bool = False, +) -> torch.Tensor: + """ + compute z = silu(y + x @ w); silu is optional + Use transformer engine pytorch for potential OOB fp8 support + """ + linear = te.linear(x.shape[-1], w.shape[-2], bias=True) + with torch.no_grad(): + linear.weight.copy_(w) + linear.bias.copy_(y) + z = linear(x) + if silu: + silu_z = torch.nn.functional.silu(z) + else: + silu_z = None + return z, silu_z \ No newline at end of file diff --git a/examples/hstu/pipeline/train_pipeline.py b/examples/hstu/pipeline/train_pipeline.py index ccc9fc394..0942a434e 100644 --- a/examples/hstu/pipeline/train_pipeline.py +++ b/examples/hstu/pipeline/train_pipeline.py @@ -23,6 +23,7 @@ import abc import logging +import warnings from collections import deque from typing import ( Any, @@ -66,6 +67,15 @@ from torchrec.pt2.checks import is_torchdynamo_compiling from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +try: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling, Float8BlockScaling + use_te = True + recipe_map = {'delayed': DelayedScaling, 'tensorwise': Float8CurrentScaling, 'blockwise': Float8BlockScaling} +except: + warnings.warn("transformer_engine.pytorch is not installed, FP8 mixed precision will not be supported") + use_te = False + logger: logging.Logger = logging.getLogger(__name__) # This is required to support older torch package export for older models try: @@ -126,12 +136,26 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + te_mixed_precision: bool = False, + **fp8_mp_kwargs ) -> None: self._model = model self._optimizer = optimizer self._device = device self._execute_all_batches = execute_all_batches self._apply_jit = apply_jit + self._te_mixed_precision = te_mixed_precision + self._fp8_mp_kwargs = fp8_mp_kwargs + + # Define recipe for FP8 mixed precision training based on kwargs + if self._te_mixed_precision and not use_te: + assert False, "transformer_engine.pytorch is not installed, but te_mixed_precision = True" + elif self._te_mixed_precision and len(fp8_mp_kwargs) == 0: + assert False, "fp8_mp_kwargs is empty, but te_mixed_precision = True" + elif self._te_mixed_precision and use_te: + self.recipe = recipe_map[fp8_mp_kwargs.pop('recipe')](**fp8_mp_kwargs) + else: + self.recipe = None if device.type == "cuda": # use two data streams to support two concurrent batches @@ -365,7 +389,11 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # forward with record_function("## forward ##"): - losses, output = self._model_fwd(self.batches[0]) + if use_te: + with te.fp8_autocast(enabled=self._te_mixed_precision, fp8_recipe=self.recipe): + losses, output = self._model_fwd(self.batches[0]) + else: + losses, output = self._model_fwd(self.batches[0]) if len(self.batches) >= 2: # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) @@ -609,6 +637,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + **kwargs, ) -> None: super().__init__( model=model, @@ -619,6 +648,7 @@ def __init__( context_type=PrefetchTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + **kwargs, ) self._context = PrefetchTrainPipelineContext(version=0) self._prefetch_stream: Optional[torch.Stream] = ( @@ -675,7 +705,11 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self._wait_sparse_data_dist() # forward with record_function("## forward ##"): - losses, output = self._model_fwd(self._batch_i) + if use_te: + with te.fp8_autocast(enabled=self._te_mixed_precision, fp8_recipe=self.recipe): + losses, output = self._model_fwd(self._batch_i) + else: + losses, output = self._model_fwd(self._batch_i) self._prefetch(self._batch_ip1) @@ -740,6 +774,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + **kwargs, ) -> None: super().__init__( model, @@ -750,6 +785,7 @@ def __init__( TrainPipelineContext, pipeline_postproc, custom_model_fwd, + **kwargs, ) def progress(self, dataloader_iter: Iterator[In]) -> Out: @@ -797,7 +833,11 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # forward with nvtx.annotate("## forward ##"): - losses, output = self._model_fwd(self.batches[0]) + if use_te: + with te.fp8_autocast(enabled=self._te_mixed_precision, fp8_recipe=self.recipe): + losses, output = self._model_fwd(self.batches[0]) + else: + losses, output = self._model_fwd(self.batches[0]) with nvtx.annotate("## loss postprocess ##"): collective_assert(not torch.isnan(losses).any(), "loss has nan value") local_tokens = torch.tensor(losses.size(0), device=self._device).float() @@ -843,6 +883,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + **kwargs, ) -> None: super().__init__( model, @@ -852,6 +893,7 @@ def __init__( apply_jit, pipeline_postproc, custom_model_fwd, + **kwargs, ) def progress(self, dataloader_iter: Iterator[In]) -> Tuple[torch.Tensor, Out]: diff --git a/examples/hstu/pretrain_gr_ranking.py b/examples/hstu/pretrain_gr_ranking.py index 9851eb6b3..8ea05249f 100644 --- a/examples/hstu/pretrain_gr_ranking.py +++ b/examples/hstu/pretrain_gr_ranking.py @@ -40,6 +40,7 @@ NetworkArgs, OptimizerArgs, TensorModelParallelArgs, + MixedPrecisionArgs, TrainerArgs, create_dynamic_optitons_dict, create_embedding_configs, @@ -51,6 +52,15 @@ train_with_pipeline, ) +try: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling, Float8BlockScaling + use_te = True + format_map = {'e4m3': Format.E4M3, 'e5m2': Format.E5M2, 'hybrid': Format.HYBRID} +except: + warnings.warn("transformer_engine.pytorch is not installed, FP8 mixed precision will not be supported") + use_te = False + @gin.configurable @dataclass @@ -83,7 +93,18 @@ def __post_init__(self): network_args = NetworkArgs() optimizer_args = OptimizerArgs() tp_args = TensorModelParallelArgs() +mp_args = MixedPrecisionArgs() + +if mp_args.enabled and not use_te: + assert False, "FP8 mixed precision only supported with Transformer Engine" +if mp_args.enabled: + fp8_mp_kwargs = { + "recipe": mp_args.linear_recipe, + "fp8_format": format_map[mp_args.linear_scaling_precision], + } +else: + fp8_mp_kwargs = {} def create_ranking_config() -> RankingConfig: ranking_args = RankingArgs() @@ -110,7 +131,7 @@ def main(): print_rank_0( f"distributed env initialization done. Free cuda memory: {free_memory / (1024 ** 2):.2f} MB" ) - hstu_config = create_hstu_config(network_args, tp_args) + hstu_config = create_hstu_config(network_args, tp_args, mp_args) task_config = create_ranking_config() model = get_ranking_model(hstu_config=hstu_config, task_config=task_config) @@ -156,6 +177,8 @@ def main(): model_train, dense_optimizer, device=torch.device("cuda", torch.cuda.current_device()), + te_mixed_precision=mp_args.enabled, + **fp8_mp_kwargs ) else: pipeline = JaggedMegatronTrainNonePipeline( diff --git a/examples/hstu/pretrain_gr_retrieval.py b/examples/hstu/pretrain_gr_retrieval.py index 267cc0e9e..f64910d8e 100644 --- a/examples/hstu/pretrain_gr_retrieval.py +++ b/examples/hstu/pretrain_gr_retrieval.py @@ -38,6 +38,7 @@ NetworkArgs, OptimizerArgs, TensorModelParallelArgs, + MixedPrecisionArgs, TrainerArgs, create_dynamic_optitons_dict, create_embedding_config, @@ -49,6 +50,18 @@ train_with_pipeline, ) +# Optional Transformer Engine support (FP8) +try: + import transformer_engine.pytorch as te # type: ignore + from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling, Float8BlockScaling # type: ignore + use_te = True + format_map = {"e4m3": Format.E4M3, "e5m2": Format.E5M2, "hybrid": Format.HYBRID} +except: # noqa: E722 + warnings.warn( + "transformer_engine.pytorch is not installed, FP8 mixed precision will not be supported" + ) + use_te = False + @gin.configurable @dataclass @@ -71,6 +84,18 @@ class RetrievalArgs: network_args = NetworkArgs() optimizer_args = OptimizerArgs() tp_args = TensorModelParallelArgs() +mp_args = MixedPrecisionArgs() + +if mp_args.enabled and not use_te: + assert False, "FP8 mixed precision only supported with Transformer Engine" + +if mp_args.enabled: + fp8_mp_kwargs = { + "recipe": mp_args.linear_recipe, + "fp8_format": format_map[mp_args.linear_scaling_precision], + } +else: + fp8_mp_kwargs = {} def create_retrieval_config() -> RetrievalConfig: @@ -95,7 +120,7 @@ def main(): ) init.set_random_seed(trainer_args.seed) - hstu_config = create_hstu_config(network_args, tp_args) + hstu_config = create_hstu_config(network_args, tp_args, mp_args) task_config = create_retrieval_config() model = get_retrieval_model(hstu_config=hstu_config, task_config=task_config) @@ -128,6 +153,8 @@ def main(): model_train, dense_optimizer, device=torch.device("cuda", torch.cuda.current_device()), + te_mixed_precision=mp_args.enabled, + **fp8_mp_kwargs, ) else: pipeline = JaggedMegatronTrainNonePipeline( diff --git a/examples/hstu/training/gin_config_args.py b/examples/hstu/training/gin_config_args.py index 126aeebde..dceca36a4 100644 --- a/examples/hstu/training/gin_config_args.py +++ b/examples/hstu/training/gin_config_args.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import List, Optional, Union + +from dataclasses import dataclass, field +from typing import List, Optional, Union, Literal, Dict import gin @@ -175,3 +176,35 @@ class OptimizerArgs: @dataclass class TensorModelParallelArgs: tensor_model_parallel_size: int = 1 + +#TODO: Eventual support for MXFP8 and NVFP4 on Blackwell +@gin.configurable +@dataclass +class MixedPrecisionArgs: + mixed_precision_dtype: Optional[Literal["fp8"]] = None + linear_recipe: Literal["delayed", "tensorwise", "blockwise"] = "tensorwise" + linear_scaling_precision: Literal["hybrid", "e4m3", "e5m2"] = "hybrid" + enable_fp8_for_prediction_head: bool = False # Experimental feature - prediction head has variable batch size and TE requires batch size to be divisible by 8 + fp8_alignment_mode: Literal["truncate", "pad"] = "truncate" + hstu_attn_quantization_mode: Literal["bf16", "fp8", "1xdim", "128x1", "per-block", "per-head", "per-batch", "per-tensor"] = "bf16" + hstu_attn_quantization_map: Dict[str, int] = field( + default_factory=lambda: { + "bf16": -1, + "fp8": 0, + "1xdim": 1, + "128x1": 2, + "per-block": 3, + "per-head": 4, + "per-batch": 5, + "per-tensor": 6, + } + ) + + def __post_init__(self): + if self.mixed_precision_dtype is None: + return # disabled; skip validations + assert self.mixed_precision_dtype == "fp8", "Only 'fp8' is supported now." + + @property + def enabled(self) -> bool: + return self.mixed_precision_dtype is not None \ No newline at end of file diff --git a/examples/hstu/training/utils.py b/examples/hstu/training/utils.py index 8e47bce13..e627d2a94 100644 --- a/examples/hstu/training/utils.py +++ b/examples/hstu/training/utils.py @@ -30,7 +30,7 @@ get_hstu_config, ) from dynamicemb import DynamicEmbTableOptions -from modules.embedding import ShardedEmbeddingConfig +from modules.embedding import ShardedEmbeddingConfig, ShardedEmbedding from training.gin_config_args import ( BenchmarkDatasetArgs, DatasetArgs, @@ -39,8 +39,10 @@ NetworkArgs, OptimizerArgs, TensorModelParallelArgs, + MixedPrecisionArgs, TrainerArgs, ) +from commons.utils.logger import print_rank_0 @torch.compile @@ -91,7 +93,7 @@ def cal_flops(hstu_config: HSTUConfig, seqlens: List[torch.Tensor]) -> int: def create_hstu_config( - network_args: NetworkArgs, tensor_model_parallel_args: TensorModelParallelArgs + network_args: NetworkArgs, tensor_model_parallel_args: TensorModelParallelArgs, mp_args: MixedPrecisionArgs ): dtype = None if network_args.dtype_str == "bfloat16": @@ -129,6 +131,20 @@ def create_hstu_config( ) else: hstu_preprocessing_config = None + + # Define HSTU attn quantization separate from TE FP8 + hstu_attn_quantization_mode = mp_args.hstu_attn_quantization_map[mp_args.hstu_attn_quantization_mode] + + if mp_args.enabled: + # Matching Megatron FP8 arguments + fp8 = mp_args.linear_scaling_precision # Flag to set both te linear and precision https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_config.py + fp8_recipe = mp_args.linear_recipe + fp8_alignment_mode = mp_args.fp8_alignment_mode + else: + fp8 = None + fp8_recipe = None + fp8_alignment_mode = None + return get_hstu_config( hidden_size=network_args.hidden_size, kv_channels=network_args.kv_channels, @@ -139,12 +155,16 @@ def create_hstu_config( is_causal=network_args.is_causal, dtype=dtype, kernel_backend=kernel_backend, + hstu_attn_quantization_mode=hstu_attn_quantization_mode, + fp8_alignment_mode=fp8_alignment_mode, hstu_preprocessing_config=hstu_preprocessing_config, position_encoding_config=position_encoding_config, target_group_size=network_args.target_group_size, hstu_layer_type=layer_type, recompute_input_layernorm=network_args.recompute_input_layernorm, recompute_input_silu=network_args.recompute_input_silu, + fp8 = fp8, + fp8_recipe = fp8_recipe, ) @@ -556,3 +576,195 @@ def get_dataset_and_embedding_args() -> ( ] else: raise ValueError(f"dataset {dataset_args.dataset_name} is not supported") + +def inspect_sharded_embedding_tables(embedding_collection: ShardedEmbedding, table_name_filter: str = None, tracking_state: dict = None) -> dict: + """ + Helper function to inspect all embedding tables in a ShardedEmbedding collection. + Works with both regular embeddings and dynamic embeddings. + + Args: + embedding_collection: The ShardedEmbedding instance to inspect + table_name_filter: Optional filter to only show tables containing this string + tracking_state: Optional dict to track IDs across iterations {table_name: {'all_ids': set(), 'nan_ids': set()}} + + Returns: + Updated tracking_state dict + """ + if tracking_state is None: + tracking_state = {} + print_rank_0("=" * 80) + print_rank_0("EMBEDDING COLLECTION INSPECTION") + print_rank_0("=" * 80) + + # Check model-parallel embeddings + if embedding_collection._model_parallel_embedding_collection is not None: + print_rank_0("\n[MODEL-PARALLEL EMBEDDINGS]") + mp_collection = embedding_collection._model_parallel_embedding_collection + + # Try to get dynamic embedding modules + try: + from dynamicemb.dump_load import get_dynamic_emb_module + dynamic_modules = get_dynamic_emb_module(mp_collection) + + if len(dynamic_modules) > 0: + print_rank_0(f"Found {len(dynamic_modules)} dynamic embedding module(s)") + for module_idx, dyn_module in enumerate(dynamic_modules): + print_rank_0(f"\n Dynamic Module {module_idx}:") + for table_idx, (table_name, table) in enumerate(zip(dyn_module.table_names, dyn_module.tables)): + if table_name_filter and table_name_filter not in table_name: + continue + print_rank_0(f" Table '{table_name}':") + print_rank_0(f" Type: Dynamic Embedding (KeyValueTable)") + capacity = table.capacity() + used_size = table.size() + emb_dim = table.embedding_dim() + opt_dim = table.optim_state_dim() + + print_rank_0(f" Capacity: {capacity}") + print_rank_0(f" Size (used): {used_size}") + print_rank_0(f" Embedding dim: {emb_dim}") + print_rank_0(f" Optimizer state dim: {opt_dim}") + print_rank_0(f" Effective shape: [{used_size}, {emb_dim}] (sparse)") + + # Check ALL embeddings for NaNs (sparse hash table requires full scan) + if used_size > 0: + from dynamicemb.dump_load import export_keys_values + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + try: + total_checked = 0 + total_nan_count = 0 + all_embedding_ids = [] + ids_with_nan = [] + emb_min = float('inf') + emb_max = float('-inf') + has_any_nan = False + + # Scan through entire hash table to find all stored embeddings + for keys, embeddings, opt_states, scores in export_keys_values(table, device, batch_size=65536): + batch_size_actual = embeddings.shape[0] + total_checked += batch_size_actual + + # Track all embedding IDs + all_embedding_ids.extend(keys.tolist()) + + # Check for NaNs and get the actual movie IDs with NaN + if torch.isnan(embeddings).any(): + has_any_nan = True + nan_mask = embeddings.isnan().any(dim=1) # Which embeddings have NaN + nan_per_embedding = embeddings.isnan().sum(dim=1) # Count NaNs per embedding + total_nan_count += torch.isnan(embeddings).sum().item() + ids_with_nan_in_batch = keys[nan_mask].tolist() # Get actual movie IDs + ids_with_nan.extend(ids_with_nan_in_batch) + + # Store NaN pattern for detailed analysis + if 'nan_patterns' not in tracking_state: + tracking_state['nan_patterns'] = {} + if table_name not in tracking_state['nan_patterns']: + # For each ID with NaN, store how many dimensions have NaN + nan_pattern = { + int(keys[i].item()): int(nan_per_embedding[i].item()) + for i in range(len(keys)) if nan_mask[i] + } + tracking_state['nan_patterns'][table_name] = nan_pattern + else: + emb_min = min(emb_min, embeddings.min().item()) + emb_max = max(emb_max, embeddings.max().item()) + + print_rank_0(f" Scanned {total_checked}/{used_size} embeddings") + print_rank_0(f" Has NaN: {has_any_nan}") + if has_any_nan: + print_rank_0(f" Total NaN count: {total_nan_count}") + total_elements = total_checked * emb_dim + print_rank_0(f" NaN percentage: {100.0 * total_nan_count / total_elements:.2f}%") + print_rank_0(f" Number of embeddings with NaN: {len(ids_with_nan)}") + print_rank_0(f" IDs with NaN: {ids_with_nan}") + + # Print NaN pattern analysis + if 'nan_patterns' in tracking_state and table_name in tracking_state['nan_patterns']: + nan_pattern = tracking_state['nan_patterns'][table_name] + fully_nan = sum(1 for count in nan_pattern.values() if count == emb_dim) + partially_nan = len(nan_pattern) - fully_nan + print_rank_0(f" NaN Pattern:") + print_rank_0(f" Fully NaN embeddings: {fully_nan}/{len(nan_pattern)} (all {emb_dim} dims)") + print_rank_0(f" Partially NaN embeddings: {partially_nan}/{len(nan_pattern)}") + if partially_nan > 0: + # Show examples of partial NaN + partial_examples = [(id_, count) for id_, count in list(nan_pattern.items())[:5] if count < emb_dim] + if partial_examples: + print_rank_0(f" Partial NaN examples (ID: NaN_count): {partial_examples}") + + # Store for comparison across iterations + if table_name not in tracking_state: + tracking_state[table_name] = {'all_ids': set(), 'nan_ids': set()} + + prev_all_ids = tracking_state[table_name]['all_ids'] + prev_nan_ids = tracking_state[table_name]['nan_ids'] + current_all_ids = set(all_embedding_ids) + current_nan_ids = set(ids_with_nan) + + new_ids = current_all_ids - prev_all_ids + new_ids_with_nan = current_nan_ids & new_ids + old_ids_with_nan = current_nan_ids - new_ids + + print_rank_0(f" New embeddings this iteration: {len(new_ids)}") + print_rank_0(f" New embeddings with NaN: {len(new_ids_with_nan)} (IDs: {list(new_ids_with_nan)[:10]})") + print_rank_0(f" Old embeddings corrupted: {len(old_ids_with_nan)} (IDs: {list(old_ids_with_nan)[:10]})") + + # Update tracking + tracking_state[table_name]['all_ids'] = current_all_ids + tracking_state[table_name]['nan_ids'] = current_nan_ids + else: + print_rank_0(f" Min/Max: {emb_min:.4f} / {emb_max:.4f}") + + # Track IDs even when no NaN for future comparison + if table_name not in tracking_state: + tracking_state[table_name] = {'all_ids': set(), 'nan_ids': set()} + tracking_state[table_name]['all_ids'] = set(all_embedding_ids) + except Exception as e: + print_rank_0(f" Error scanning embeddings: {e}") + except ImportError: + print_rank_0(" Dynamic embeddings module not available") + + # Check regular (non-dynamic) embeddings via state_dict + print_rank_0("\n Regular embeddings in state_dict:") + for name, tensor in mp_collection.state_dict().items(): + if table_name_filter and table_name_filter not in name: + continue + print_rank_0(f" '{name}':") + + if hasattr(tensor, "local_shards"): + print_rank_0(f" Type: ShardedTensor (model-parallel)") + for shard_idx, shard in enumerate(tensor.local_shards()): + shard_tensor = shard.tensor + print_rank_0(f" Shard {shard_idx}:") + print_rank_0(f" Shape: {shard_tensor.shape}") + print_rank_0(f" Offsets: {shard.metadata.shard_offsets}") + print_rank_0(f" Sizes: {shard.metadata.shard_sizes}") + print_rank_0(f" Has NaN: {torch.isnan(shard_tensor).any()}") + if torch.isnan(shard_tensor).any(): + print_rank_0(f" NaN count: {torch.isnan(shard_tensor).sum()}") + else: + print_rank_0(f" Type: Regular Tensor") + print_rank_0(f" Shape: {tensor.shape}") + print_rank_0(f" Has NaN: {torch.isnan(tensor).any()}") + if torch.isnan(tensor).any(): + print_rank_0(f" NaN count: {torch.isnan(tensor).sum()}") + + # Check data-parallel embeddings + if embedding_collection._data_parallel_embedding_collection is not None: + print_rank_0("\n[DATA-PARALLEL EMBEDDINGS]") + dp_collection = embedding_collection._data_parallel_embedding_collection + + if hasattr(dp_collection, 'embedding_weights'): + for table_name, weight_tensor in dp_collection.embedding_weights.items(): + if table_name_filter and table_name_filter not in table_name: + continue + print_rank_0(f" Table '{table_name}':") + print_rank_0(f" Shape: {weight_tensor.shape}") + print_rank_0(f" Has NaN: {torch.isnan(weight_tensor).any()}") + if torch.isnan(weight_tensor).any(): + print_rank_0(f" NaN count: {torch.isnan(weight_tensor).sum()}") + + print_rank_0("=" * 80) + return tracking_state \ No newline at end of file