Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first version of flash_attention for jax #19743

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

vulkomilev
Copy link

This is my first version of the flash attention implementation .It is just for Jax.

@fchollet
Copy link
Member

Thanks for the PR! Have you tried to time it on GPU compared to regular attention? I was under the impression that we were going to need a custom Pallas kernel for this.

@gbaned gbaned requested a review from fchollet May 27, 2024 06:04
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation May 27, 2024
@vulkomilev
Copy link
Author

I have used /keras/src/layers/attention/ directory as a template for implementing a flash attention but I don't understand how the mask is generated in the benchmark. I need one but I don't see it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

None yet

3 participants