Skip to content

Commit

Permalink
Denoising task in JAX #13 #14
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 29, 2022
1 parent 8e93222 commit 02d0728
Showing 1 changed file with 81 additions and 51 deletions.
132 changes: 81 additions & 51 deletions src/raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import jmp
import time
from typing import Tuple, Any, NamedTuple, Dict
from skimage import data
import matplotlib.pyplot as plt
from tqdm import trange

from raygun.jax.networks.NLayerDiscriminator import NLayerDiscriminator3D


# PARAMETERS
Expand All @@ -21,7 +26,6 @@ class Params(NamedTuple):
loss_scale: jmp.LossScale


# should be the same as gunpowder.jax.GenericJaxModel
# replicated here to reduce dependency
class GenericJaxModel():

Expand All @@ -48,12 +52,11 @@ class MyModel(hk.Module):

def __init__(self, name=None):
super().__init__(name=name)
# self.net = UNet(
# ngf=2,
# fmap_inc_factor=2,
# downsample_factors=[[2,2,2],[2,2,2],[2,2,2]]
# )
self.net = NLayerDiscriminator2D()
self.net = UNet(
ngf=5,
fmap_inc_factor=2,
downsample_factors=[[2,2,2],[2,2,2],[2,2,2]]
)

def __call__(self, x):
return self.net(x)
Expand Down Expand Up @@ -133,48 +136,75 @@ def initialize(self, rng_key, inputs, is_training=True):
loss_scale = jmp.NoOpLossScale()
return Params(weight, opt_state, loss_scale)

def split(arr, n_devices):
"""Splits the first axis of `arr` evenly across the number of devices."""
return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

def create_network():
# returns a model that Gunpowder `Predict` and `Train` node can use
return Model()


my_model = Model()

n_devices = jax.local_device_count()
batch_size = 4*n_devices

raw = jnp.ones([batch_size, 1, 132, 132, 132])
gt = jnp.zeros([batch_size, 2, 40, 40, 40])
mask = jnp.ones([batch_size, 2, 40, 40, 40])
inputs = {
'raw': raw,
'gt': gt,
'mask': mask,
}
rng= jax.random.PRNGKey(42)

# init model
model_params = my_model.initialize(rng, inputs, is_training=True)

# #%%
# # test forward
# y = jit(my_model.forward)(model_params, {'raw': raw})
# #%%
# assert y['affs'].shape == (batch_size, 2, 40, 40, 40)


# test train loop
for _ in range(10):
t0 = time.time()
class NetworkTestJAX():

model_params, outputs, loss = jax.jit(
my_model.train_step,
donate_argnums=(0,),
static_argnums=(2,))(
model_params, inputs, False)

print(f'Loss: {loss}, took {time.time()-t0}s')
def __init__(self, task=None, im='astronaut', batch_size=None, noise_factor=3, model=Model(), num_epochs=15) -> None:
self.task = task
self.im = im
n_devices = jax.local_device_count()
if batch_size is None:
self.batch_size = 4*n_devices
else:
self.batch_size = batch_size
self.noise_factor = noise_factor

self.model = model
self.num_epochs = num_epochs

# TODO
self.inputs = None
self.model_params = None

def im2batch(self):
im = jnp.expand_dims(im, 0)
batch = []
for i in range(self.batch_size):
batch.append(jnp.expand_dims(im, 0))
return jnp.concatenate(batch)


def data_engine(self):
if self.task is None:
self.inputs = {
'raw': jnp.ones([self.batch_size, 1, 132, 132, 132]),
'gt': jnp.zeros([self.batch_size, 3, 40, 40, 40]),
'mask': jnp.ones([self.batch_size, 3, 40, 40, 40])
}
else:
gt_import = getattr(data, self.im)()
if len(gt_import.shape) > 2: # Strips to use only one image
gt_import = gt_import[...,0]
gt = self.im2batch(im= jnp.asarray(gt_import), batch_size=self.batch_size)

noise_key = jax.random.PRNGKey(22)
noise = self.im2batch(im=jax.random.uniform(key=noise_key, shape=gt_import.shape), batch_size=batch_size)

raw = (gt*noise) / self.noise_factor + (gt/self.noise_factor)

self.inputs = {
'raw': raw,
'gt': gt
}

# init model
def init_model(self):
if self.inputs is None: # Create data engine if it does not exist
self.data_engine()

rng, inputs = jax.random.PRNGKey(42), self.inputs
self.model_params = self.model.initialize(rng, inputs, is_training=True)

# test train loop
def train(self) -> None:
if self.model_params is None: # Init model if not created
self.init_model()
for _ in range(self.num_epochs):
t0 = time.time()

model_params, outputs, loss = jax.jit(
self.model.train_step,
donate_argnums=(0,),
static_argnums=(2,))(
self.model_params, self.inputs, False)

print(f'Loss: {loss}, took {time.time()-t0}s')

0 comments on commit 02d0728

Please sign in to comment.