-
Notifications
You must be signed in to change notification settings - Fork 285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention for Neuron #939
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe wait until this PR is checked in. From what i can tell, your PR also has the remat bug not fixed. #942 (review)
|
||
|
||
def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate): | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: end comments with . (here and everywhere)
8a92182
to
73a2808
Compare
key: Tensor, | ||
value: Tensor, | ||
bias: Tensor, | ||
causal: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we support segment ID? Or a more general masking fn (with optimized handling) is even better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, I am fine with leaving a TODO here, but it is a hard blocker for enabling it for our internal training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do segment IDs in a separate PR? That involves non-trivial work and needs some time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, in this regard, I may ask for more, let's do general mask then, since we have want things beyond causal.
Thanks for all the reviews @ruomingp @kelvin-zou. I resolved all the comments, please let me know if any more changes are needed. |
seed = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads. | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# even num_heads except 0
if num_heads > 0 and num_heads % 2 == 0:
input_dtype: jnp.dtype, | ||
attention_bias_type: bool, | ||
): | ||
softmax_scale = 1.0 / (per_head_dim**0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just
per_head_dim**-0.5
73a2808
to
c226d03
Compare
I rebased the PR to avoid merge conflicts, can I please get a new approval? Thank you! |
from axlearn.common.flash_attention.utils import mha_reference | ||
|
||
if jax.default_backend() != "neuron": | ||
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a number of CI steps are failing -- I think we can either do something like
if jax.default_backend() != "neuron":
pytest.skip(reason=..., allow_module_level=True)
or update run_tests.sh
to exclude tests marked with neuron
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apoorvtintin I see quite a few unit tests failed, can you take a look?
c226d03
to
42720ad
Compare
42720ad
to
f7f06fd
Compare
This PR adds support for flash attention kernel for Neuron implemented through Neuron Kernel Interface (NKI).
The flash attention kernel works with TRN1 and TRN2.
This PR is a newer version of #883 from a different fork. All comments from the previous PR are addressed in this one. It has dropout support.
Dropout and Segment ID support in the flash attention kernel is in progress and will be available at a later date.