⚡️ Speed up function repeat_kv by 8%
#348
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 8% (0.08x) speedup for
repeat_kvinsrc/transformers/models/mimi/modeling_mimi.py⏱️ Runtime :
2.31 milliseconds→2.14 milliseconds(best of88runs)📝 Explanation and details
The optimization replaces advanced indexing (
[:, :, None, :, :]) with the dedicated PyTorchunsqueeze(2)method for adding a dimension to the tensor. This change provides a 7% speedup by leveraging PyTorch's optimized dimension manipulation API instead of relying on implicit advanced indexing.Key Changes:
hidden_states[:, :, None, :, :]withhidden_states.unsqueeze(2)Why This is Faster:
unsqueeze()is a dedicated PyTorch operation optimized specifically for dimension manipulationNonerequires PyTorch to interpret and process the slice notation, which involves more overheadunsqueeze()allows PyTorch's internal optimizations to work more effectivelyPerformance Impact:
The function is called in the hot path of attention mechanisms (both
MimiAttentionandMimiSdpaAttention), whererepeat_kvis used to expand key and value states for multi-head attention. Given that attention computations are performed repeatedly during model inference and training, this 7% improvement can compound significantly.Test Case Performance:
The optimization shows consistent improvements across various scenarios:
The optimization maintains identical functionality while providing measurable performance gains in this critical attention pathway.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
To edit these changes
git checkout codeflash/optimize-repeat_kv-mi9j4zlfand push.