diff --git a/src/raygun/jax/networks/ResNet.py b/src/raygun/jax/networks/ResNet.py index 28c114e2..e35335ab 100644 --- a/src/raygun/jax/networks/ResNet.py +++ b/src/raygun/jax/networks/ResNet.py @@ -1,13 +1,16 @@ +import jax +import haiku as hk +import jax.numpy as jnp import functools -import torch -from raygun.jax.networks.utils import NoiseBlock, ParameterizedNoiseBlock +from raygun.jax.networks.utils import NoiseBlock -class ResnetGenerator2D(torch.nn.Module): + +class ResnetGenerator2D(hk.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations, and (optionally) the injection of a feature map of random noise into the first upsampling layer. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) """ - def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', activation=torch.nn.ReLU, add_noise=False, n_downsampling=2): + def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer= hk.BatchNorm, use_dropout=False, n_blocks=6, padding_type='reflect', activation=jax.nn.relu, add_noise=False, n_downsampling=2): """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images @@ -24,17 +27,17 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNor assert(n_blocks >= 0) super().__init__() if type(norm_layer) == functools.partial: - use_bias = norm_layer.func == torch.nn.InstanceNorm2d + use_bias = norm_layer.func == hk.InstanceNorm else: - use_bias = norm_layer == torch.nn.InstanceNorm2d - + use_bias = norm_layer == hk.InstanceNorm + p = 0 updown_p = 1 padder = [] - if padding_type.lower() == 'reflect': - padder = [torch.nn.ReflectionPad2d(3)] - elif padding_type.lower() == 'replicate': - padder = [torch.nn.ReplicationPad2d(3)] + # if padding_type.lower() == 'reflect': # TODO parallel in JAX? + # padder = [hk.pad.same(3)] + if padding_type.lower() == 'replicate': + padder = [hk.pad.same(3)] elif padding_type.lower() == 'zeros': p = 3 elif padding_type.lower() == 'valid': @@ -43,13 +46,13 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNor model = [] model += padder.copy() - model += [torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=p, bias=use_bias), + model += [hk.Conv2D(ngf, kernel_shape=7, padding=p, bias=use_bias), norm_layer(ngf), activation()] for i in range(n_downsampling): # add downsampling layers mult = 2 ** i - model += [torch.nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=updown_p, bias=use_bias), + model += [hk.Conv2D(ngf * mult, ngf * mult * 2, kernel_shape=3, stride=2, padding=updown_p, bias=use_bias), norm_layer(ngf * mult * 2), activation()] @@ -59,33 +62,35 @@ def __init__(self, input_nc=1, output_nc=1, ngf=64, norm_layer=torch.nn.BatchNor model += [ResnetBlock2D(ngf * mult, padding_type=padding_type.lower(), norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, activation=activation)] if add_noise == 'param': # add noise feature if necessary - model += [ParameterizedNoiseBlock()] + # model += [ParameterizedNoiseBlock()] + pass elif add_noise: model += [NoiseBlock()] for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) - model += [torch.nn.ConvTranspose2d(ngf * mult + (i==0 and (add_noise is not False)), + model += [hk.Conv2DTranspose(ngf * mult + (i==0 and (add_noise is not False)), int(ngf * mult / 2), - kernel_size=3, stride=2, - padding=updown_p, output_padding=updown_p, + kernel_shape=3, stride=2, + padding=updown_p, bias=use_bias), norm_layer(int(ngf * mult / 2)), activation()] model += padder.copy() - model += [torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=p)] - model += [torch.nn.Tanh()] + model += [hk.Conv2D(output_nc, kernel_shape=7, padding=p)] + model += [jax.nn.tanh()] - self.model = torch.nn.Sequential(*model) + self.model = hk.Sequential(*model) - def forward(self, input): + def __call__(self, input): """Standard forward""" return self.model(input) -class ResnetBlock2D(torch.nn.Module): + +class ResnetBlock2D(hk.Module): """Define a Resnet block""" - def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=jax.nn.relu): """Initialize the Resnet block A resnet block is a conv block with skip connections We construct a conv block with build_conv_block function, @@ -96,7 +101,7 @@ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, activat self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, activation) self.padding_type = padding_type - def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=torch.nn.ReLU): + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, activation=jax.nn.relu): """Construct a convolutional block. Parameters: dim (int) -- the number of channels in the conv layer. @@ -109,10 +114,10 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, """ p = 0 padder = [] - if padding_type == 'reflect': - padder = [torch.nn.ReflectionPad2d(1)] - elif padding_type == 'replicate': - padder = [torch.nn.ReplicationPad2d(1)] + # if padding_type == 'reflect': # TODO parallel in JAX? + # padder = [torch.nn.ReflectionPad2d(1)] + if padding_type == 'replicate': + padder = [hk.pad.same(1)] elif padding_type == 'zeros': p = 1 elif padding_type == 'valid': diff --git a/src/raygun/jax/networks/ResidualUNet.py b/src/raygun/jax/networks/ResidualUNet.py index ebce737b..6d3fd6b4 100644 --- a/src/raygun/jax/networks/ResidualUNet.py +++ b/src/raygun/jax/networks/ResidualUNet.py @@ -1,57 +1,38 @@ - -from funlib.learn.torch.models.conv4d import Conv4d +import jax +import haiku as hk +import jax.numpy as jnp import math -import numpy as np -import torch -import torch.nn as nn - from raygun.jax.networks.utils import NoiseBlock, ParameterizedNoiseBlock -class ConvPass(torch.nn.Module): - +class ConvPass(hk.Module): + def __init__( self, input_nc, output_nc, kernel_sizes, activation, - padding='valid', + padding='VALID', residual=False, - padding_mode='reflect', norm_layer=None, - final=False - ): - """Convolution pass block - - Args: - input_nc (int): Number of input channels - output_nc (int): Number of output channels - kernel_sizes (list(int) or array_like): Kernel sizes for convolution layers. - activation (str or callable): Name of activation function in 'torch.nn' or the function itself. - padding (str, optional): What type of padding to use in convolutions. Defaults to 'valid'. - residual (bool, optional): Whether to make the blocks calculate the residual. Defaults to False. - padding_mode (str, optional): What values to use in padding (i.e. 'zeros', 'reflect', 'wrap', etc.). Defaults to 'reflect'. - norm_layer (callable or None, optional): Whether to use a normalization layer and if so (i.e. if not None), the layer to use. Defaults to None. - final (bool, optional): Whether this block is the final output of the network (and thus should have the final activation omitted). Defaults to False. - - Returns: - ConvPass: Convolution block - """ - super(ConvPass, self).__init__() + data_format='NCDHW'): + + super().__init__() if activation is not None: if isinstance(activation, str): - self.activation = getattr(torch.nn, activation)() + self.activation = getattr(jax.nn, activation) else: - self.activation = activation() # assume is function + self.activation = activation # assume activation is a defined function else: - self.activation = nn.Identity() - + self.activation = jax.numpy.identity + + if activation is not None: + activation = getattr(jax.nn, activation) + self.residual = residual - self.padding = padding - self.final = final - + layers = [] for i, kernel_size in enumerate(kernel_sizes): @@ -59,20 +40,27 @@ def __init__( self.dims = len(kernel_size) conv = { - 2: torch.nn.Conv2d, - 3: torch.nn.Conv3d, - 4: Conv4d + 2: hk.Conv2D, + 3: hk.Conv3D, + # 4: Conv4d # TODO }[self.dims] + if data_format is None: + in_data_format = { + 2: 'NCHW', + 3: 'NCDHW' + }[self.dims] + else: + in_data_format = data_format + try: layers.append( conv( - input_nc, - output_nc, - kernel_size, - padding=padding, - padding_mode=padding_mode - )) + output_channels=output_nc, + kernel_shape=kernel_size, + padding=padding, + # padding_mode=padding_mode, + data_format=in_data_format)) if residual and i == 0: if input_nc < output_nc: groups = input_nc @@ -81,28 +69,26 @@ def __init__( self.x_init_map = conv( input_nc, output_nc, - np.ones(self.dims, dtype=int), + jnp.ones(self.dims, dtype=int), padding=padding, - padding_mode=padding_mode, + # padding_mode=padding_mode, TODO bias=False, - groups=groups + feature_group_count=groups ) except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) - + if norm_layer is not None: layers.append(norm_layer(output_nc)) - - if not ((residual and i == (len(kernel_sizes) - 1)) or (final and i == (len(kernel_sizes) - 1))): - layers.append(self.activation) - + + if not (residual and i == (len(kernel_sizes)-1)): + layers.append(activation) + input_nc = output_nc - self.conv_pass = torch.nn.Sequential(*layers) - + self.conv_pass = hk.Sequential(layers) + def crop(self, x, shape): - '''Center-crop x to match spatial dimensions given by shape.''' - x_target_size = x.size()[:-self.dims] + shape offset = tuple( @@ -114,8 +100,8 @@ def crop(self, x, shape): for o, s in zip(offset, x_target_size)) return x[slices] - - def forward(self, x): + + def __call__(self, x): if not self.residual: return self.conv_pass(x) else: @@ -124,120 +110,96 @@ def forward(self, x): init_x = self.crop(self.x_init_map(x), res.size()[-self.dims:]) else: init_x = self.x_init_map(x) - if not self.final: - return self.activation(init_x + res) - else: - return (init_x + res) + return self.activation(init_x + res) -class ConvDownsample(torch.nn.Module): + +class ConvDownsample(hk.Module): + def __init__( - self, - input_nc, - output_nc, - kernel_sizes, - downsample_factor, - activation, - padding='valid', - padding_mode='reflect', - norm_layer=None - ): - """Convolution-based downsampling - - Args: - input_nc (int): Number of input channels. - output_nc (int): Number of output channels. - kernel_sizes (list(int) or array_like): Kernel sizes for convolution layers. - downsample_factor (int): Factor by which to downsample in all spatial dimensions. - activation (str or callable): Name of activation function in 'torch.nn' or the function itself. - padding (str, optional): What type of padding to use in convolutions. Defaults to 'valid'. - padding_mode (str, optional): What values to use in padding (i.e. 'zeros', 'reflect', 'wrap', etc.). Defaults to 'reflect'. - norm_layer (callable or None, optional): Whether to use a normalization layer and if so (i.e. if not None), the layer to use. Defaults to None. - - Returns: - Downsampling layer. - """ - - super(ConvDownsample, self).__init__() + self, + # input_nc, + output_nc, + kernel_sizes, + downsample_factor, + activation, + padding='valid', + # padding_mode='reflect', + norm_layer=None, + data_format='NCDHW'): + + super().__init__() if activation is not None: if isinstance(activation, str): - self.activation = getattr(torch.nn, activation)() + self.activation = getattr(jax.nn, activation) else: - self.activation = activation() # assume is function + self.activation = activation() # assume activation is a defined function else: - self.activation = nn.Identity() - + self.activation = jax.numpy.identity() + self.padding = padding - + layers = [] self.dims = len(kernel_sizes) + conv = { - 2: torch.nn.Conv2d, - 3: torch.nn.Conv3d, - 4: Conv4d + 2: hk.Conv2D, + 3: hk.Conv3D, + # 4: Conv4d # TODO }[self.dims] + if data_format is None: + in_data_format = { + 2: 'NCHW', + 3: 'NCDHW' + }[self.dims] + else: + in_data_format = data_format + try: layers.append( conv( - input_nc, - output_nc, - kernel_sizes, + output_channels=output_nc, + kernel_shape=kernel_sizes, stride=downsample_factor, - # padding=padding, #TODO: Make work with same padding - padding='valid', - padding_mode=padding_mode - )) - + padding=padding, + # padding_mode=padding_mode, + data_format=in_data_format)) + except KeyError: raise RuntimeError("%dD convolution not implemented" % self.dims) - + if norm_layer is not None: layers.append(norm_layer(output_nc)) - - layers.append(self.activation) - self.conv_pass = torch.nn.Sequential(*layers) - - def forward(self, x): + + layers.append(self.activation) + self.conv_pass = hk.Sequential(layers) + + def __call__(self, x): return self.conv_pass(x) -class MaxDownsample(torch.nn.Module): +class MaxDownsample(hk.Module): # TODO: check data format type + def __init__( - self, - downsample_factor, - flexible=True): - """MaxPooling-based downsampling - - Args: - downsample_factor (list(int) or array_like): Factors to downsample by in each dimension. - flexible (bool, optional): True allows torch.nn.MaxPoolNd to crop the right/bottom of tensors in order to allow pooling of tensors not evenly divisible by the downsample_factor. Alternative implementations could pass 'ceil_mode=True' or 'padding= {# > 0}' to avoid cropping of inputs. False forces inputs to be evenly divisible by the downsample_factor, which generally restricts the flexibility of model architectures. Defaults to True. + self, + downsample_factor, + flexible=True): - Returns: - Downsampling layer. - """ - - super(MaxDownsample, self).__init__() - + super().__init__() + self.dims = len(downsample_factor) self.downsample_factor = downsample_factor self.flexible = flexible - - pool = { - 2: torch.nn.MaxPool2d, - 3: torch.nn.MaxPool3d, - 4: torch.nn.MaxPool3d # only 3D pooling, even for 4D input - }[self.dims] - - self.down = pool( - downsample_factor, - stride=downsample_factor, - ) - - def forward(self, x): + + self.down = hk.MaxPool(window_shape=downsample_factor, + strides=downsample_factor, + padding='VALID') + + def __call__(self, x): if self.flexible: try: return self.down(x) @@ -246,9 +208,9 @@ def forward(self, x): else: self.check_mismatch(x.size()) return self.down(x) - + def check_mismatch(self, size): - for d in range(1, self.dims + 1): + for d in range(1, self.dims+1): if size[-d] % self.downsample_factor[-d] != 0: raise RuntimeError( "Can not downsample shape %s with factor %s, mismatch " @@ -256,57 +218,54 @@ def check_mismatch(self, size): size, self.downsample_factor, self.dims - d)) - return -class Upsample(torch.nn.Module): +class Upsample(hk.Module): + def __init__( self, scale_factor, - mode=None, - input_nc=None, + mode='transposed_conv', output_nc=None, crop_factor=None, - next_conv_kernel_sizes=None): + next_conv_kernel_sizes=None, + data_format='NCDHW'): super(Upsample, self).__init__() - + if crop_factor is not None: assert next_conv_kernel_sizes is not None, "crop_factor and next_conv_kernel_sizes have to be given together" self.crop_factor = crop_factor self.next_conv_kernel_sizes = next_conv_kernel_sizes self.dims = len(scale_factor) - + if mode == 'transposed_conv': - up = { - 2: torch.nn.ConvTranspose2d, - 3: torch.nn.ConvTranspose3d + 2: hk.Conv2DTranspose, + 3: hk.Conv3DTranspose }[self.dims] + if data_format is None: + in_data_format = { + 2: 'NCHW', + 3: 'NCDHW' + }[self.dims] + else: + in_data_format = data_format + self.up = up( - input_nc, - output_nc, - kernel_size=scale_factor, - stride=scale_factor) + output_channels=output_nc, + kernel_shape=scale_factor, + stride=scale_factor, + data_format=in_data_format) else: - - self.up = torch.nn.Upsample( - scale_factor=scale_factor, - mode=mode) - + raise RuntimeError("Unimplemented") # Not implemented in Haiku + def crop_to_factor(self, x, factor, kernel_sizes): - '''Crop feature maps to ensure translation equivariance with stride of - upsampling factor. This should be done right after upsampling, before - application of the convolutions with the given kernel sizes. - - The crop could be done after the convolutions, but it is more efficient - to do that before (feature maps will be smaller). - ''' - shape = x.size() + shape = x.shape spatial_shape = shape[-self.dims:] # the crop that will already be done due to the convolutions @@ -315,19 +274,6 @@ def crop_to_factor(self, x, factor, kernel_sizes): for d in range(self.dims) ) - # we need (spatial_shape - convolution_crop) to be a multiple of - # factor, i.e.: - # - # (s - c) = n*k - # - # we want to find the largest n for which s' = n*k + c <= s - # - # n = floor((s - c)/k) - # - # this gives us the target shape s' - # - # s' = n*k + c - ns = ( int(math.floor(float(s - c)/f)) for s, c, f in zip(spatial_shape, convolution_crop, factor) @@ -358,11 +304,11 @@ def crop_to_factor(self, x, factor, kernel_sizes): def crop(self, x, shape): '''Center-crop x to match spatial dimensions given by shape.''' - x_target_size = x.size()[:-self.dims] + shape + x_target_size = x.shape[:-self.dims] + shape offset = tuple( (a - b)//2 - for a, b in zip(x.size(), x_target_size)) + for a, b in zip(x.shape, x_target_size)) slices = tuple( slice(o, o + s) @@ -370,7 +316,7 @@ def crop(self, x, shape): return x[slices] - def forward(self, f_left, g_out): + def __call__(self, f_left, g_out): g_up = self.up(g_out) @@ -382,137 +328,32 @@ def forward(self, f_left, g_out): else: g_cropped = g_up - f_cropped = self.crop(f_left, g_cropped.size()[-self.dims:]) + f_cropped = self.crop(f_left, g_cropped.shape[-self.dims:]) + + return jax.lax.concatenate((f_cropped, g_cropped), dimension=1) - return torch.cat([f_cropped, g_cropped], dim=1) -class ResidualUNet(torch.nn.Module): +class ResidualUNet(hk.Module): def __init__( self, - input_nc, + # input_nc, ngf, fmap_inc_factor, downsample_factors, kernel_size_down=None, kernel_size_up=None, - activation='ReLU', + activation='relu', + input_nc=None, output_nc=None, num_heads=1, constant_upsample=False, downsample_method='max', - padding_type='valid', + padding_type='VALID', residual=False, norm_layer=None, add_noise=False, ): - '''Create a U-Net:: - - f_in --> f_left --------------------------->> f_right--> f_out - | ^ - v | - g_in --> g_left ------->> g_right --> g_out - | ^ - v | - ... - - where each ``-->`` is a convolution pass, each `-->>` a crop, and down - and up arrows are max-pooling and transposed convolutions, - respectively. - - The U-Net expects 3D or 4D tensors shaped like:: - - ``(batch=1, channels, [length,] depth, height, width)``. - - It will perform 4D convolutions as long as ``length`` is greater than 1. - As soon as ``length`` is 1 due to a valid convolution, the time dimension will be - dropped and tensors with ``(b, c, z, y, x)`` will be use (and returned) - from there on. - - Args: - - input_nc: - - The number of input channels. - - ngf: - - The number of feature maps in the first layer. By default, this is also the - number of output feature maps. Stored in the ``channels`` - dimension. - - fmap_inc_factor: - - By how much to multiply the number of feature maps between - layers. If layer 0 has ``k`` feature maps, layer ``l`` will - have ``k*fmap_inc_factor**l``. - - downsample_factors: - - List of tuples ``(z, y, x)`` to use to down- and up-sample the - feature maps between layers. - - kernel_size_down (optional): - - List of lists of kernel sizes. The number of sizes in a list - determines the number of convolutional layers in the - corresponding level of the build on the left side. Kernel sizes - can be given as tuples or integer. If not given, each - convolutional pass will consist of two 3x3x3 convolutions. - - kernel_size_up (optional): - - List of lists of kernel sizes. The number of sizes in a list - determines the number of convolutional layers in the - corresponding level of the build on the right side. Within one - of the lists going from left to right. Kernel sizes can be - given as tuples or integer. If not given, each convolutional - pass will consist of two 3x3x3 convolutions. - - activation: - - Which activation to use after a convolution. Accepts the name - of any tensorflow activation function (e.g., ``ReLU`` for - ``torch.nn.ReLU``). - - output_nc (optional): - - The number of feature maps in the output layer. By default, this is the same as the - number of feature maps of the input layer. Stored in the ``channels`` - dimension. - - num_heads (optional): - - Number of decoders. The resulting U-Net has one single encoder - path and num_heads decoder paths. This is useful in a - multi-task learning context. - - constant_upsample (optional): - - If set to true, perform a constant upsampling instead of a - transposed convolution in the upsampling layers. - - downsample_method (optional): - - Whether to use max pooling ('max') or strided convolution ('convolve') for downsampling layers. Default is max pooling. - - padding_type (optional): - - How to pad convolutions. Either 'same' or 'valid' (default). - - residual (optional): - - Whether to train convolutional layers to output residuals to add to inputs (as in ResNet) or to directly convolve input data to output. Either 'True' or 'False' (default). - - norm_layer (optional): - - What, if any, normalization to layer after network layers. Default is none. - - add_noise (optional): - - Whether to add gaussian noise with 0 mean and unit variance ('True'), mean and variance parameterized by the network ('param'), or no noise ('False' <- default). - - ''' super(ResidualUNet, self).__init__() @@ -522,19 +363,16 @@ def __init__( self.input_nc = input_nc self.output_nc = output_nc if output_nc else ngf self.residual = residual - self.padding_type = padding_type if activation is not None: if isinstance(activation, str): - self.activation = getattr(torch.nn, activation)() + self.activation = getattr(jax.nn, activation)() else: self.activation = activation() # assume is function else: - self.activation = nn.Identity() + self.activation = jax.numpy.identity() - if add_noise == 'param': # add noise feature if necessary - self.noise_layer = ParameterizedNoiseBlock() - elif add_noise: + if add_noise == 'noise_block': self.noise_layer = NoiseBlock() else: self.noise_layer = None @@ -544,7 +382,7 @@ def __init__( kernel_size_down = [[(3,)*self.ndims, (3,)*self.ndims]]*self.num_levels if kernel_size_up is None: kernel_size_up = [[(3,)*self.ndims, (3,)*self.ndims]]*(self.num_levels - 1) - + # compute crop factors for translation equivariance crop_factors = [] factor_product = None @@ -566,7 +404,7 @@ def __init__( # modules # left convolutional passes - self.l_conv = nn.ModuleList([ + self.l_conv = [ ConvPass( input_nc if level == 0 @@ -578,20 +416,19 @@ def __init__( residual=self.residual, norm_layer=norm_layer) for level in range(self.num_levels) - ]) - self.dims = self.l_conv[0].dims + ] # left downsample layers if downsample_method.lower() == 'max': - self.l_down = nn.ModuleList([ + self.l_down = [ MaxDownsample(downsample_factors[level]) for level in range(self.num_levels - 1) - ]) + ] elif downsample_method.lower() == 'convolve': - self.l_down = nn.ModuleList([ + self.l_down = [ ConvDownsample( ngf*fmap_inc_factor**level, ngf*fmap_inc_factor**(level + 1), @@ -601,15 +438,15 @@ def __init__( padding=padding_type, norm_layer=norm_layer) for level in range(self.num_levels - 1) - ]) + ] else: raise RuntimeError(f'Unknown downsampling method {downsample_method}. Use "max" or "convolve" instead.') # right up/crop/concatenate layers - self.r_up = nn.ModuleList([ - nn.ModuleList([ + self.r_up = [ + [ Upsample( downsample_factors[level], mode='nearest' if constant_upsample else 'transposed_conv', @@ -618,13 +455,13 @@ def __init__( crop_factor=crop_factors[level], next_conv_kernel_sizes=kernel_size_up[level]) for level in range(self.num_levels - 1) - ]) + ] for _ in range(num_heads) - ]) + ] # right convolutional passes - self.r_conv = nn.ModuleList([ - nn.ModuleList([ + self.r_conv = [ + [ ConvPass( ngf*fmap_inc_factor**level + ngf*fmap_inc_factor**(level + 1), @@ -638,9 +475,9 @@ def __init__( norm_layer=norm_layer, final=(level==0)) for level in range(self.num_levels - 1) - ]) + ] for _ in range(num_heads) - ]) + ] def rec_forward(self, level, f_in): @@ -694,11 +531,11 @@ def crop(self, x, shape): return x[slices] - def forward(self, x): + def __call__(self, x): y = self.rec_forward(self.num_levels - 1, x) if self.padding_type.lower() == 'valid': - x = self.crop(x, y[0].size()[-self.ndims:]) + x = self.crop(x=x, shape=y[0].size()[-self.ndims:]) for i in range(self.num_heads): y[i] = self.activation(x + y[i]) diff --git a/src/raygun/jax/networks/UNet.py b/src/raygun/jax/networks/UNet.py index e42a9204..184024d4 100644 --- a/src/raygun/jax/networks/UNet.py +++ b/src/raygun/jax/networks/UNet.py @@ -1,8 +1,8 @@ import math -import numpy as np +import jax.numpy as jnp import jax import haiku as hk - +from raygun.jax.networks.utils import NoiseBlock, ParameterizedNoiseBlock class ConvPass(hk.Module): @@ -68,7 +68,7 @@ def __init__( self.x_init_map = conv( input_nc, output_nc, - np.ones(self.dims, dtype=int), + jnp.ones(self.dims, dtype=int), padding=padding, # padding_mode=padding_mode, TODO bias=False, @@ -110,60 +110,7 @@ def __call__(self, x): else: init_x = self.x_init_map(x) return self.activation(init_x + res) -# class ConvPass(hk.Module): - -# def __init__( -# self, -# out_channels, -# kernel_sizes, -# activation, -# padding='VALID', -# data_format='NCDHW'): - -# super().__init__() - -# if activation is not None: -# activation = getattr(jax.nn, activation) - -# layers = [] - -# for kernel_size in kernel_sizes: - -# self.dims = len(kernel_size) - -# conv = { -# 2: hk.Conv2D, -# 3: hk.Conv3D, -# # 4: Conv4d # TODO -# }[self.dims] - -# if data_format is None: -# in_data_format = { -# 2: 'NCHW', -# 3: 'NCDHW' -# }[self.dims] -# else: -# in_data_format = data_format - -# try: -# layers.append( -# conv( -# output_channels=out_channels, -# kernel_shape=kernel_size, -# padding=padding, -# data_format=in_data_format)) -# except KeyError: -# raise RuntimeError( -# "%dD convolution not implemented" % self.dims) - -# if activation is not None: -# layers.append(activation) - -# self.conv_pass = hk.Sequential(layers) - -# def __call__(self, x): -# return self.conv_pass(x) class ConvDownsample(hk.Module): @@ -401,7 +348,8 @@ def __init__(self, padding_type='VALID', residual=False, norm_layer=None, - name=None + name=None, + add_noise = False ): super().__init__(name=name) diff --git a/src/raygun/jax/networks/utils.py b/src/raygun/jax/networks/utils.py index a6ce8ac8..e2fe1fe9 100644 --- a/src/raygun/jax/networks/utils.py +++ b/src/raygun/jax/networks/utils.py @@ -1,6 +1,8 @@ # ORIGINALLY WRITTEN BY TRI NGUYEN (HARVARD, 2021) # WRITTEN IN JAX BY BRIAN REICHER (NORTHEASTERN, 2022) +#%% +from ast import Raise import jax import haiku as hk @@ -67,7 +69,7 @@ def init_weights(net, init_type='normal', init_gain=0.02, nonlinearity='relu'): pass -class NoiseBlock(hk.Module): +class NoiseBlock(): """Definies a block for producing and appending a feature map of gaussian noise with mean=0 and stdev=1""" def __init__(self): @@ -76,18 +78,21 @@ def __init__(self): def __call__(self, x): # TODO JAX tensors? shape = list(x.shape) shape[1] = 1 # only make one noise feature - noise = jax.numpy.empty(shape).to(x.device).normal_() + key = jax.random.PRNGKey(22) + noise = jax.random.normal(key=key, shape=shape) # noise = torch.empty(shape, device=x.device).normal_() # return torch.cat([x, noise.requires_grad_()], 1) return jax.numpy.concatenate(([x, noise]),1) -class ParameterizedNoiseBlock(hk.Module): +class ParameterizedNoiseBlock(): """Definies a block for producing and appending a feature map of gaussian noise with mean and stdev defined by the first two feature maps of the incoming tensor""" - def __init__(self): - super().__init__() + # def __init__(self): + # super().__init__() - def __call__(self, x): # TODO JAX tensors? - noise = jax.random.normal(x[:,0,...], jax.nn.relu(x[:,1,...])).unsqueeze(1) - return jax.numpy.concatenate([x, noise], 1) + # def __call__(self, x): # TODO JAX tensors? + # key = jax.random.PRNGKey(22) + # noise = jax.random.normal(key=key, x[:,0,...], jax.nn.relu(x[:,1,...])).unsqueeze(1) + # return jax.numpy.concatenate([x, noise], 1) + pass \ No newline at end of file diff --git a/src/raygun/jax/tests/network_test_jax.py b/src/raygun/jax/tests/network_test_jax.py index cbc0494e..d91be8b4 100644 --- a/src/raygun/jax/tests/network_test_jax.py +++ b/src/raygun/jax/tests/network_test_jax.py @@ -1,6 +1,6 @@ #%% import raygun -from raygun.jax.networks import * +from raygun.jax.networks import ResidualUNet, UNet, NLayerDiscriminator import jax import jax.numpy as jnp from jax import jit @@ -47,11 +47,12 @@ class MyModel(hk.Module): def __init__(self, name=None): super().__init__(name=name) - self.net = UNet( - ngf=3, - fmap_inc_factor=2, - downsample_factors=[[2,2,2],[2,2,2],[2,2,2]] - ) + # self.net = ResidualUNet( + # ngf=3, + # fmap_inc_factor=2, + # downsample_factors=[[2,2,2],[2,2,2],[2,2,2]] + # ) + self.net = NLayerDiscriminator(ndims=2, ngf=3) # net = getattr(raygun.jax.networks, network_type) # self.net = net(net_kwargs) @@ -59,7 +60,7 @@ def __call__(self, x): return self.net(x) def _forward(x): # Temporary set of _forward() - net = MyModel + net = MyModel() return net(x) @@ -188,7 +189,7 @@ def init_model(self): self.data_engine() self.rng = jax.random.PRNGKey(42) - self.model_params = self.model.initialize(rng_key=self.rng, inputs=self.inputs) + self.model_params = self.model.initialize(self.rng, self.inputs) # test train loop def train(self) -> None: