Skip to content

Commit

Permalink
Issue #44 put _create_2d_block out of class
Browse files Browse the repository at this point in the history
  • Loading branch information
thompson318 committed Jul 27, 2022
1 parent 58b6203 commit e086b36
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions sksurgerytf/models/rgb_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,25 +291,6 @@ def _load_data(self):

self.validate_generator = zip(validate_image_generator, validate_mask_generator)

# pylint: disable=no-self-use
def _create_2d_block(self, input_tensor, num_filters, kernel_size, batch_norm=True):

model = keras.layers.Conv2D(num_filters, kernel_size, padding='same', kernel_initializer='he_normal')(input_tensor)

if batch_norm:
model = keras.layers.BatchNormalization()(model)

model = keras.layers.Activation('relu')(model)

model = keras.layers.Conv2D(num_filters, kernel_size, padding='same', kernel_initializer='he_normal')(model)

if batch_norm:
model = keras.layers.BatchNormalization()(model)

model = keras.layers.Activation('relu')(model)

return model

def _build_model(self):
"""
Constructs the neural network, and compiles it.
Expand All @@ -325,45 +306,45 @@ def _build_model(self):
inputs = keras.Input(self.input_size)

# Left side of UNet
conv1 = self._create_2d_block(inputs, num_filters * 1, kernel_size=kernel_size, batch_norm=batch_norm)
conv1 = _create_2d_block(inputs, num_filters * 1, kernel_size=kernel_size, batch_norm=batch_norm)
pool1 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv1)
pool1 = keras.layers.Dropout(dropout)(pool1)

conv2 = self._create_2d_block(pool1, num_filters * 2, kernel_size=kernel_size, batch_norm=batch_norm)
conv2 = _create_2d_block(pool1, num_filters * 2, kernel_size=kernel_size, batch_norm=batch_norm)
pool2 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv2)
pool2 = keras.layers.Dropout(dropout)(pool2)

conv3 = self._create_2d_block(pool2, num_filters * 4, kernel_size=kernel_size, batch_norm=batch_norm)
conv3 = _create_2d_block(pool2, num_filters * 4, kernel_size=kernel_size, batch_norm=batch_norm)
pool3 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv3)
pool3 = keras.layers.Dropout(dropout)(pool3)

conv4 = self._create_2d_block(pool3, num_filters * 8, kernel_size=kernel_size, batch_norm=batch_norm)
conv4 = _create_2d_block(pool3, num_filters * 8, kernel_size=kernel_size, batch_norm=batch_norm)
pool4 = keras.layers.MaxPooling2D((pooling_size, pooling_size))(conv4)
pool4 = keras.layers.Dropout(dropout)(pool4)

# Bottom of UNet
conv5 = self._create_2d_block(pool4, num_filters * 16, kernel_size=kernel_size, batch_norm=batch_norm)
conv5 = _create_2d_block(pool4, num_filters * 16, kernel_size=kernel_size, batch_norm=batch_norm)

# Right side of UNet
up6 = keras.layers.Conv2DTranspose(num_filters * 8, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv5)
up6 = keras.layers.concatenate([up6, conv4])
up6 = keras.layers.Dropout(dropout)(up6)
conv6 = self._create_2d_block(up6, num_filters * 8, kernel_size=3, batch_norm=batch_norm)
conv6 = _create_2d_block(up6, num_filters * 8, kernel_size=3, batch_norm=batch_norm)

up7 = keras.layers.Conv2DTranspose(num_filters * 4, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv6)
up7 = keras.layers.concatenate([up7, conv3])
up7 = keras.layers.Dropout(dropout)(up7)
conv7 = self._create_2d_block(up7, num_filters * 4, kernel_size=3, batch_norm=batch_norm)
conv7 = _create_2d_block(up7, num_filters * 4, kernel_size=3, batch_norm=batch_norm)

up8 = keras.layers.Conv2DTranspose(num_filters * 2, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv7)
up8 = keras.layers.concatenate([up8, conv2])
up8 = keras.layers.Dropout(dropout)(up8)
conv8 = self._create_2d_block(up8, num_filters * 2, kernel_size=3, batch_norm=batch_norm)
conv8 = _create_2d_block(up8, num_filters * 2, kernel_size=3, batch_norm=batch_norm)

up9 = keras.layers.Conv2DTranspose(num_filters * 1, 3, strides=(2, 2), padding='same', kernel_initializer='he_normal')(conv8)
up9 = keras.layers.concatenate([up9, conv1])
up9 = keras.layers.Dropout(dropout)(up9)
conv9 = self._create_2d_block(up9, num_filters * 1, kernel_size=3, batch_norm=batch_norm)
conv9 = _create_2d_block(up9, num_filters * 1, kernel_size=3, batch_norm=batch_norm)

conv10 = keras.layers.Conv2D(1, 1, padding='same', activation='sigmoid')(conv9)

Expand Down Expand Up @@ -625,3 +606,22 @@ def _copy_images(src_dir, dst_dir):
os.path.dirname(src_dir)) + "_" +
os.path.basename(image_file))
os.symlink(image_file, destination)


def _create_2d_block(input_tensor, num_filters, kernel_size, batch_norm=True):

model = keras.layers.Conv2D(num_filters, kernel_size, padding='same', kernel_initializer='he_normal')(input_tensor)

if batch_norm:
model = keras.layers.BatchNormalization()(model)

model = keras.layers.Activation('relu')(model)

model = keras.layers.Conv2D(num_filters, kernel_size, padding='same', kernel_initializer='he_normal')(model)

if batch_norm:
model = keras.layers.BatchNormalization()(model)

model = keras.layers.Activation('relu')(model)

return model

0 comments on commit e086b36

Please sign in to comment.