Skip to content

Conversation

@esoba
Copy link
Contributor

@esoba esoba commented Oct 21, 2025

Added MixedPrecisionArgs to allow users to configure FP8 usage in TE linear layers and HSTU attention for Native HSTU layer. Features include:

  • New MixedPrecisionArgs in gin config
  • Support for TE fp8 autocast in pipeline
  • Ability to separate TE FP8 from HSTU FP8 attn
  • Truncation/padding logic to enable TE FP8 divisible by 16 requirement

Minimal working example PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun --nproc_per_node 2 --master_addr localhost --master_port 6000 pretrain_gr_ranking.py --gin-config-file movielens_ranking_fp8.gin

Setup currently has a bug when both TE linear layer and HSTU attn are fp8 enabled, seeing NaN loss at iteration 64. I have a debugging branch here that tracks the forward pass and associated fp8 metadata for easier debugging. I tried to repro the issue here with some dummy inputs, and it has run successfully - have a hunch that there are NaN gradients flowing back into embedding table.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.


TensorModelParallelArgs.tensor_model_parallel_size = 2

# MixedPrecisionArgs.mixed_precision_dtype = "fp8"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncommenting this line should throw error at iteration 64 related to NaN loss

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reproduce command

PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun --nproc_per_node 2 --master_addr localhost --master_port 6000  pretrain_gr_ranking.py --gin-config-file movielens_ranking_fp8.gin

target_group_size=self._target_group_size,
)

# TODO: Remove this once the attention kernel outputs consistent dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JacoCheung could you double check why we need this?

Copy link
Collaborator

@JacoCheung JacoCheung Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When training with fp8 enabled, the model weight could be bf16/fp16 (NetworkArgs.dtype_str) (Usually it's bf16). So as activation is.

But the hstu kernel output is fp16, so here we need a cast between fp16->bf16. @shijieliu

@esoba do you think if there's a need to move the cast into the kernel or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think casting fp16 to bf16 wouldn't result in error since dynamic range is larger, but would assume some additional quantization error pops up (ideally that gets learned by model anyways). I think for consistency it would probably be better to move it into the kernel but as a workaround casting outside should be fine.

@XinboZhao
Copy link

I will help to review the code.


return jd

def _align_jagged_data_for_fp8(
Copy link
Collaborator

@JacoCheung JacoCheung Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @esoba , Since you have padding here, there should be a discarding process at the postprocessor before loss compute. That's being said,

 final_loss = drop_pad_values(final_loss)
 final_loss.mean().backward()

Otherwise, the padded token will impact the backward both data gradient and weight gradient even if the padded value is initialized as 0.

See our loss calculation.

And our post-processor (if it's ranking) will take the padded token as normal token..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I was seeing the issue when I set this to truncate (cut off the last N elements to get nearest divisible by 16), let me double check this to see if there is any undefined behavior doing it this way as well. Thanks for the catch!

@JacoCheung JacoCheung mentioned this pull request Nov 4, 2025
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants