Skip to content

Commit c8c5067

Browse files
Merge pull request #66 from matchms/fix_and_expand_architecture
Fix and expand architecture choices
2 parents 58f270b + 6abf1fe commit c8c5067

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

CHANGELOG.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## Changed
11+
12+
- Allow users to define L1 and L2 regularization of `SiameseModel` [#67](https://github.com/matchms/ms2deepscore/issues/67)
13+
- Allow users to define number and size of `SiameseModel` [#64](https://github.com/matchms/ms2deepscore/pull/64)
14+
1015
## [0.1.2] - 2021-03-05
1116

1217
## Added
1318

1419
- `create_confusion_matrix_plot` in `plotting` [#58](https://github.com/matchms/ms2deepscore/pull/58)
15-
- Allow users to define number and size of `SiameseModel` [#64](https://github.com/matchms/ms2deepscore/pull/64)
1620

1721
## [0.1.1] - 2021-02-09
1822

@@ -32,6 +36,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3236
- This is the initial version of MS2DeepScore
3337

3438
[Unreleased]: https://github.com/matchms/ms2deepscore/compare/0.1.2...HEAD
35-
[0.1.2]: https://github.com/matchms/ms2deepscore/releases/tag/0.1.1...0.1.2
36-
[0.1.1]: https://github.com/matchms/ms2deepscore/releases/tag/0.1.0...0.1.1
39+
[0.1.2]: https://github.com/matchms/ms2deepscore/compare/0.1.1...0.1.2
40+
[0.1.1]: https://github.com/matchms/ms2deepscore/compare/0.1.0...0.1.1
3741
[0.1.0]: https://github.com/matchms/ms2deepscore/releases/tag/0.1.0

ms2deepscore/models/SiameseModel.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def __init__(self,
4545
base_dims: Tuple[int, ...] = (600, 500, 500),
4646
embedding_dim: int = 400,
4747
dropout_rate: float = 0.5,
48+
l1_reg: float = 1e-6,
49+
l2_reg: float = 1e-6,
4850
keras_model: keras.Model = None):
4951
"""
5052
Construct SiameseModel
@@ -59,7 +61,11 @@ def __init__(self,
5961
embedding_dim
6062
Dimension of the embedding (i.e. the output of the base model)
6163
dropout_rate
62-
Dropout rate to be used in the base model
64+
Dropout rate to be used in the base model.
65+
l1_reg
66+
L1 regularization rate. Default is 1e-6.
67+
l2_reg
68+
L2 regularization rate. Default is 1e-6.
6369
keras_model
6470
When provided, this keras model will be used to construct the SiameseModel instance.
6571
Default is None.
@@ -75,7 +81,9 @@ def __init__(self,
7581
self.base = self._get_base_model(input_dim=self.input_dim,
7682
dims=base_dims,
7783
embedding_dim=embedding_dim,
78-
dropout_rate=dropout_rate)
84+
dropout_rate=dropout_rate,
85+
l1_reg=l1_reg,
86+
l2_reg=l2_reg)
7987
# Create head model
8088
self.model = self._get_head_model(input_dim=self.input_dim,
8189
base_model=self.base)
@@ -100,22 +108,24 @@ def save(self, filename: Union[str, Path]):
100108
def _get_base_model(input_dim: int,
101109
dims: Tuple[int, ...] = (600, 500, 500),
102110
embedding_dim: int = 400,
103-
dropout_rate: float = 0.25):
111+
dropout_rate: float = 0.25,
112+
l1_reg: float = 1e-6,
113+
l2_reg: float = 1e-6):
114+
# pylint: disable=too-many-arguments
104115
model_input = keras.layers.Input(shape=input_dim, name='base_input')
105116
for i, dim in enumerate(dims):
106117
if i == 0:
107118
model_layer = keras.layers.Dense(dim, activation='relu', name='dense'+str(i+1),
108-
kernel_regularizer=keras.regularizers.l1_l2(l1=1e-6, l2=1e-6))(
119+
kernel_regularizer=keras.regularizers.l1_l2(l1=l1_reg, l2=l2_reg))(
109120
model_input)
110121
else:
111-
model_layer = keras.layers.Dense(dim, activation='relu', name='dense'+str(i+1),
112-
kernel_regularizer=keras.regularizers.l1_l2(l1=1e-6, l2=1e-6))(
113-
model_layer)
122+
model_layer = keras.layers.Dense(dim, activation='relu',
123+
name='dense'+str(i+1))(model_layer)
114124
model_layer = keras.layers.BatchNormalization(name='normalization'+str(i+1))(model_layer)
115125
model_layer = keras.layers.Dropout(dropout_rate, name='dropout'+str(i+1))(model_layer)
116126

117-
embedding = keras.layers.Dense(embedding_dim, activation='relu', name='embedding')(
118-
model_layer)
127+
embedding = keras.layers.Dense(embedding_dim, activation='relu',
128+
name='embedding')(model_layer)
119129
return keras.Model(model_input, embedding, name='base')
120130

121131
@staticmethod

tests/test_models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def test_siamese_model():
4444
assert len(model.model.layers[2].layers) == len(model.base.layers) == 11, \
4545
"Expected different number of layers"
4646
assert model.model.input_shape == [(None, 339), (None, 339)], "Expected different input shape"
47+
np.testing.assert_array_almost_equal(model.base.layers[1].kernel_regularizer.l1, 1e-6), \
48+
"Expected different L1 regularization rate"
49+
np.testing.assert_array_almost_equal(model.base.layers[1].kernel_regularizer.l2, 1e-6), \
50+
"Expected different L2 regularization rate"
4751

4852
# Test base model inference
4953
X, y = test_generator.__getitem__(0)
@@ -67,6 +71,16 @@ def test_siamese_model_different_architecture():
6771
assert model.base.output_shape == (None, 100), "Expected different output shape of base model"
6872

6973

74+
def test_siamese_model_different_regularization_rates():
75+
spectrum_binner, test_generator = get_test_binner_and_generator()
76+
model = SiameseModel(spectrum_binner, base_dims=(200,),
77+
embedding_dim=100, l1_reg=1e-7, l2_reg=1e-5)
78+
np.testing.assert_array_almost_equal(model.base.layers[1].kernel_regularizer.l1, 1e-7), \
79+
"Expected different L1 regularization rate"
80+
np.testing.assert_array_almost_equal(model.base.layers[1].kernel_regularizer.l2, 1e-5), \
81+
"Expected different L2 regularization rate"
82+
83+
7084
def test_load_model():
7185
"""Test loading a model from file."""
7286
spectrum_binner, test_generator = get_test_binner_and_generator()

0 commit comments

Comments
 (0)