Skip to content

Commit

Permalink
Achieved successful JAX training #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 25, 2022
1 parent a6d1ee4 commit e4e7159
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 38 deletions.
4 changes: 3 additions & 1 deletion raygun/jax/networks/ResidualUNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import torch
import torch.nn as nn

from utils import NoiseBlock, ParameterizedNoiseBlock
# from utils import NoiseBlock, ParameterizedNoiseBlock
from raygun.torch.networks.utils import NoiseBlock, ParameterizedNoiseBlock


class ConvPass(torch.nn.Module):

Expand Down
3 changes: 2 additions & 1 deletion raygun/jax/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
JAX network architectures
"""
from .UNet import UNet
from .ResidualUNet import ResidualUNet
# from .ResidualUNet import ResidualUNet
from .MIRNet2D import MIRNet
# from .ResNet import ResNet
from .utils import *
50 changes: 14 additions & 36 deletions raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class MyModel(hk.Module):
def __init__(self, name=None):
super().__init__(name=name)
self.unet = UNet(
ngf=24,
fmap_inc_factor=3,
ngf=2,
fmap_inc_factor=2,
downsample_factors=[[2,2,2],[2,2,2],[2,2,2]]
)

Expand Down Expand Up @@ -147,53 +147,31 @@ def create_network():
batch_size = 4*n_devices

raw = jnp.ones([batch_size, 1, 132, 132, 132])
gt = jnp.zeros([batch_size, 3, 40, 40, 40])
mask = jnp.ones([batch_size, 3, 40, 40, 40])
gt = jnp.zeros([batch_size, 2, 40, 40, 40])
mask = jnp.ones([batch_size, 2, 40, 40, 40])
inputs = {
'raw': raw,
'gt': gt,
'mask': mask,
}
rng= jax.random.PRNGKey(42)
#%%

# init model
if n_devices > 1:
# split input for pmap
raw = split(raw, n_devices)
gt = split(gt, n_devices)
mask = split(mask, n_devices)
single_device_inputs = {
'raw': raw,
'gt': gt,
'mask': mask
}
rng = jnp.broadcast_to(rng, (n_devices,) + rng.shape)
model_params = jax.pmap(my_model.initialize)(rng, single_device_inputs)

else:
model_params = my_model.initialize(rng, inputs, is_training=True)
model_params = my_model.initialize(rng, inputs, is_training=True)
#%%
# test forward
y = jit(my_model.forward)(model_params, {'raw': raw})

assert y['affs'].shape == (batch_size, 3, 40, 40, 40)
#%%
assert y['affs'].shape == (batch_size, 2, 40, 40, 40)
#%%
# test train loop
for _ in range(10):
t0 = time.time()

if n_devices > 1:
model_params, outputs, loss = jax.pmap(
my_model.train_step,
axis_name='num_devices',
donate_argnums=(0,),
static_broadcasted_argnums=(2,))(
model_params, inputs, True)
else:
model_params, outputs, loss = jax.jit(
my_model.train_step,
donate_argnums=(0,),
static_argnums=(2,))(
model_params, inputs, False)

model_params, outputs, loss = jax.jit(
my_model.train_step,
donate_argnums=(0,),
static_argnums=(2,))(
model_params, inputs, False)

print(f'Loss: {loss}, took {time.time()-t0}s')

0 comments on commit e4e7159

Please sign in to comment.