-
Notifications
You must be signed in to change notification settings - Fork 268
/
Copy pathregularizations.py
35 lines (26 loc) · 940 Bytes
/
regularizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# inspired by loss of VAEs
def fc_selu_reg(x, mu):
# average over filter size
mean = K.mean(x, axis=0)
tau_sqr = K.mean(K.square(x), axis=0)
# average over batch size
mean_loss = K.mean(K.square(mean))
tau_loss = K.mean(tau_sqr - K.log(tau_sqr + K.epsilon()))
return mu * (mean_loss + tau_loss)
class SeluDenseRegularizer(Regularizer):
def __init__(self, mu=0.001):
self.mu = K.cast_to_floatx(mu)
def __call__(self, x):
return fc_selu_reg(x, self.mu)
def get_config(self):
return {'mu': float(self.mu)}
class SeluConv2DRegularizer(Regularizer):
def __init__(self, mu=0.001):
self.mu = K.cast_to_floatx(mu)
def __call__(self, x):
shape = K.int_shape(x)
num_filters = shape[-1]
x = K.reshape(x, shape=[-1, num_filters])
return fc_selu_reg(x, self.mu)
def get_config(self):
return {'mu': float(self.mu)}