-
Notifications
You must be signed in to change notification settings - Fork 41
FP8 Enablement for TE layers + HSTU attn #197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…line in te amp cm
Te fp8 wrapper
…ant_mode argument name in python API
Merge to main
|
|
||
| TensorModelParallelArgs.tensor_model_parallel_size = 2 | ||
|
|
||
| # MixedPrecisionArgs.mixed_precision_dtype = "fp8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncommenting this line should throw error at iteration 64 related to NaN loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reproduce command
PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun --nproc_per_node 2 --master_addr localhost --master_port 6000 pretrain_gr_ranking.py --gin-config-file movielens_ranking_fp8.gin
| target_group_size=self._target_group_size, | ||
| ) | ||
|
|
||
| # TODO: Remove this once the attention kernel outputs consistent dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JacoCheung could you double check why we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When training with fp8 enabled, the model weight could be bf16/fp16 (NetworkArgs.dtype_str) (Usually it's bf16). So as activation is.
But the hstu kernel output is fp16, so here we need a cast between fp16->bf16. @shijieliu
@esoba do you think if there's a need to move the cast into the kernel or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think casting fp16 to bf16 wouldn't result in error since dynamic range is larger, but would assume some additional quantization error pops up (ideally that gets learned by model anyways). I think for consistency it would probably be better to move it into the kernel but as a workaround casting outside should be fine.
|
I will help to review the code. |
|
|
||
| return jd | ||
|
|
||
| def _align_jagged_data_for_fp8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @esoba , Since you have padding here, there should be a discarding process at the postprocessor before loss compute. That's being said,
final_loss = drop_pad_values(final_loss)
final_loss.mean().backward()
Otherwise, the padded token will impact the backward both data gradient and weight gradient even if the padded value is initialized as 0.
See our loss calculation.
And our post-processor (if it's ranking) will take the padded token as normal token..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe I was seeing the issue when I set this to truncate (cut off the last N elements to get nearest divisible by 16), let me double check this to see if there is any undefined behavior doing it this way as well. Thanks for the catch!
Added MixedPrecisionArgs to allow users to configure FP8 usage in TE linear layers and HSTU attention for Native HSTU layer. Features include:
Minimal working example PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun --nproc_per_node 2 --master_addr localhost --master_port 6000 pretrain_gr_ranking.py --gin-config-file movielens_ranking_fp8.gin
Setup currently has a bug when both TE linear layer and HSTU attn are fp8 enabled, seeing NaN loss at iteration 64. I have a debugging branch here that tracks the forward pass and associated fp8 metadata for easier debugging. I tried to repro the issue here with some dummy inputs, and it has run successfully - have a hunch that there are NaN gradients flowing back into embedding table.
Checklist