⚡️ Speed up method MimiEuclideanCodebook.encode by 9%
#349
+13
−2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 9% (0.09x) speedup for
MimiEuclideanCodebook.encodeinsrc/transformers/models/mimi/modeling_mimi.py⏱️ Runtime :
1.23 milliseconds→1.13 milliseconds(best of17runs)📝 Explanation and details
The optimization achieves an 8% speedup by eliminating unnecessary tensor dimension manipulation in the
quantizemethod.Key optimizations:
Removed redundant tensor indexing: The original code used
hidden_states[None].float()andself.embed[None].float()to add singleton dimensions, then immediately accessed[0]to remove them. The optimized version directly passes 2D tensors totorch.cdist, eliminating this wasteful round-trip.Added efficient embed property: Introduced a property that provides direct access to the embedding tensor, using
self._embedif available, otherwiseself.embed_sum. This avoids potential repeated buffer lookups and provides a cleaner interface.Direct function calls: Replaced
dists.argmin(dim=-1)withtorch.argmin(dists, dim=-1)for slightly more direct computation.Why this is faster:
The primary speedup comes from avoiding unnecessary tensor shape manipulations. Creating singleton dimensions with
[None]and then indexing with[0]forces PyTorch to allocate intermediate tensors and perform extra memory operations. The line profiler shows thetorch.cdistcall dropping from 79.5% to 65.6% of execution time inquantize.Performance impact:
The optimization is most effective for scenarios with empty inputs (75.7% speedup) and high-dimensional vectors (12.8% speedup), suggesting the tensor manipulation overhead scales with data complexity. All test cases show consistent 4-8% improvements, making this a universally beneficial optimization for vector quantization workloads.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
To edit these changes
git checkout codeflash/optimize-MimiEuclideanCodebook.encode-mi9jkcu3and push.