Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 21, 2025

📄 45% (0.45x) speedup for MimiRotaryEmbedding.forward in src/transformers/models/mimi/modeling_mimi.py

⏱️ Runtime : 3.66 milliseconds 2.53 milliseconds (best of 106 runs)

📝 Explanation and details

The optimized version achieves a 44% speedup by replacing inefficient tensor operations with more performant alternatives in the forward method:

Key Optimizations:

  1. Replaced matrix multiplication with torch.einsum: The original code used .expand() to create large intermediate tensors followed by matrix multiplication (@). The optimized version uses torch.einsum("bs, d -> bsd", position_ids_float, inv_freq_float) which computes the same result without creating expanded intermediate tensors, reducing memory allocation overhead.

  2. Eliminated redundant .expand() operations: The original code expanded inv_freq to [batch, dim, 1] and position_ids to [batch, 1, seq_len], creating large temporary tensors. The optimized version leverages broadcasting directly in einsum, avoiding these allocations entirely.

  3. Used in-place operations: Replaced emb.cos() * self.attention_scaling with emb.cos().mul_(self.attention_scaling) to avoid creating additional temporary tensors during scaling.

  4. Streamlined dtype conversions: Consolidated the .float() calls into direct .to(dtype=torch.float32) operations, reducing redundant conversions.

  5. Added missing compute_default_rope_parameters method: The optimized version includes the static method that was missing from the original, ensuring complete functionality.

Why It's Faster:

  • torch.einsum is highly optimized for broadcasting operations and avoids intermediate tensor allocations
  • In-place operations (mul_) reduce memory pressure and garbage collection overhead
  • Fewer tensor expansions mean less memory bandwidth usage and allocation overhead

Performance Impact:
The optimizations show consistent improvements across all test cases (19-56% faster), with particularly strong gains for smaller tensors where memory allocation overhead is proportionally higher. The method benefits any workload using rotary position embeddings, which are common in transformer attention mechanisms.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 63 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import torch

from transformers.models.mimi.modeling_mimi import MimiRotaryEmbedding


# function to test
# (Paste the MimiRotaryEmbedding class definition here as provided above.)


# Minimal config class to use for MimiRotaryEmbedding
class DummyConfig:
    def __init__(
        self,
        hidden_size=8,
        num_attention_heads=2,
        max_position_embeddings=16,
        rope_parameters=None,
        head_dim=None,
    ):
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.max_position_embeddings = max_position_embeddings
        self.head_dim = head_dim
        # Default RoPE parameters
        self.rope_parameters = rope_parameters or {
            "rope_type": "default",
            "rope_theta": 10000.0,
        }


# =========================
# Basic Test Cases
# =========================


def test_forward_basic_shape_and_type():
    """
    Basic: Test output shapes and types for typical input
    """
    config = DummyConfig(hidden_size=8, num_attention_heads=2, max_position_embeddings=16)
    rotary = MimiRotaryEmbedding(config)
    # x: [batch, dim, seq_len]
    x = torch.randn(2, 4, 5)  # batch=2, dim=4, seq_len=5
    position_ids = torch.arange(5).unsqueeze(0).repeat(2, 1)  # [2, 5]
    cos, sin = rotary.forward(x, position_ids)  # 110μs -> 71.9μs (54.3% faster)


def test_forward_basic_values():
    """
    Basic: Test that cos/sin outputs are correct for known input
    """
    config = DummyConfig(hidden_size=4, num_attention_heads=2, max_position_embeddings=8)
    rotary = MimiRotaryEmbedding(config)
    x = torch.ones(1, 2, 4)
    position_ids = torch.tensor([[0, 1, 2, 3]])
    cos, sin = rotary.forward(x, position_ids)  # 101μs -> 70.8μs (44.0% faster)


def test_forward_dtype_float16():
    """
    Basic: Test with float16 input
    """
    config = DummyConfig(hidden_size=8, num_attention_heads=2, max_position_embeddings=16)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(2, 4, 5).half()
    position_ids = torch.arange(5).unsqueeze(0).repeat(2, 1)
    cos, sin = rotary.forward(x, position_ids)  # 103μs -> 69.4μs (49.7% faster)


def test_forward_dtype_float32():
    """
    Basic: Test with float32 input
    """
    config = DummyConfig(hidden_size=8, num_attention_heads=2, max_position_embeddings=16)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(2, 4, 5).float()
    position_ids = torch.arange(5).unsqueeze(0).repeat(2, 1)
    cos, sin = rotary.forward(x, position_ids)  # 102μs -> 66.6μs (53.9% faster)


# =========================
# Edge Test Cases
# =========================


def test_forward_empty_position_ids():
    """
    Edge: Test with empty position_ids (zero sequence length)
    """
    config = DummyConfig(hidden_size=8, num_attention_heads=2, max_position_embeddings=16)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(2, 4, 0)  # seq_len=0
    position_ids = torch.empty(2, 0, dtype=torch.long)
    cos, sin = rotary.forward(x, position_ids)  # 96.5μs -> 63.5μs (52.0% faster)


def test_forward_single_position_id():
    """
    Edge: Test with a single position_id
    """
    config = DummyConfig(hidden_size=8, num_attention_heads=2, max_position_embeddings=16)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 4, 1)  # batch=1, dim=4, seq_len=1
    position_ids = torch.tensor([[7]])
    cos, sin = rotary.forward(x, position_ids)  # 101μs -> 70.7μs (44.1% faster)


def test_forward_non_contiguous_tensor():
    """
    Edge: Test with non-contiguous input tensors
    """
    config = DummyConfig(hidden_size=8, num_attention_heads=2, max_position_embeddings=16)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(2, 4, 5).transpose(0, 2)  # Make non-contiguous
    position_ids = torch.arange(5).unsqueeze(0).repeat(2, 1)
    # Make position_ids non-contiguous
    position_ids = position_ids.t().contiguous().t()
    cos, sin = rotary.forward(x.transpose(0, 2), position_ids)  # 100μs -> 72.1μs (39.3% faster)


def test_forward_negative_position_ids():
    """
    Edge: Test with negative position_ids (should still work, but output will be valid cos/sin)
    """
    config = DummyConfig(hidden_size=8, num_attention_heads=2, max_position_embeddings=16)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(2, 4, 5)
    position_ids = torch.tensor([[-1, -2, -3, -4, -5], [-5, -4, -3, -2, -1]])
    cos, sin = rotary.forward(x, position_ids)  # 118μs -> 76.3μs (55.0% faster)


def test_forward_large_theta():
    """
    Edge: Test with very large rope_theta parameter
    """
    config = DummyConfig(
        hidden_size=8,
        num_attention_heads=2,
        max_position_embeddings=16,
        rope_parameters={"rope_type": "default", "rope_theta": 1e9},
    )
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(2, 4, 5)
    position_ids = torch.arange(5).unsqueeze(0).repeat(2, 1)
    cos, sin = rotary.forward(x, position_ids)  # 103μs -> 68.4μs (51.2% faster)


def test_forward_small_theta():
    """
    Edge: Test with very small rope_theta parameter
    """
    config = DummyConfig(
        hidden_size=8,
        num_attention_heads=2,
        max_position_embeddings=16,
        rope_parameters={"rope_type": "default", "rope_theta": 1e-9},
    )
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(2, 4, 5)
    position_ids = torch.arange(5).unsqueeze(0).repeat(2, 1)
    cos, sin = rotary.forward(x, position_ids)  # 101μs -> 67.5μs (51.0% faster)


# =========================
# Large Scale Test Cases
# =========================


def test_forward_large_batch_and_seq_len():
    """
    Large Scale: Test with large batch and sequence length
    """
    config = DummyConfig(hidden_size=32, num_attention_heads=8, max_position_embeddings=512)
    rotary = MimiRotaryEmbedding(config)
    batch = 16
    seq_len = 512
    dim = config.hidden_size // config.num_attention_heads
    x = torch.randn(batch, dim, seq_len)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)
    cos, sin = rotary.forward(x, position_ids)  # 182μs -> 153μs (19.3% faster)


def test_forward_maximum_tensor_size():
    """
    Large Scale: Test with maximum tensor size under 100MB
    """
    config = DummyConfig(hidden_size=64, num_attention_heads=8, max_position_embeddings=1000)
    rotary = MimiRotaryEmbedding(config)
    batch = 8
    seq_len = 1000
    dim = config.hidden_size // config.num_attention_heads
    # Each float32 element is 4 bytes, total size = batch * dim * seq_len * 4
    total_bytes = batch * dim * seq_len * 4
    x = torch.randn(batch, dim, seq_len)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)
    cos, sin = rotary.forward(x, position_ids)  # 242μs -> 202μs (20.3% faster)


def test_forward_large_head_dim():
    """
    Large Scale: Test with large head_dim
    """
    config = DummyConfig(hidden_size=256, num_attention_heads=4, max_position_embeddings=128, head_dim=64)
    rotary = MimiRotaryEmbedding(config)
    batch = 4
    seq_len = 128
    dim = config.head_dim
    x = torch.randn(batch, dim, seq_len)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)
    cos, sin = rotary.forward(x, position_ids)  # 170μs -> 110μs (53.9% faster)


def test_forward_performance_large_scale():
    """
    Large Scale: Test that function completes in reasonable time for large input
    """
    import time

    config = DummyConfig(hidden_size=128, num_attention_heads=8, max_position_embeddings=512)
    rotary = MimiRotaryEmbedding(config)
    batch = 8
    seq_len = 512
    dim = config.hidden_size // config.num_attention_heads
    x = torch.randn(batch, dim, seq_len)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)
    start = time.time()
    cos, sin = rotary.forward(x, position_ids)  # 237μs -> 166μs (42.8% faster)
    duration = time.time() - start


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import torch

from transformers.models.mimi.modeling_mimi import MimiRotaryEmbedding


# function to test
# (see above for full MimiRotaryEmbedding implementation)


# Helper class to simulate MimiConfig for tests
class DummyConfig:
    def __init__(
        self,
        max_position_embeddings=128,
        hidden_size=64,
        num_attention_heads=8,
        rope_parameters=None,
        head_dim=None,
    ):
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.rope_parameters = rope_parameters or {"rope_type": "default", "rope_theta": 10000.0}
        self.head_dim = head_dim


# ----------- BASIC TEST CASES -----------


def test_forward_basic_shape_and_dtype():
    """
    Basic: Test that output shapes and dtypes are correct for standard input.
    """
    config = DummyConfig(max_position_embeddings=32, hidden_size=64, num_attention_heads=8)
    rotary = MimiRotaryEmbedding(config)
    # Simulate input tensor x: batch=2, seq_len=16, dim=8
    x = torch.randn(2, 16, 8)
    # position_ids: (2, 16)
    position_ids = torch.arange(16).repeat(2, 1)
    cos, sin = rotary.forward(x, position_ids)  # 111μs -> 74.2μs (50.2% faster)


def test_forward_basic_values_repeatability():
    """
    Basic: Test that outputs are deterministic for same input.
    """
    config = DummyConfig(max_position_embeddings=16, hidden_size=32, num_attention_heads=4)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 8, 8)
    position_ids = torch.arange(8).unsqueeze(0)
    cos1, sin1 = rotary.forward(x, position_ids)  # 104μs -> 68.7μs (51.8% faster)
    cos2, sin2 = rotary.forward(x, position_ids)  # 46.3μs -> 31.7μs (45.8% faster)


def test_forward_basic_different_positions():
    """
    Basic: Test that different position_ids yield different outputs.
    """
    config = DummyConfig(max_position_embeddings=32, hidden_size=64, num_attention_heads=8)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 8, 8)
    position_ids1 = torch.arange(8).unsqueeze(0)
    position_ids2 = torch.arange(8, 16).unsqueeze(0)
    cos1, sin1 = rotary.forward(x, position_ids1)  # 99.7μs -> 65.7μs (51.7% faster)
    cos2, sin2 = rotary.forward(x, position_ids2)  # 45.2μs -> 31.0μs (45.6% faster)


# ----------- EDGE TEST CASES -----------


def test_forward_edge_empty_input():
    """
    Edge: Test with empty input tensor.
    """
    config = DummyConfig(max_position_embeddings=8, hidden_size=16, num_attention_heads=2)
    rotary = MimiRotaryEmbedding(config)
    x = torch.empty(0, 0, 0)
    position_ids = torch.empty(0, 0, dtype=torch.long)
    cos, sin = rotary.forward(x, position_ids)  # 89.7μs -> 62.5μs (43.6% faster)


def test_forward_edge_single_position():
    """
    Edge: Test with a single position id.
    """
    config = DummyConfig(max_position_embeddings=1, hidden_size=8, num_attention_heads=2)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 1, 4)
    position_ids = torch.tensor([[0]])
    cos, sin = rotary.forward(x, position_ids)  # 99.0μs -> 70.1μs (41.2% faster)


def test_forward_edge_max_position_embedding():
    """
    Edge: Test with position_ids at max_position_embeddings boundary.
    """
    config = DummyConfig(max_position_embeddings=16, hidden_size=32, num_attention_heads=4)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 4, 8)
    position_ids = torch.full((1, 4), config.max_position_embeddings - 1)
    cos, sin = rotary.forward(x, position_ids)  # 106μs -> 70.3μs (51.6% faster)


def test_forward_edge_negative_position_ids():
    """
    Edge: Test with negative position ids (should not crash, but may produce valid outputs).
    """
    config = DummyConfig(max_position_embeddings=16, hidden_size=32, num_attention_heads=4)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 4, 8)
    position_ids = torch.tensor([[-1, -2, -3, -4]])
    cos, sin = rotary.forward(x, position_ids)  # 107μs -> 68.7μs (56.4% faster)


def test_forward_edge_large_theta():
    """
    Edge: Test with very large rope_theta value.
    """
    config = DummyConfig(
        max_position_embeddings=8,
        hidden_size=16,
        num_attention_heads=2,
        rope_parameters={"rope_type": "default", "rope_theta": 1e9},
    )
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 4, 4)
    position_ids = torch.arange(4).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids)  # 100μs -> 66.7μs (50.0% faster)


def test_forward_edge_dtype_float16():
    """
    Edge: Test with input tensor of dtype float16.
    """
    config = DummyConfig(max_position_embeddings=16, hidden_size=32, num_attention_heads=4)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 8, 8).half()
    position_ids = torch.arange(8).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids)  # 104μs -> 69.9μs (49.3% faster)


def test_forward_edge_dtype_bfloat16():
    """
    Edge: Test with input tensor of dtype bfloat16.
    """
    config = DummyConfig(max_position_embeddings=16, hidden_size=32, num_attention_heads=4)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 8, 8).to(dtype=torch.bfloat16)
    position_ids = torch.arange(8).unsqueeze(0)
    cos, sin = rotary.forward(x, position_ids)  # 104μs -> 70.5μs (48.0% faster)


def test_forward_edge_device_cpu_vs_cuda():
    """
    Edge: Test that function works on both CPU and CUDA (if available).
    """
    config = DummyConfig(max_position_embeddings=8, hidden_size=16, num_attention_heads=2)
    rotary = MimiRotaryEmbedding(config)
    x = torch.randn(1, 4, 4)
    position_ids = torch.arange(4).unsqueeze(0)
    cos_cpu, sin_cpu = rotary.forward(x, position_ids)  # 100μs -> 67.2μs (49.8% faster)
    if torch.cuda.is_available():
        rotary_cuda = MimiRotaryEmbedding(config, device="cuda")
        x_cuda = x.cuda()
        position_ids_cuda = position_ids.cuda()
        cos_cuda, sin_cuda = rotary_cuda.forward(x_cuda, position_ids_cuda)


# ----------- LARGE SCALE TEST CASES -----------


def test_forward_large_batch_and_sequence():
    """
    Large Scale: Test with large batch and sequence length, but within memory limits.
    """
    config = DummyConfig(max_position_embeddings=512, hidden_size=128, num_attention_heads=8)
    rotary = MimiRotaryEmbedding(config)
    batch_size = 32
    seq_len = 64
    dim = 16
    x = torch.randn(batch_size, seq_len, dim)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
    cos, sin = rotary.forward(x, position_ids)  # 169μs -> 120μs (40.5% faster)


def test_forward_large_dim():
    """
    Large Scale: Test with large embedding dimension.
    """
    config = DummyConfig(max_position_embeddings=128, hidden_size=1024, num_attention_heads=16)
    rotary = MimiRotaryEmbedding(config)
    batch_size = 4
    seq_len = 16
    dim = 64
    x = torch.randn(batch_size, seq_len, dim)
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
    cos, sin = rotary.forward(x, position_ids)  # 119μs -> 76.4μs (56.2% faster)


def test_forward_large_varied_position_ids():
    """
    Large Scale: Test with varied position_ids across batch.
    """
    config = DummyConfig(max_position_embeddings=256, hidden_size=128, num_attention_heads=8)
    rotary = MimiRotaryEmbedding(config)
    batch_size = 16
    seq_len = 32
    dim = 16
    x = torch.randn(batch_size, seq_len, dim)
    # Each batch element gets a different offset
    position_ids = torch.stack([torch.arange(i, i + seq_len) for i in range(batch_size)])
    cos, sin = rotary.forward(x, position_ids)  # 133μs -> 86.2μs (54.4% faster)


def test_forward_large_float16():
    """
    Large Scale: Test with large input and float16 dtype.
    """
    config = DummyConfig(max_position_embeddings=128, hidden_size=256, num_attention_heads=8)
    rotary = MimiRotaryEmbedding(config)
    batch_size = 8
    seq_len = 64
    dim = 32
    x = torch.randn(batch_size, seq_len, dim).half()
    position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
    cos, sin = rotary.forward(x, position_ids)  # 144μs -> 97.2μs (48.8% faster)


# ----------- FAILURE CASES -----------

To edit these changes git checkout codeflash/optimize-MimiRotaryEmbedding.forward-mi9ik8to and push.

Codeflash Static Badge

The optimized version achieves a **44% speedup** by replacing inefficient tensor operations with more performant alternatives in the `forward` method:

**Key Optimizations:**

1. **Replaced matrix multiplication with `torch.einsum`**: The original code used `.expand()` to create large intermediate tensors followed by matrix multiplication (`@`). The optimized version uses `torch.einsum("bs, d -> bsd", position_ids_float, inv_freq_float)` which computes the same result without creating expanded intermediate tensors, reducing memory allocation overhead.

2. **Eliminated redundant `.expand()` operations**: The original code expanded `inv_freq` to `[batch, dim, 1]` and `position_ids` to `[batch, 1, seq_len]`, creating large temporary tensors. The optimized version leverages broadcasting directly in `einsum`, avoiding these allocations entirely.

3. **Used in-place operations**: Replaced `emb.cos() * self.attention_scaling` with `emb.cos().mul_(self.attention_scaling)` to avoid creating additional temporary tensors during scaling.

4. **Streamlined dtype conversions**: Consolidated the `.float()` calls into direct `.to(dtype=torch.float32)` operations, reducing redundant conversions.

5. **Added missing `compute_default_rope_parameters` method**: The optimized version includes the static method that was missing from the original, ensuring complete functionality.

**Why It's Faster:**
- `torch.einsum` is highly optimized for broadcasting operations and avoids intermediate tensor allocations
- In-place operations (`mul_`) reduce memory pressure and garbage collection overhead
- Fewer tensor expansions mean less memory bandwidth usage and allocation overhead

**Performance Impact:**
The optimizations show consistent improvements across all test cases (19-56% faster), with particularly strong gains for smaller tensors where memory allocation overhead is proportionally higher. The method benefits any workload using rotary position embeddings, which are common in transformer attention mechanisms.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 21, 2025 23:51
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant