Skip to content

Commit

Permalink
orthoconv: applied black on project files
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaut.boissin committed Apr 5, 2022
1 parent 12d088a commit 241cc2a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion deel/lip/compute_layer_sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _compute_sv_conv2d(w, Ks, N, padding="circular"):

# Minimum Singular Value

bigConstant = 1.1 * sigma_max**2
bigConstant = 1.1 * sigma_max ** 2
u = tf.random.uniform((batch_size,) + input_shape, minval=-1.0, maxval=1.0)

u, v = _power_iteration_conv(
Expand Down
2 changes: 1 addition & 1 deletion deel/lip/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ def init_spectral_norm(self):
if stride > 1:
N = int(0.5 + N / stride)

if C * stride**2 > M:
if C * stride ** 2 > M:
self.spectral_input_shape = (N, N, M)
self.ro_case = True
else:
Expand Down
10 changes: 5 additions & 5 deletions deel/lip/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def _get_kernel_shape(self):
def _compute_delta(self):
(R, C, M) = self._get_kernel_shape()
if not self.flag_deconv:
delta = M - (self.stride**self.dim) * C
delta = M - (self.stride ** self.dim) * C
else:
delta = C - (self.stride**self.dim) * M
delta = C - (self.stride ** self.dim) * M
delta = max(0, delta)
if delta > 0:
print(
Expand All @@ -69,8 +69,8 @@ def _compute_delta(self):
def _check_if_orthconv_exists(self):
(R, C, M) = self._get_kernel_shape()
# RO case
if C * self.stride**self.dim >= M:
if M > C * (R**self.dim):
if C * self.stride ** self.dim >= M:
if M > C * (R ** self.dim):
raise RuntimeError(
"Impossible RO configuration for orthogonal convolution"
)
Expand All @@ -80,7 +80,7 @@ def _check_if_orthconv_exists(self):
"Impossible CO configuration for orthogonal convolution"
)

if C * (self.stride**self.dim) == M:
if C * (self.stride ** self.dim) == M:
warnings.warn(
"LorthRegularizer: Warning configuration C*S^2=M is hard to optimize"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def set_spectral_input_shape(self, kernel):
if stride > 1:
N = int(0.5 + N / stride)

if C * stride**2 > M:
if C * stride ** 2 > M:
self.spectral_input_shape = (N, N, M)
self.RO_case = True
else:
Expand Down

0 comments on commit 241cc2a

Please sign in to comment.