Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cudnn attention dropout #913

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hanzhi713
Copy link
Member

This was separated from #905 to allow it to be merged before jax upgrade.

To be merged after jax is upgraded to >= 0.4.34.

@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners January 8, 2025 23:33
@hanzhi713 hanzhi713 changed the title Enable cudnn dropout Enable cudnn attention dropout Jan 8, 2025
@kelvin-zou
Copy link
Contributor

@hanzhi713 is cudnn kernel dropout also blocked by Jax upgrade? I remember only triton/pallas kernel is broken with jax 0.4.33?

@hanzhi713
Copy link
Member Author

Yes. Jax segfaults before 0.4.34

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants