Skip to content

Commit 19515e1

Browse files
committed
misc(configs): prepare complementarity trainings
1 parent 08ad49b commit 19515e1

11 files changed

+186
-69
lines changed

configs/vicreg_b256_comp_1.yml

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name: 'vicreg_b256_comp_1'
2+
encoder:
3+
type: 'thinresnet34'
4+
model:
5+
type: 'simclr'
6+
enable_mlp: true
7+
infonce_loss_factor: 1.0
8+
vic_reg_factor: 1.0
9+
representations_loss_vic: true
10+
representations_loss_nce: false
11+
embeddings_loss_vic: false
12+
embeddings_loss_nce: true
13+
training:
14+
epochs: 500
15+
batch_size: 256
16+
learning_rate: 0.001
17+
dataset:
18+
frame_length: 32000
19+
frame_split: true
20+
extract_mfcc: true
21+
train: './data/voxceleb1_train_list'
22+
val_ratio: 0.0
23+
spec_augment: false
24+
wav_augment:
25+
enable: true

configs/vicreg_b256_comp_2.yml

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name: 'vicreg_b256_comp_2'
2+
encoder:
3+
type: 'thinresnet34'
4+
model:
5+
type: 'simclr'
6+
enable_mlp: true
7+
infonce_loss_factor: 1.0
8+
vic_reg_factor: 1.0
9+
representations_loss_vic: false
10+
representations_loss_nce: true
11+
embeddings_loss_vic: true
12+
embeddings_loss_nce: false
13+
training:
14+
epochs: 500
15+
batch_size: 256
16+
learning_rate: 0.001
17+
dataset:
18+
frame_length: 32000
19+
frame_split: true
20+
extract_mfcc: true
21+
train: './data/voxceleb1_train_list'
22+
val_ratio: 0.0
23+
spec_augment: false
24+
wav_augment:
25+
enable: true

configs/vicreg_b256_comp_3.yml

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name: 'vicreg_b256_comp_3'
2+
encoder:
3+
type: 'thinresnet34'
4+
model:
5+
type: 'simclr'
6+
enable_mlp: true
7+
infonce_loss_factor: 1.0
8+
vic_reg_factor: 0.1
9+
representations_loss_vic: true
10+
representations_loss_nce: true
11+
embeddings_loss_vic: false
12+
embeddings_loss_nce: false
13+
training:
14+
epochs: 500
15+
batch_size: 256
16+
learning_rate: 0.001
17+
dataset:
18+
frame_length: 32000
19+
frame_split: true
20+
extract_mfcc: true
21+
train: './data/voxceleb1_train_list'
22+
val_ratio: 0.0
23+
spec_augment: false
24+
wav_augment:
25+
enable: true

configs/vicreg_b256_comp_4.yml

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name: 'vicreg_b256_comp_4'
2+
encoder:
3+
type: 'thinresnet34'
4+
model:
5+
type: 'simclr'
6+
enable_mlp: true
7+
infonce_loss_factor: 1.0
8+
vic_reg_factor: 0.1
9+
representations_loss_vic: false
10+
representations_loss_nce: false
11+
embeddings_loss_vic: true
12+
embeddings_loss_nce: true
13+
training:
14+
epochs: 500
15+
batch_size: 256
16+
learning_rate: 0.001
17+
dataset:
18+
frame_length: 32000
19+
frame_split: true
20+
extract_mfcc: true
21+
train: './data/voxceleb1_train_list'
22+
val_ratio: 0.0
23+
spec_augment: false
24+
wav_augment:
25+
enable: true

configs/vicreg_b256_mlp_512.yml

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
name: 'vicreg_b256_mlp_512'
2+
encoder:
3+
type: 'thinresnet34'
4+
model:
5+
type: 'simclr'
6+
enable_mlp: true
7+
mlp_dim: 512
8+
infonce_loss_factor: 0.0
9+
vic_reg_factor: 1.0
10+
training:
11+
epochs: 500
12+
batch_size: 256
13+
learning_rate: 0.001
14+
dataset:
15+
frame_length: 32000
16+
frame_split: true
17+
extract_mfcc: true
18+
train: './data/voxceleb1_train_list'
19+
val_ratio: 0.0
20+
spec_augment: false
21+
wav_augment:
22+
enable: true

evaluate_label_efficient.py

-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def train(
9696
# Disable features required only by self-supervised training
9797
config.dataset.wav_augment.enable = False
9898
config.dataset.frame_split = False
99-
config.dataset.provide_clean_and_aug = False
10099

101100
gens, input_shape, nb_classes = load_dataset(config)
102101
(train_gen, val_gen) = gens

run.sh

100755100644
File mode changed.

sslforslr/configs.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class DatasetConfig:
2424
frame_length: int = 16000
2525
frame_split: bool = False
2626
max_samples: int = None
27-
provide_clean_and_aug: bool = False
2827
extract_mfcc: bool = False
2928
spec_augment: bool = False
3029
val_ratio: float = 0.1

sslforslr/dataset/AudioDatasetLoader.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(
4141
labels,
4242
indices,
4343
wav_augment=None,
44-
provide_clean_and_aug=False,
4544
extract_mfcc=False
4645
):
4746
self.epoch = 0
@@ -54,7 +53,6 @@ def __init__(
5453
self.labels = labels
5554
self.indices = indices
5655
self.wav_augment = wav_augment
57-
self.provide_clean_and_aug = provide_clean_and_aug
5856
self.extract_mfcc = extract_mfcc
5957

6058
def __len__(self):
@@ -86,16 +84,8 @@ def __getitem__(self, i):
8684
min_length=2*self.frame_length
8785
) # (1, T)
8886
frame1, frame2 = sample_frames(data, self.frame_length)
89-
if self.provide_clean_and_aug:
90-
frame1_clean = self.preprocess_data(frame1, augment=False)
91-
frame1_aug = self.preprocess_data(frame1)
92-
X1.append(np.stack((frame1_clean, frame1_aug), axis=-1))
93-
frame2_clean = self.preprocess_data(frame2, augment=False)
94-
frame2_aug = self.preprocess_data(frame2)
95-
X2.append(np.stack((frame2_clean, frame2_aug), axis=-1))
96-
else:
97-
X1.append(self.preprocess_data(frame1))
98-
X2.append(self.preprocess_data(frame2))
87+
X1.append(self.preprocess_data(frame1))
88+
X2.append(self.preprocess_data(frame2))
9989
y.append(self.labels[index])
10090
elif self.supervised_sampler:
10191
frame1 = load_audio(self.files[index[0]], self.frame_length)
@@ -184,7 +174,6 @@ def load(self, batch_size):
184174
self.labels,
185175
train_indices,
186176
self.wav_augment,
187-
self.config.provide_clean_and_aug,
188177
self.config.extract_mfcc
189178
)
190179

@@ -198,7 +187,6 @@ def load(self, batch_size):
198187
self.labels,
199188
val_indices,
200189
self.wav_augment,
201-
self.config.provide_clean_and_aug,
202190
self.config.extract_mfcc
203191
)
204192

sslforslr/models/simclr/SimCLR.py

+55-49
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@ def __init__(self,
2323
super().__init__()
2424

2525
self.enable_mlp = config.enable_mlp
26-
self.enable_mse_clean_aug = config.enable_mse_clean_aug
2726
self.infonce_loss_factor = config.infonce_loss_factor
2827
self.vic_reg_factor = config.vic_reg_factor
2928
self.barlow_twins_factor = config.barlow_twins_factor
30-
self.mse_clean_aug_factor = config.mse_clean_aug_factor
3129
self.reg = regularizers.l2(config.weight_reg)
3230

31+
self.representations_loss_vic = config.representations_loss_vic
32+
self.representations_loss_nce = config.representations_loss_nce
33+
self.embeddings_loss_vic = config.embeddings_loss_vic
34+
self.embeddings_loss_nce = config.embeddings_loss_nce
35+
3336
self.encoder = encoder
3437
self.mlp = MLP(config.mlp_dim)
3538
self.infonce_loss = InfoNCELoss()
@@ -45,56 +48,65 @@ def compile(self, optimizer, **kwargs):
4548
self.optimizer = optimizer
4649

4750
def call(self, X):
48-
if len(X.shape) == 4 and self.enable_mse_clean_aug:
49-
X, _ = self.extract_clean_and_aug(X)
5051
return self.encoder(X)
5152

5253
@tf.function
53-
def get_embeddings(self, X_1, X_2):
54-
Z_1 = self.encoder(X_1, training=True)
55-
Z_2 = self.encoder(X_2, training=True)
56-
if self.enable_mlp:
57-
Z_1 = self.mlp(Z_1, training=True)
58-
Z_2 = self.mlp(Z_2, training=True)
59-
return Z_1, Z_2
54+
def representations_loss(self, Z_1, Z_2):
55+
loss, accuracy = 0, 0
56+
if self.representations_loss_nce:
57+
loss, accuracy = self.infonce_loss((Z_1, Z_2))
58+
loss = self.infonce_loss_factor * loss
59+
if self.representations_loss_vic:
60+
loss += self.vic_reg_factor * self.vic_reg((Z_1, Z_2))
61+
return loss, accuracy
6062

6163
@tf.function
62-
def extract_clean_and_aug(self, X):
63-
X_clean, X_aug = tf.split(X, 2, axis=-1)
64-
X_clean = tf.squeeze(X_clean, axis=-1)
65-
X_aug = tf.squeeze(X_aug, axis=-1)
66-
return X_clean, X_aug
64+
def embeddings_loss(self, Z_1, Z_2):
65+
loss, accuracy = 0, 0
66+
if self.embeddings_loss_nce:
67+
loss, accuracy = self.infonce_loss((Z_1, Z_2))
68+
loss = self.infonce_loss_factor * loss
69+
if self.embeddings_loss_vic:
70+
loss += self.vic_reg_factor * self.vic_reg((Z_1, Z_2))
71+
return loss, accuracy
6772

6873
def train_step(self, data):
69-
X_1_aug, X_2_aug, _ = data
74+
X_1, X_2, _ = data
7075
# X shape: (B, H, W, C) = (B, 40, 200, 1)
7176

72-
if self.enable_mse_clean_aug:
73-
X_1_clean, X_1_aug = self.extract_clean_and_aug(X_1_aug)
74-
X_2_clean, X_2_aug = self.extract_clean_and_aug(X_2_aug)
75-
7677
with tf.GradientTape() as tape:
77-
Z_1_aug, Z_2_aug = self.get_embeddings(X_1_aug, X_2_aug)
78-
79-
loss, accuracy = self.infonce_loss((Z_1_aug, Z_2_aug))
80-
loss = self.infonce_loss_factor * loss
81-
loss += self.vic_reg_factor * self.vic_reg((Z_1_aug, Z_2_aug))
82-
loss += self.barlow_twins_factor * self.barlow_twins((Z_1_aug, Z_2_aug))
83-
84-
if self.enable_mse_clean_aug:
85-
Z_1_clean, Z_2_clean = self.get_embeddings(X_1_clean, X_2_clean)
86-
loss += self.mse_clean_aug_factor * mse_loss(Z_1_clean, Z_1_aug)
87-
loss += self.mse_clean_aug_factor * mse_loss(Z_2_clean, Z_2_aug)
88-
89-
trainable_params = self.encoder.trainable_weights
90-
if self.enable_mlp:
91-
trainable_params += self.mlp.trainable_weights
92-
93-
grads = tape.gradient(loss, trainable_params)
94-
# grads, _ = tf.clip_by_global_norm(grads, 5.0)
95-
self.optimizer.apply_gradients(zip(grads, trainable_params))
96-
97-
return { 'loss': loss, 'accuracy': accuracy }
78+
Z_1 = self.encoder(X_1, training=True)
79+
Z_2 = self.encoder(X_2, training=True)
80+
representations_loss, representations_accuracy = self.representations_loss(
81+
Z_1,
82+
Z_2
83+
)
84+
85+
if self.enable_mlp:
86+
Z_1 = self.mlp(Z_1, training=True)
87+
Z_2 = self.mlp(Z_2, training=True)
88+
embeddings_loss, embeddings_accuracy = self.embeddings_loss(
89+
Z_1,
90+
Z_2
91+
)
92+
93+
# Apply representations loss
94+
params = self.encoder.trainable_weights
95+
grads = tape.gradient(representations_loss, params)
96+
self.optimizer.apply_gradients(zip(grads, params))
97+
98+
# Aplly embeddings loss
99+
params = self.encoder.trainable_weights
100+
params += self.mlp.trainable_weights
101+
grads = tape.gradient(embeddings_loss, params)
102+
self.optimizer.apply_gradients(zip(grads, params))
103+
104+
return {
105+
'representations_loss': representations_loss,
106+
'representations_accuracy': representations_accuracy,
107+
'embeddings_loss': embeddings_loss,
108+
'embeddings_accuracy': embeddings_accuracy
109+
}
98110

99111

100112
class MLP(Model):
@@ -156,10 +168,4 @@ def call(self, data):
156168
preds_acc = tf.math.equal(pred_indices, labels)
157169
accuracy = tf.math.count_nonzero(preds_acc, dtype=tf.int32) / batch_size
158170

159-
return loss, accuracy
160-
161-
162-
@tf.function
163-
def mse_loss(Z_clean, Z_aug):
164-
mse = tf.keras.metrics.mean_squared_error(Z_clean, Z_aug)
165-
return tf.math.reduce_mean(mse)
171+
return loss, accuracy

sslforslr/models/simclr/SimCLRModelConfig.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from typing import List
23

34
from sslforslr.configs import ModelConfig
45

@@ -16,10 +17,12 @@ class SimCLRModelConfig(ModelConfig):
1617

1718
barlow_twins_factor: float = 0.0
1819
barlow_twins_lambda: float = 0.05
19-
20-
enable_mse_clean_aug: bool = False
21-
mse_clean_aug_factor: float = 0.1
22-
20+
21+
representations_loss_vic: bool = False
22+
representations_loss_nce: bool = False
23+
embeddings_loss_vic: bool = True
24+
embeddings_loss_nce: bool = True
25+
2326
weight_reg: float = 1e-4
2427

2528
SimCLRModelConfig.__NAME__ = 'simclr'

0 commit comments

Comments
 (0)