Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 21, 2025

📄 44% (0.44x) speedup for AGLU.forward in ultralytics/nn/modules/activation.py

⏱️ Runtime : 6.49 milliseconds 4.50 milliseconds (best of 213 runs)

📝 Explanation and details

The optimization achieves a 44% speedup by decomposing a complex compound expression into intermediate tensor operations. The key change is breaking down the single complex return statement:

return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))

Into four separate operations:

log_lam = torch.log(lam)
inv_lam = lam.reciprocal()
act_val = self.act((self.kappa * x) - log_lam)
return torch.exp(inv_lam * act_val)

Why this is faster:

  1. Eliminates redundant computation: The original code calls torch.log(lam) and computes 1 / lam within a complex nested expression, potentially causing PyTorch to create multiple intermediate tensors and perform suboptimal memory access patterns.

  2. Improves tensor operation efficiency: Using lam.reciprocal() instead of 1 / lam is more efficient for tensor division operations in PyTorch.

  3. Better memory layout and caching: Breaking the computation into discrete steps allows PyTorch's tensor operations to be more cache-friendly and reduces temporary tensor allocations.

The line profiler shows the bottleneck shifted from the single complex expression (94.9% of time) to distributed operations, with the activation function call now taking 60.3% and the final exponential 29.5% of the time.

Performance impact: The optimization shows consistent 20-32% speedups across all test cases, from simple scalar inputs to large tensors (10,000+ elements). This suggests the optimization is particularly effective for neural network activation functions that are called frequently during forward passes, making it valuable for deep learning workloads where AGLU activations are used extensively.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 73 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
# function to test
# (from ultralytics/nn/modules/activation.py)
import torch
import torch.nn as nn
from ultralytics.nn.modules.activation import AGLU

# unit tests

# ----------- BASIC TEST CASES -----------

def test_forward_scalar_input():
    """Test forward with a scalar tensor input."""
    aglu = AGLU()
    x = torch.tensor(1.0)
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_vector_input():
    """Test forward with a 1D tensor input."""
    aglu = AGLU()
    x = torch.tensor([0.0, 1.0, -1.0, 2.0])
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_matrix_input():
    """Test forward with a 2D tensor input."""
    aglu = AGLU()
    x = torch.tensor([[0.5, -0.5], [2.0, -2.0]])
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_dtype_and_device():
    """Test forward with float32 and float64 on CPU."""
    for dtype in [torch.float32, torch.float64]:
        aglu = AGLU(dtype=dtype)
        x = torch.ones((2, 2), dtype=dtype)
        codeflash_output = aglu.forward(x); y = codeflash_output

# ----------- EDGE TEST CASES -----------

def test_forward_zero_input():
    """Test forward with all zeros input."""
    aglu = AGLU()
    x = torch.zeros(5)
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_large_input():
    """Test forward with large positive and negative input values."""
    aglu = AGLU()
    x = torch.tensor([100.0, -100.0])
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_inf_nan_input():
    """Test forward with inf and nan input values."""
    aglu = AGLU()
    x = torch.tensor([float('inf'), float('-inf'), float('nan')])
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_extreme_lambda_kappa():
    """Test with extreme values of lambda and kappa parameters."""
    aglu = AGLU()
    # Set lambda to a very small value (should be clamped to 0.0001)
    aglu.lambd.data.fill_(-100.0)
    # Set kappa to a large value
    aglu.kappa.data.fill_(100.0)
    x = torch.tensor([1.0, -1.0])
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_empty_tensor():
    """Test forward with an empty tensor."""
    aglu = AGLU()
    x = torch.empty(0)
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_broadcasting_lambda_kappa():
    """Test that broadcasting works when input is larger than parameters."""
    aglu = AGLU()
    x = torch.ones((3, 4, 5))
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_gradients():
    """Test that gradients flow through the module."""
    aglu = AGLU()
    x = torch.randn(3, requires_grad=True)
    codeflash_output = aglu.forward(x); y = codeflash_output
    y.sum().backward()

# ----------- LARGE SCALE TEST CASES -----------

def test_forward_large_tensor():
    """Test forward with a large tensor (but <100MB)."""
    aglu = AGLU()
    # 10000 x 100 floats (float32 = 4 bytes, 4MB), well under 100MB
    x = torch.randn(10000, 100)
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_batch_processing():
    """Test forward with a batch of inputs."""
    aglu = AGLU()
    # Simulate a batch of 512 images, each with 32 features
    x = torch.randn(512, 32)
    codeflash_output = aglu.forward(x); y = codeflash_output

def test_forward_speed_large():
    """Test forward speed with a large tensor, ensure no timeout."""
    import time
    aglu = AGLU()
    x = torch.randn(1000, 100)  # 0.4MB
    start = time.time()
    codeflash_output = aglu.forward(x); y = codeflash_output
    elapsed = time.time() - start

def test_forward_high_dimensional_tensor():
    """Test forward with a high-dimensional tensor."""
    aglu = AGLU()
    x = torch.randn(4, 3, 2, 5, 5)
    codeflash_output = aglu.forward(x); y = codeflash_output

# ----------- TYPE AND ERROR HANDLING TESTS -----------

def test_forward_non_float_input_raises():
    """Test that non-float input raises an error."""
    aglu = AGLU()
    x = torch.tensor([1, 2, 3], dtype=torch.int32)
    with pytest.raises(RuntimeError):
        aglu.forward(x)

def test_forward_device_consistency():
    """Test that module and input on same device works (CPU only)."""
    device = torch.device('cpu')
    aglu = AGLU(device=device)
    x = torch.randn(10, device=device)
    codeflash_output = aglu.forward(x); y = codeflash_output

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
import torch
from ultralytics.nn.modules.activation import AGLU

# unit tests

# ---------------- BASIC TEST CASES ----------------


def test_forward_scalar_zero():
    # Test with scalar input 0
    aglu = AGLU()
    x = torch.tensor(0.0)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 65.3μs -> 52.1μs (25.4% faster)


def test_forward_scalar_positive():
    # Test with scalar input > 0
    aglu = AGLU()
    x = torch.tensor(2.0)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 55.3μs -> 42.4μs (30.5% faster)


def test_forward_scalar_negative():
    # Test with scalar input < 0
    aglu = AGLU()
    x = torch.tensor(-2.0)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 53.6μs -> 42.7μs (25.5% faster)


def test_forward_vector_input():
    # Test with 1D tensor input
    aglu = AGLU()
    x = torch.tensor([-1.0, 0.0, 1.0, 2.0])
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 54.8μs -> 42.9μs (27.7% faster)


def test_forward_matrix_input():
    # Test with 2D tensor input
    aglu = AGLU()
    x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 55.0μs -> 43.9μs (25.2% faster)


def test_forward_preserves_dtype_and_device():
    # Test that dtype and device are preserved in output
    device = torch.device("cpu")
    dtype = torch.float64
    aglu = AGLU(device=device, dtype=dtype)
    x = torch.tensor([1.0, 2.0], dtype=dtype, device=device)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 55.6μs -> 43.7μs (27.4% faster)


# ---------------- EDGE TEST CASES ----------------


def test_forward_lambda_clamping():
    # Test that lambda is clamped to minimum value
    aglu = AGLU()
    # forcibly set lambd to a negative value
    aglu.lambd.data.fill_(-10.0)
    x = torch.tensor(1.0)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 53.9μs -> 41.3μs (30.5% faster)


def test_forward_lambda_small_value():
    # Test with lambda very close to zero (but positive)
    aglu = AGLU()
    aglu.lambd.data.fill_(1e-8)
    x = torch.tensor(1.0)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 54.5μs -> 41.7μs (30.8% faster)


def test_forward_kappa_large_value():
    # Test with a very large kappa value
    aglu = AGLU()
    aglu.kappa.data.fill_(1e6)
    x = torch.tensor(1.0)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 53.4μs -> 41.1μs (30.0% faster)


def test_forward_kappa_negative_value():
    # Test with a negative kappa value
    aglu = AGLU()
    aglu.kappa.data.fill_(-2.0)
    x = torch.tensor([1.0, -1.0])
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 53.8μs -> 41.2μs (30.6% faster)


def test_forward_extreme_inputs():
    # Test with very large and very small input values
    aglu = AGLU()
    x = torch.tensor([1e10, -1e10, 1e-10, -1e-10])
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 54.3μs -> 42.2μs (28.7% faster)


def test_forward_inf_nan_inputs():
    # Test with inf and nan inputs
    aglu = AGLU()
    x = torch.tensor([float("inf"), float("-inf"), float("nan")])
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 53.8μs -> 42.0μs (28.3% faster)


def test_forward_gradients():
    # Test that gradients can be computed through the module
    aglu = AGLU()
    x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 54.0μs -> 41.0μs (31.8% faster)
    s = out.sum()
    s.backward()


# ---------------- LARGE SCALE TEST CASES ----------------


def test_forward_large_batch():
    # Test with a large batch of inputs (but < 100MB)
    aglu = AGLU()
    # 10000 x 10 float32 is 400KB
    x = torch.randn(10000, 10)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 266μs -> 245μs (8.56% faster)


def test_forward_high_dimensional_input():
    # Test with high-dimensional input (e.g., 4D tensor)
    aglu = AGLU()
    x = torch.randn(8, 8, 8, 8)  # 4096 elements
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 74.5μs -> 62.0μs (20.2% faster)


def test_forward_multiple_calls_consistency():
    # Test that repeated calls with same input and parameters yield same output
    aglu = AGLU()
    x = torch.randn(100, 10)
    codeflash_output = aglu.forward(x)
    out1 = codeflash_output  # 61.2μs -> 49.9μs (22.7% faster)
    codeflash_output = aglu.forward(x)
    out2 = codeflash_output  # 28.8μs -> 23.5μs (22.6% faster)


def test_forward_large_lambda_and_kappa():
    # Test with large values for both lambda and kappa
    aglu = AGLU()
    aglu.lambd.data.fill_(1e3)
    aglu.kappa.data.fill_(1e3)
    x = torch.randn(100, 10)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 59.7μs -> 49.3μs (21.1% faster)


def test_forward_large_negative_lambda_and_kappa():
    # Test with large negative values for both lambda and kappa (lambda should clamp)
    aglu = AGLU()
    aglu.lambd.data.fill_(-1e3)
    aglu.kappa.data.fill_(-1e3)
    x = torch.randn(100, 10)
    codeflash_output = aglu.forward(x)
    out = codeflash_output  # 72.7μs -> 57.8μs (25.9% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-AGLU.forward-mi8g8qyk and push.

Codeflash Static Badge

The optimization achieves a 44% speedup by decomposing a complex compound expression into intermediate tensor operations. The key change is breaking down the single complex return statement:

```python
return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))
```

Into four separate operations:
```python
log_lam = torch.log(lam)
inv_lam = lam.reciprocal()
act_val = self.act((self.kappa * x) - log_lam)
return torch.exp(inv_lam * act_val)
```

**Why this is faster:**

1. **Eliminates redundant computation**: The original code calls `torch.log(lam)` and computes `1 / lam` within a complex nested expression, potentially causing PyTorch to create multiple intermediate tensors and perform suboptimal memory access patterns.

2. **Improves tensor operation efficiency**: Using `lam.reciprocal()` instead of `1 / lam` is more efficient for tensor division operations in PyTorch.

3. **Better memory layout and caching**: Breaking the computation into discrete steps allows PyTorch's tensor operations to be more cache-friendly and reduces temporary tensor allocations.

The line profiler shows the bottleneck shifted from the single complex expression (94.9% of time) to distributed operations, with the activation function call now taking 60.3% and the final exponential 29.5% of the time.

**Performance impact**: The optimization shows consistent 20-32% speedups across all test cases, from simple scalar inputs to large tensors (10,000+ elements). This suggests the optimization is particularly effective for neural network activation functions that are called frequently during forward passes, making it valuable for deep learning workloads where AGLU activations are used extensively.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 21, 2025 05:58
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant