@@ -23,13 +23,16 @@ def __init__(self,
23
23
super ().__init__ ()
24
24
25
25
self .enable_mlp = config .enable_mlp
26
- self .enable_mse_clean_aug = config .enable_mse_clean_aug
27
26
self .infonce_loss_factor = config .infonce_loss_factor
28
27
self .vic_reg_factor = config .vic_reg_factor
29
28
self .barlow_twins_factor = config .barlow_twins_factor
30
- self .mse_clean_aug_factor = config .mse_clean_aug_factor
31
29
self .reg = regularizers .l2 (config .weight_reg )
32
30
31
+ self .representations_loss_vic = config .representations_loss_vic
32
+ self .representations_loss_nce = config .representations_loss_nce
33
+ self .embeddings_loss_vic = config .embeddings_loss_vic
34
+ self .embeddings_loss_nce = config .embeddings_loss_nce
35
+
33
36
self .encoder = encoder
34
37
self .mlp = MLP (config .mlp_dim )
35
38
self .infonce_loss = InfoNCELoss ()
@@ -45,56 +48,65 @@ def compile(self, optimizer, **kwargs):
45
48
self .optimizer = optimizer
46
49
47
50
def call (self , X ):
48
- if len (X .shape ) == 4 and self .enable_mse_clean_aug :
49
- X , _ = self .extract_clean_and_aug (X )
50
51
return self .encoder (X )
51
52
52
53
@tf .function
53
- def get_embeddings (self , X_1 , X_2 ):
54
- Z_1 = self .encoder (X_1 , training = True )
55
- Z_2 = self .encoder (X_2 , training = True )
56
- if self .enable_mlp :
57
- Z_1 = self .mlp (Z_1 , training = True )
58
- Z_2 = self .mlp (Z_2 , training = True )
59
- return Z_1 , Z_2
54
+ def representations_loss (self , Z_1 , Z_2 ):
55
+ loss , accuracy = 0 , 0
56
+ if self .representations_loss_nce :
57
+ loss , accuracy = self .infonce_loss ((Z_1 , Z_2 ))
58
+ loss = self .infonce_loss_factor * loss
59
+ if self .representations_loss_vic :
60
+ loss += self .vic_reg_factor * self .vic_reg ((Z_1 , Z_2 ))
61
+ return loss , accuracy
60
62
61
63
@tf .function
62
- def extract_clean_and_aug (self , X ):
63
- X_clean , X_aug = tf .split (X , 2 , axis = - 1 )
64
- X_clean = tf .squeeze (X_clean , axis = - 1 )
65
- X_aug = tf .squeeze (X_aug , axis = - 1 )
66
- return X_clean , X_aug
64
+ def embeddings_loss (self , Z_1 , Z_2 ):
65
+ loss , accuracy = 0 , 0
66
+ if self .embeddings_loss_nce :
67
+ loss , accuracy = self .infonce_loss ((Z_1 , Z_2 ))
68
+ loss = self .infonce_loss_factor * loss
69
+ if self .embeddings_loss_vic :
70
+ loss += self .vic_reg_factor * self .vic_reg ((Z_1 , Z_2 ))
71
+ return loss , accuracy
67
72
68
73
def train_step (self , data ):
69
- X_1_aug , X_2_aug , _ = data
74
+ X_1 , X_2 , _ = data
70
75
# X shape: (B, H, W, C) = (B, 40, 200, 1)
71
76
72
- if self .enable_mse_clean_aug :
73
- X_1_clean , X_1_aug = self .extract_clean_and_aug (X_1_aug )
74
- X_2_clean , X_2_aug = self .extract_clean_and_aug (X_2_aug )
75
-
76
77
with tf .GradientTape () as tape :
77
- Z_1_aug , Z_2_aug = self .get_embeddings (X_1_aug , X_2_aug )
78
-
79
- loss , accuracy = self .infonce_loss ((Z_1_aug , Z_2_aug ))
80
- loss = self .infonce_loss_factor * loss
81
- loss += self .vic_reg_factor * self .vic_reg ((Z_1_aug , Z_2_aug ))
82
- loss += self .barlow_twins_factor * self .barlow_twins ((Z_1_aug , Z_2_aug ))
83
-
84
- if self .enable_mse_clean_aug :
85
- Z_1_clean , Z_2_clean = self .get_embeddings (X_1_clean , X_2_clean )
86
- loss += self .mse_clean_aug_factor * mse_loss (Z_1_clean , Z_1_aug )
87
- loss += self .mse_clean_aug_factor * mse_loss (Z_2_clean , Z_2_aug )
88
-
89
- trainable_params = self .encoder .trainable_weights
90
- if self .enable_mlp :
91
- trainable_params += self .mlp .trainable_weights
92
-
93
- grads = tape .gradient (loss , trainable_params )
94
- # grads, _ = tf.clip_by_global_norm(grads, 5.0)
95
- self .optimizer .apply_gradients (zip (grads , trainable_params ))
96
-
97
- return { 'loss' : loss , 'accuracy' : accuracy }
78
+ Z_1 = self .encoder (X_1 , training = True )
79
+ Z_2 = self .encoder (X_2 , training = True )
80
+ representations_loss , representations_accuracy = self .representations_loss (
81
+ Z_1 ,
82
+ Z_2
83
+ )
84
+
85
+ if self .enable_mlp :
86
+ Z_1 = self .mlp (Z_1 , training = True )
87
+ Z_2 = self .mlp (Z_2 , training = True )
88
+ embeddings_loss , embeddings_accuracy = self .embeddings_loss (
89
+ Z_1 ,
90
+ Z_2
91
+ )
92
+
93
+ # Apply representations loss
94
+ params = self .encoder .trainable_weights
95
+ grads = tape .gradient (representations_loss , params )
96
+ self .optimizer .apply_gradients (zip (grads , params ))
97
+
98
+ # Aplly embeddings loss
99
+ params = self .encoder .trainable_weights
100
+ params += self .mlp .trainable_weights
101
+ grads = tape .gradient (embeddings_loss , params )
102
+ self .optimizer .apply_gradients (zip (grads , params ))
103
+
104
+ return {
105
+ 'representations_loss' : representations_loss ,
106
+ 'representations_accuracy' : representations_accuracy ,
107
+ 'embeddings_loss' : embeddings_loss ,
108
+ 'embeddings_accuracy' : embeddings_accuracy
109
+ }
98
110
99
111
100
112
class MLP (Model ):
@@ -156,10 +168,4 @@ def call(self, data):
156
168
preds_acc = tf .math .equal (pred_indices , labels )
157
169
accuracy = tf .math .count_nonzero (preds_acc , dtype = tf .int32 ) / batch_size
158
170
159
- return loss , accuracy
160
-
161
-
162
- @tf .function
163
- def mse_loss (Z_clean , Z_aug ):
164
- mse = tf .keras .metrics .mean_squared_error (Z_clean , Z_aug )
165
- return tf .math .reduce_mean (mse )
171
+ return loss , accuracy
0 commit comments