Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for llama multipack using updated code/patches #1754

Merged
merged 6 commits into from
Jul 16, 2024

Conversation

winglian
Copy link
Collaborator

The attention monkey patch we have for llama is pretty old at this point and having to maintain it is a pain. Swapping to the updated unpad patch for flash attention, and did a slight refactor to continue to support the cross entropy loss and rms norm patches.

Screenshot 2024-07-15 at 8 47 35 AM As we can see, it's slightly faster, uses about 1.4% less VRAM and has pretty similar loss and grad norm characteristics.

I also attempted to use the updated triton RMS Norm over the CUDA implementation of RMS norm from flash attn and made things slightly worse.
Screenshot 2024-07-15 at 8 57 59 AM

@winglian winglian merged commit 5f58555 into main Jul 16, 2024
8 checks passed
@winglian winglian deleted the llama-multipack-v2 branch July 16, 2024 21:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant