In now, this repo provides general architectures
and functions
that are useful for the GAN and classification.
I will continue to add useful things to other areas.
Also, your pull requests and issues are always welcome.
And tell me what you want to implement on the issue. I'll implement it.
Functional API [Template code]
- More fast than Sequential
- More easy to create a flexible model architecture
- Easy to use some layer operaions like
concatenate
,add
, ...
- Easy to use some layer operaions like
- Define
tf.keras.layers.Input
- You have to know the
shape of input tensor
- You have to know the
- Define
tf.keras.Model
- You have to create the
Model
constructor at the end
- You have to create the
Sequential API [Template code]
- Simple to use
- Similar to Pytorch style
- Hard to create a flexible model architecture
ops_functional.py
- Functional API operations
- from ops_functional import *
ops_sequential.py
- Sequential API operations
- from ops_sequential import *
utils.py
- image processing + something useful functions (e.g. automatic_gpu_usage)
automatic_gpu_usage
: Automatically manage gpu memorymultiple_gpu_usage
: You can set gpu memory limit
- from utils import *
- image processing + something useful functions (e.g. automatic_gpu_usage)
from ops_functional import *
from utils import *
automatic_gpu_usage() # for efficient gpu use
input_shape = [img_height, img_width, img_ch]
inputs = tf.keras.layers.Input(input_shape, name='input')
# architecture
x = conv(inputs, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=False, name='conv')
x = batch_norm(x, name='batch_norm')
x = relu(x)
x = global_avg_pooling(x)
x = fully_connected(x, units=10, sn=False, name='fc')
model = tf.keras.Model(inputs, s, name='model')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
from ops_sequential import *
from utils import *
automatic_gpu_usage() # for efficient gpu use
model = []
model += [Conv(channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=False, sn=False, name='conv')]
model += [BatchNorm(name)]
model += [Relu()]
model += [Global_Avg_Pooling()]
model += [FullyConnected(units=10, sn=False, name='fc')]
model = Sequential(model, name='model')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
img_class = Image_data(img_height, img_width, img_ch, dataset_path, augment_flag)
img_class.preprocess()
img_slice = tf.data.Dataset.from_tensor_slices(img_class.dataset)
gpu_device = '/gpu:0'
img_slice = img_slice. \
apply(shuffle_and_repeat(dataset_num)). \
apply(map_and_batch(img_class.image_processing, self.batch_size,
num_parallel_batches=AUTOTUNE,
drop_remainder=True)). \
apply(prefetch_to_device(gpu_device, AUTOTUNE))
dataset_iter = iter(img_slice)
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=2)
start_iteration = 0
if manager.latest_checkpoint:
ckpt.restore(manager.latest_checkpoint)
start_iteration = int(manager.latest_checkpoint.split('-')[-1])
print('Latest checkpoint restored!!')
else:
print('Not restoring from saved checkpoint')
def train_step(img):
with tf.GradientTape() as tape:
logit = model(img)
# calculate loss
"""
if classification
your_loss = cross_entropy_loss(logit, label)
"""
loss = your_loss + regularization_loss(model)
train_variable = model.trainable_variables
gradient = tape.gradient(loss, train_variable)
optimizer.apply_gradients(zip(gradient, train_variable))
return loss
def train():
# setup tensorboard
summary_writer = tf.summary.create_file_writer(log_dir)
for idx in range(start_iteration, total_iteration):
img = next(dataset_iter)
# update network
loss = train_step(img)
# save to tensorboard
with summary_writer.as_default():
tf.summary.scalar('loss', loss, step=idx)
# save ckpt
manager.save(checkpoint_number=idx + 1)
# save model for final step
manager.save(checkpoint_number=total_iteration)
5-2. Multi-GPUs train [Template code]
strategy = tf.distribute.MirroredStrategy()
NUM_GPUS = strategy.num_replicas_in_sync
total_iteration = iteration // NUM_GPUS
with strategy.scope():
# copy & paste
# 2. Network template
# 3. Data pipeline
# 4. Restore
def train_step(img):
""" SAME """
def distribute_train_step(img):
with strategy.scope():
loss = strategy.experimental_run_v2(train_step, args=(img))
loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
return loss
def train():
# setup tensorboard
summary_writer = tf.summary.create_file_writer(log_dir)
for idx in range(start_iteration, total_iteration):
img = next(dataset_iter)
# update network
loss = distribute_train_step(img)
"""
SAME
"""
weight_initializer = tf.initializers.RandomNormal(mean=0.0, stddev=0.02)
weight_regularizer = tf.keras.regularizers.l2(0.0001)
weight_regularizer_fully = tf.keras.regularizers.l2(0.0001)
Xavier
: tf.initializers.GlorotUniform() or tf.initializers.GlorotNormal()He
: tf.initializers.VarianceScaling()Normal
: tf.initializers.RandomNormal(mean=0.0, stddev=0.02)Truncated normal
: tf.initializers.TruncatedNormal(mean=0.0, stddev=0.02)Orthogonal
: tf.initializers.Orthogonal0.02)
l2_decay
: tf.keras.regularizers.l2(0.0001)orthogonal_regularizer
: orthogonal_regularizer(0.0001) # orthogonal_regularizer_fully(0.0001)
padding='SAME'
- pad = ceil[ (kernel - stride) / 2 ]
pad_type
- 'zero' or 'reflect'
sn
- use spectral normalization of not
x = various_rnn(x, n_hidden=128, n_layer=2, dropout_rate=0.5, training=True, bidirectional=True, rnn_type='lstm', name='rnn')
x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, name='conv')
partial conv (NVIDIA Partial Convolution)
x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, name='partial_conv')
x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='VALID', sn=True, name='dilate_conv')
x = deconv(x, channels=64, kernel=3, stride=1, padding='SAME', use_bias=True, sn=True, name='deconv')
x = fully_connected(x, units=64, use_bias=True, sn=True, snamecope='fully_connected')
x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, name='pixel_shuffle_down')
x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, name='pixel_shuffle_up')
down
===> [height, width] -> [height // scale_factor, width // scale_factor]up
===> [height, width] -> [height * scale_factor, width * scale_factor]
x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, name='residual_block')
x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, name='residual_block_down')
x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, name='residual_block_up')
down
===> [height, width] -> [height // 2, width // 2]up
===> [height, width] -> [height * 2, width * 2]
x = denseblock(x, channels=64, n_db=6, is_training=is_training, use_bias=True, sn=True, name='denseblock')
n_db
===> The number of dense-block
x = res_denseblock(x, channels=64, n_rdb=20, n_rdb_conv=6, is_training=is_training, use_bias=True, sn=True, name='res_denseblock')
n_rdb
===> The number of RDBn_rdb_conv
===> per RDB conv layer
x = self_attention(x, use_bias=True, sn=True, name='self_attention')
x = self_attention_with_pooling(x, use_bias=True, sn=True, name='self_attention_version_2')
x = squeeze_excitation(x, ratio=16, use_bias=True, sn=True, name='squeeze_excitation')
x = convolution_block_attention(x, ratio=16, use_bias=True, sn=True, name='convolution_block_attention')
x = global_context_block(x, use_bias=True, sn=True, name='gc_block')
x = srm_block(x, use_bias=False, is_training=is_training, name='srm_block')
x = batch_norm(x, training=training, name='batch_norm')
x = layer_norm(x, name='layer_norm')
x = instance_norm(x, name='instance_norm')
x = group_norm(x, groups=32, name='group_norm')
x = pixel_norm(x)
x = batch_instance_norm(x, name='batch_instance_norm')
x = layer_instance_norm(x, name='layer_instance_norm')
x = switch_norm(x, scope='switch_norm')
x = condition_batch_norm(x, z, training=training, name='condition_batch_norm'):
x = adaptive_instance_norm(x, gamma, beta)
x = adaptive_layer_instance_norm(x, gamma, beta, smoothing=True, name='adaLIN')
- See this for how to use
condition_batch_norm
- See this for how to use
adaptive_instance_norm
- See this for how to use
adaptive_layer_instance_norm
&layer_instance_norm
x = relu(x)
x = lrelu(x, alpha=0.2)
x = tanh(x)
x = sigmoid(x)
x = swish(x)
x = elu(x)
x = nearest_up_sample(x, scale_factor=2)
x = bilinear_up_sample(x, scale_factor=2)
x = nearest_down_sample(x, scale_factor=2)
x = bilinear_down_sample(x, scale_factor=2)
x = max_pooling(x, pool_size=2)
x = avg_pooling(x, pool_size=2)
x = global_max_pooling(x)
x = global_avg_pooling(x)
x = flatten(x)
x = hw_flatten(x)
loss, accuracy = classification_loss(logit, label)
loss = dice_loss(n_classes=10, logit, label)
model_reg_loss = regularization_loss(model)
- If you want to use
regularizer
, then you should write it
loss = L1_loss(x, y)
loss = L2_loss(x, y)
loss = huber_loss(x, y)
loss = histogram_loss(x, y)
loss = gram_style_loss(x, y)
loss = color_consistency_loss(x, y)
histogram_loss
means the difference in the color distribution of the image pixel values.gram_style_loss
means the difference between the styles using gram matrix.color_consistency_loss
means the color difference between the generated image and the input image.
d_loss = discriminator_loss(Ra=True, gan_type='wgan-gp', real_logit=real_logit, fake_logit=fake_logit)
g_loss = generator_loss(Ra=True, gan_type='wgan-gp', real_logit=real_logit, fake_logit=fake_logit)
Ra
- use relativistic gan or not
loss_func
d_bottleneck_loss = vdb_loss(real_mu, real_logvar, i_c) + vdb_loss(fake_mu, fake_logvar, i_c)
loss = kl_loss(mean, logvar)