Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Depthwise convolution in flax.nnx.Conv is significantly slower than PyTorch and TensorFlow #4207

Open
YushaArif99 opened this issue Sep 18, 2024 · 0 comments

Comments

@YushaArif99
Copy link

Description

I have noticed that the depthwise convolution using nnx.Conv is much slower compared to PyTorch and TensorFlow. I ran a quick benchmark to compare performance across the three frameworks, and the results show a huge gap in speed for the JAX version.

Here's a minimal example I used for benchmarking:

import flax.nnx as nnx
import torch
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf

dim = 256
dwconv_pt = torch.nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
dwconv_jax = nnx.Conv(
            in_features=dim,
            out_features=dim,
            kernel_size=3,
            strides=1,
            padding=1,
            use_bias=True,
            feature_group_count=dim,
            input_dilation=1,
            kernel_dilation=1,
            rngs= nnx.Rngs(0)
        )
dwconv_tf = tf.keras.layers.DepthwiseConv2D(
    kernel_size=(3, 3),
    strides=(1, 1),
    padding='same',
    depth_multiplier=1,
    use_bias=True
)

np_input = np.random.rand(1, dim, 124, 124).astype('float32')
jax_input = jnp.array(np.transpose(np_input, (0, 2, 3, 1)))
tf_input = tf.convert_to_tensor(np.transpose(np_input, (0, 2, 3, 1)))
torch_input = torch.from_numpy(np_input)

# Benchmarking
import time
for _ in range(10):
    dwconv_pt(torch_input)
    dwconv_jax(jax_input)
    dwconv_tf(tf_input)

# Timings
start_time = time.time()
for _ in range(50):
    dwconv_pt(torch_input)
pt_time = time.time() - start_time

start_time = time.time()
for _ in range(50):
    dwconv_jax(jax_input)
jax_time = time.time() - start_time

start_time = time.time()
for _ in range(50):
    dwconv_tf(tf_input)
tf_time = time.time() - start_time

print(f'Tensorflow time: {tf_time}s')
print(f'PyTorch time: {pt_time}s')
print(f'JAX time: {jax_time}s')

Results:

TensorFlow time: 0.77s
PyTorch time: 0.49s
JAX (Flax) time: 36.57s

System info (python version, jaxlib version, accelerator, etc.)

Python version: 3.10.12
JAX version: 0.4.26
Flax version: 0.8.4

Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          46 bits physical, 48 bits virtual
  Byte Order:             Little Endian
CPU(s):                   2
  On-line CPU(s) list:    0,1
Vendor ID:                GenuineIntel
  Model name:             Intel(R) Xeon(R) CPU @ 2.20GHz
    CPU family:           6
    Model:                79
    Thread(s) per core:   2
    Core(s) per socket:   1
    Socket(s):            1
    Stepping:             0
    BogoMIPS:             4399.99
    Flags:                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 cl
                          flush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc re
                          p_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3
                           fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand
                           hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp 
                          fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx sm
                          ap xsaveopt arat md_clear arch_capabilities
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant