MultiHeadAttention memory usage reduction via tiling #679
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hope this is helpful, I haven't submitted a PR here before so let me know if I've got any etiquette wrong :).
Let me know if there are any changes you'd like.
Summary
This PR limits memory usage in MultiHeadAttention. It sets a maximum size (
max_chunk_size_mb
) for theattn_logits
tensor. If the tensor size is belowmax_chunk_size_mb
, everything proceeds normally, if it is above this size,attn_logits
is evaluated in chunks rather than ever being fully materialised. This allows limiting memory usage to what will fit in the machine, and as a bonus also sometimes provides a performance boost.Correctness Testing
I've run this through attention_test.py and run my own test scripts, using random tensors for query, key, value and mask and checking that the results are identical between the original implementation and the PR.
Performance Testing
CPU
I've run 18 different performance tests on CPU with different values of key_size, num_heads and query shape, all using self attention. For small inputs the calculation is the same so there's no difference, and for sufficiently large inputs, the old calculation runs out of memory while the new one works fine. In between these two, we see a substantial memory usage reduction, and a modest speed improvement (presumably due to better caching).
For instance, with the parameters:
This change resulted in:
GPU
I did the bulk of this profiling on CPU because I couldn't work out a good peak memory profiling setup on GPU, but testing shows that this also allows calculations to fit in GPU ram that otherwise wouldn't.
For example, using the parameters:
I can calculate self attention in 3.6s on a 12GB card using the new version.
In the previous version I get an OOM error, since at least 115GB of VRAM would have been required.
The speedup on the GPU seems to be a bit smaller, though both the new and original code are showing a variance of around 10% in my tests so it's difficult to pin down an exact number. If you've got any tips on how to benchmark this (or peak GPU RAM) accurately I'd appreciate them.
Implementation notes
I've tried to keep to your code style as much as possible, let me know if there's anything that would be better changed.
The test/benchmark scripts I used with this are pretty hacky, but if there is anywhere you'd like additional tests added let me know.
I've added a new argument
max_chunk_size_mb
with a default value of 1GB. This value works well as a default from my testing. It's big enough to give a performance boost (values below 100MB or so can hurt performance), and small enough to have an effect on memory usage. However, obviously this is going to be subject to tuning for individual machines. Let me know if you'd prefer something different API-wise.The idea behind this is based on FlashAttention, but it's a much simpler implementation because it only slices the tensor up in one dimension. If I understand it they implement the whole thing in C++ and slice the tensor up in multiple dimensions. I think that gives them a bigger performance boost, but my main focus here was on reducing memory usage and this version seems to do the trick.