Skip to content

hk.Conv2DTranspose takes FOREVER to initialize and compile #724

@sokrypton

Description

@sokrypton

Not sure if this is a jax thing or dm-haiku... but recently I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.

here is an example:

import haiku as hk
import jax
from jax import random
import time

def toy_model(x):
  x = hk.Conv2DTranspose(32, 32, stride=16, padding="VALID")(x)
  return x

# Transform the model to be JAX-compatible
toy_model_init = hk.transform(toy_model).init
toy_model_apply = hk.transform(toy_model).apply

# Generate random input and params
key = random.PRNGKey(42)
x = random.normal(key, (1, 8, 8, 128))

# Time the model initialization
start_time = time.time()
params = toy_model_init(key, x)
end_time = time.time()
print(f"initialization Time: {end_time - start_time:.6f} seconds")

# Time the model compilation
start_time = time.time()
compiled_apply = jax.jit(toy_model_apply)
# Warm-up call (this compiles the function)
_ = compiled_apply(params, None, x)
end_time = time.time()
print(f"Compilation Time: {end_time - start_time:.6f} seconds")

# Time the model run
start_time = time.time()
o = compiled_apply(params, None, x)
print("input_shape",x.shape)
print("output_shape",o.shape)
end_time = time.time()
print(f"Run Time: {end_time - start_time:.6f} seconds")

output

initialization Time: 251.865844 seconds
Compilation Time: 255.010969 seconds
input_shape (1, 8, 8, 128)
output_shape (1, 144, 144, 32)
Run Time: 0.000671 seconds

for comparison, in pytorch:

Initialization Time: 0.033582 seconds
input_shape torch.Size([1, 128, 8, 8])
output_shape torch.Size([1, 32, 144, 144])
Run Time: 0.047478 seconds

Google colab notebook replicating the test:
https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions