-
Notifications
You must be signed in to change notification settings - Fork 42
FP8 Enablement for TE layers + HSTU attn #197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fe517ed
86e31e7
ebbaa4f
446d932
d65386a
4773714
b67b2bf
4ad689f
d7c9ab3
16360d8
9647e76
ffb7900
782fc66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
|
|
||
| import os | ||
| from functools import partial | ||
| import warnings | ||
|
|
||
| import nvtx | ||
| import torch | ||
|
|
@@ -140,6 +141,7 @@ def __init__(self, config: HSTUConfig): | |
| attention_dim=self._attention_dim_per_head, | ||
| linear_dim=self._linear_dim_per_head, | ||
| is_causal=config.is_causal, | ||
| quant_mode=config.hstu_attn_quantization_mode, | ||
| ) | ||
| register_setter_and_getter_for_nvtx( | ||
| HSTULayer.forward, key_or_attr_name="values" | ||
|
|
@@ -216,8 +218,10 @@ def forward(self, jd: JaggedData) -> JaggedData: | |
| bias=self._input_layernorm_bias, | ||
| eps=self._eps, | ||
| ) | ||
|
|
||
| with nvtx.annotate("hstu uvqk linear_silu fwd", color="BLUE"): | ||
| tu, tv, tq, tk = self.get_user_value_query_key_tensors(normed_x) | ||
|
|
||
| # TODO: remove contiguous once cutlass backend is ready | ||
| with nvtx.annotate("hstu attn fwd", color="BLUE"): | ||
| jagged_attn_output = self._attn_func( | ||
|
|
@@ -230,6 +234,12 @@ def forward(self, jd: JaggedData) -> JaggedData: | |
| max_seqlen=jd.max_seqlen, | ||
| target_group_size=self._target_group_size, | ||
| ) | ||
|
|
||
| # TODO: Remove this once the attention kernel outputs consistent dtype | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JacoCheung could you double check why we need this?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When training with fp8 enabled, the model weight could be bf16/fp16 (NetworkArgs.dtype_str) (Usually it's bf16). So as activation is. But the hstu kernel output is fp16, so here we need a cast between fp16->bf16. @shijieliu @esoba do you think if there's a need to move the cast into the kernel or not?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think casting fp16 to bf16 wouldn't result in error since dynamic range is larger, but would assume some additional quantization error pops up (ideally that gets learned by model anyways). I think for consistency it would probably be better to move it into the kernel but as a workaround casting outside should be fine. |
||
| expected_dtype = torch.bfloat16 if self.config.bf16 else torch.float16 | ||
| if jagged_attn_output.dtype != expected_dtype: | ||
| warnings.warn(f"Jagged attn output dtype mismatch: {jagged_attn_output.dtype} != {expected_dtype}. Casting to {expected_dtype}.") | ||
| jagged_attn_output = jagged_attn_output.to(expected_dtype, non_blocking=True) | ||
|
|
||
| with nvtx.annotate("hstu norm mul dropout fwd", color="GREEN"): | ||
| if self._debug_shortcut_output_ln_mul_dropout: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @esoba , Since you have padding here, there should be a discarding process at the postprocessor before loss compute. That's being said,
Otherwise, the padded token will impact the backward both data gradient and weight gradient even if the padded value is initialized as 0.
See our loss calculation.
And our post-processor (if it's ranking) will take the padded token as normal token..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe I was seeing the issue when I set this to truncate (cut off the last N elements to get nearest divisible by 16), let me double check this to see if there is any undefined behavior doing it this way as well. Thanks for the catch!