Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiHeadAttention memory usage reduction via tiling #679

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

maxconway
Copy link

Hope this is helpful, I haven't submitted a PR here before so let me know if I've got any etiquette wrong :).
Let me know if there are any changes you'd like.

Summary

This PR limits memory usage in MultiHeadAttention. It sets a maximum size (max_chunk_size_mb) for the attn_logits tensor. If the tensor size is below max_chunk_size_mb, everything proceeds normally, if it is above this size, attn_logits is evaluated in chunks rather than ever being fully materialised. This allows limiting memory usage to what will fit in the machine, and as a bonus also sometimes provides a performance boost.

Correctness Testing

I've run this through attention_test.py and run my own test scripts, using random tensors for query, key, value and mask and checking that the results are identical between the original implementation and the PR.

Performance Testing

CPU

I've run 18 different performance tests on CPU with different values of key_size, num_heads and query shape, all using self attention. For small inputs the calculation is the same so there's no difference, and for sufficiently large inputs, the old calculation runs out of memory while the new one works fine. In between these two, we see a substantial memory usage reduction, and a modest speed improvement (presumably due to better caching).

For instance, with the parameters:

key_size = 64
num_heads = 32
h = jax.random.normal(jax.random.PRNGKey(42), [1,10000,1024])
max_chunk_size_mb = 1000

This change resulted in:

  • a 27% speedup (12.1s to 3.8s) and
  • a 6x drop in memory usage (28.8GB to 4.4Gb)

GPU

I did the bulk of this profiling on CPU because I couldn't work out a good peak memory profiling setup on GPU, but testing shows that this also allows calculations to fit in GPU ram that otherwise wouldn't.

For example, using the parameters:

key_size = 64
num_heads = 32
h = jax.random.normal(jax.random.PRNGKey(42), [1,30000,1024])
max_chunk_size_mb = 1000

I can calculate self attention in 3.6s on a 12GB card using the new version.
In the previous version I get an OOM error, since at least 115GB of VRAM would have been required.

The speedup on the GPU seems to be a bit smaller, though both the new and original code are showing a variance of around 10% in my tests so it's difficult to pin down an exact number. If you've got any tips on how to benchmark this (or peak GPU RAM) accurately I'd appreciate them.

Implementation notes

I've tried to keep to your code style as much as possible, let me know if there's anything that would be better changed.

The test/benchmark scripts I used with this are pretty hacky, but if there is anywhere you'd like additional tests added let me know.

I've added a new argument max_chunk_size_mb with a default value of 1GB. This value works well as a default from my testing. It's big enough to give a performance boost (values below 100MB or so can hurt performance), and small enough to have an effect on memory usage. However, obviously this is going to be subject to tuning for individual machines. Let me know if you'd prefer something different API-wise.

The idea behind this is based on FlashAttention, but it's a much simpler implementation because it only slices the tensor up in one dimension. If I understand it they implement the whole thing in C++ and slice the tensor up in multiple dimensions. I think that gives them a bigger performance boost, but my main focus here was on reducing memory usage and this version seems to do the trick.

When the attn_logits tensor is expected to be very large (by
default above 1GB), attention calculation is split into chunks
sized below this limit.

This reduces memory consumption quite a bit and crucially allows
for input tensor shapes that would otherwise not fit in memory.
There also seems to be a small speed improvement, presumably due
to better caching.

This is a similar concept to FlashAttention, though a much
simpler implementation because it is only tiling in one dimension

As an example, calculating self attention with:

key_size = 64
num_heads = 16
h = jax.random.normal(jax.random.PRNGKey(42), [1,16000,1024])

This change resulted in:
a ~30% speedup (6s to 4.3s) and
a 5x reduction in memory usage (19.3GB to 3.9Gb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant