Skip to content

Commit

Permalink
ResNet JAX updates #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 30, 2022
1 parent 9677a77 commit bd8db54
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 424 deletions.
61 changes: 33 additions & 28 deletions src/raygun/jax/networks/ResNet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import jax
import haiku as hk
import jax.numpy as jnp
import functools
import torch
from raygun.jax.networks.utils import NoiseBlock, ParameterizedNoiseBlock
from raygun.jax.networks.utils import NoiseBlock

class ResnetGenerator2D(torch.nn.Module):

class ResnetGenerator2D(hk.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""

def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', activation=torch.nn.ReLU, add_noise=False, n_downsampling=2):
def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='reflect', activation=jax.nn.relu, add_noise=False, n_downsampling=2):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
Expand All @@ -24,17 +27,17 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNor
assert(n_blocks >= 0)
super().__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == torch.nn.InstanceNorm2d
use_bias = norm_layer.func == hk.InstanceNorm
else:
use_bias = norm_layer == torch.nn.InstanceNorm2d

use_bias = norm_layer == hk.InstanceNorm
p = 0
updown_p = 1
padder = []
if padding_type.lower() == 'reflect':
padder = [torch.nn.ReflectionPad2d(3)]
elif padding_type.lower() == 'replicate':
padder = [torch.nn.ReplicationPad2d(3)]
# if padding_type.lower() == 'reflect': # TODO parallel in JAX?
# padder = [hk.pad.same(3)]
if padding_type.lower() == 'replicate':
padder = [hk.pad.same(3)]
elif padding_type.lower() == 'zeros':
p = 3
elif padding_type.lower() == 'valid':
Expand All @@ -43,13 +46,13 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNor

model = []
model += padder.copy()
model += [torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias),
model += [hk.Conv2D(ngf, kernel_shape=7, padding=p, bias=use_bias),
norm_layer(ngf),
activation()]

for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [torch.nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=updown_p, bias=use_bias),
model += [hk.Conv2D(ngf * mult, ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, bias=use_bias),
norm_layer(ngf * mult * 2),
activation()]

Expand All @@ -59,33 +62,35 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNor
model += [ResnetBlock2D(ngf * mult, padding_type=padding_type.lower(), norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)]

if add_noise == 'param': # add noise feature if necessary
model += [ParameterizedNoiseBlock()]
# model += [ParameterizedNoiseBlock()]
pass
elif add_noise:
model += [NoiseBlock()]

for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [torch.nn.ConvTranspose2d(ngf * mult + (i==0 and (add_noise is not False)),
model += [hk.Conv2DTranspose(ngf * mult + (i==0 and (add_noise is not False)),
int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=updown_p, output_padding=updown_p,
kernel_shape=3, stride=2,
padding=updown_p,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
activation()]
model += padder.copy()
model += [torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=p)]
model += [torch.nn.Tanh()]
model += [hk.Conv2D(output_nc, kernel_shape=7, padding=p)]
model += [jax.nn.tanh()]

self.model = torch.nn.Sequential(*model)
self.model = hk.Sequential(*model)

def forward(self, input):
def __call__(self, input):
"""Standard forward"""
return self.model(input)

class ResnetBlock2D(torch.nn.Module):

class ResnetBlock2D(hk.Module):
"""Define a Resnet block"""

def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=jax.nn.relu):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
Expand All @@ -96,7 +101,7 @@ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activat
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, activation)
self.padding_type = padding_type

def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU):
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=jax.nn.relu):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
Expand All @@ -109,10 +114,10 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,
"""
p = 0
padder = []
if padding_type == 'reflect':
padder = [torch.nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
padder = [torch.nn.ReplicationPad2d(1)]
# if padding_type == 'reflect': # TODO parallel in JAX?
# padder = [torch.nn.ReflectionPad2d(1)]
if padding_type == 'replicate':
padder = [hk.pad.same(1)]
elif padding_type == 'zeros':
p = 1
elif padding_type == 'valid':
Expand Down
Loading

0 comments on commit bd8db54

Please sign in to comment.