Skip to content

Commit

Permalink
custom objects through config
Browse files Browse the repository at this point in the history
  • Loading branch information
SiLiKhon committed Nov 3, 2020
1 parent b0a0cd4 commit efb8fd1
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 3 deletions.
86 changes: 86 additions & 0 deletions models/configs/baseline_fc_8x16_kinked_trainable.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
latent_dim: 32
batch_size: 32
lr: 1.e-4
lr_schedule_rate: 0.999

num_disc_updates: 8
gp_lambda: 10.
gpdata_lambda: 0.
cramer: False
stochastic_stepping: True

save_every: 50
num_epochs: 10000

feature_noise_power: NULL
feature_noise_decay: NULL

data_version: 'data_v4'
pad_range: [-3, 5]
time_range: [-7, 9]
scaler: 'logarithmic'

architecture:
generator:
- block_type: 'fully_connected'
arguments:
units: [32, 64, 64, 64, 128]
activations: [
'elu', 'elu', 'elu', 'elu', 'custom_objects["TrainableActivation"]()'
]
kernel_init: 'glorot_uniform'
input_shape: [37,]
output_shape: [8, 16]
name: 'generator'

discriminator:
- block_type: 'connect'
arguments:
vector_shape: [5,]
img_shape: [8, 16]
vector_bypass: False
concat_outputs: True
name: 'discriminator_tail'
block:
block_type: 'conv'
arguments:
filters: [16, 16, 32, 32, 64, 64]
kernel_sizes: [3, 3, 3, 3, 3, 2]
paddings: ['same', 'same', 'same', 'same', 'valid', 'valid']
activations: ['elu', 'elu', 'elu', 'elu', 'elu', 'elu']
poolings: [NULL, [1, 2], NULL, 2, NULL, NULL]
kernel_init: glorot_uniform
input_shape: NULL
output_shape: [64,]
dropouts: [0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
name: discriminator_conv_block
- block_type: 'fully_connected'
arguments:
units: [128, 1]
activations: ['elu', NULL]
kernel_init: 'glorot_uniform'
input_shape: [69,]
output_shape: NULL
name: 'discriminator_head'

custom_objects: |
class TrainableActivation(tf.keras.layers.Layer):
def __init__(self, val=np.log10(2)):
super().__init__(autocast=False)
self.v = tf.Variable(0., dtype='float32', trainable=True)
self.val = val
def call(self, x):
val = self.val
slope = (tf.nn.elu(self.v * 50. + 50) + 1.)
return tf.where(
x >= 0,
val + x,
tf.exp(-tf.abs(x) * (slope + 1e-10)) * val
)
def get_config(self):
config = super().get_config().copy()
config.update(dict(val=self.val))
return config
6 changes: 4 additions & 2 deletions models/model_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def __init__(self, config):
self.latent_dim = config['latent_dim']

architecture_descr = config['architecture']
self.generator = nn.build_architecture(architecture_descr['generator'])
self.discriminator = nn.build_architecture(architecture_descr['discriminator'])
self.generator = nn.build_architecture(architecture_descr['generator'],
custom_objects_code=config.get('custom_objects', None))
self.discriminator = nn.build_architecture(architecture_descr['discriminator'],
custom_objects_code=config.get('custom_objects', None))

self.step_counter = tf.Variable(0, dtype='int32', trainable=False)

Expand Down
10 changes: 9 additions & 1 deletion models/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import numpy as np


custom_objects = {}


def get_activation(activation):
try:
activation = tf.keras.activations.get(activation)
Expand Down Expand Up @@ -188,7 +191,12 @@ def build_block(block_type, arguments):
return block


def build_architecture(block_descriptions, name=None):
def build_architecture(block_descriptions, name=None, custom_objects_code=None):
if custom_objects_code:
print("build_architecture(): got custom objects code, executing:")
print(custom_objects_code)
exec(custom_objects_code, globals(), custom_objects)

blocks = [build_block(**descr)
for descr in block_descriptions]

Expand Down
1 change: 1 addition & 0 deletions run_docker_jlab.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ else
fi

docker run -it \
--rm \
-u $(id -u):$(id -g) \
--env HOME=`pwd` \
-p 127.0.0.1:$PORT:8888/tcp \
Expand Down

0 comments on commit efb8fd1

Please sign in to comment.