Skip to content

Commit

Permalink
Merge pull request #73 from deel-ai/refactor/power_iteration_generic
Browse files Browse the repository at this point in the history
Power iteration algorithm is now generic to any linear operator
  • Loading branch information
cofri authored Oct 20, 2023
2 parents b5ee010 + f09c2be commit 9781950
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 137 deletions.
243 changes: 110 additions & 133 deletions deel/lip/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,116 +158,128 @@ def body(w, old_w):
return w


def _power_iteration(w, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL):
"""
Internal function that performs the power iteration algorithm.
def _power_iteration(
linear_operator,
adjoint_operator,
u,
eps=DEFAULT_EPS_SPECTRAL,
maxiter=DEFAULT_MAXITER_SPECTRAL,
axis=None,
):
"""Internal function that performs the power iteration algorithm to estimate the
largest singular vector of a linear operator.
Args:
w: weights matrix that we want to find eigen vector
u: initialization of the eigen vector
eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps
maxiter: maximum number of iterations for the algorithm
linear_operator (Callable): a callable object that maps a linear operation.
adjoint_operator (Callable): a callable object that maps the adjoint of the
linear operator.
u (tf.Tensor): initialization of the singular vector.
eps (float, optional): stopping criterion of the algorithm, when
norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL.
maxiter (int, optional): maximum number of iterations for the algorithm.
Defaults to DEFAULT_MAXITER_SPECTRAL.
axis (int/list, optional): dimension along which to normalize. Can be set for
depthwise convolution for example. Defaults to None.
Returns:
u and v corresponding to the maximum eigenvalue
tf.Tensor: the maximum singular vector.
"""
# build _u and _v (_v is size of [email protected](w), will be set on the first body
# iteration)
if u is None:
u = tf.linalg.l2_normalize(
tf.random.uniform(
shape=(1, w.shape[-1]), minval=0.0, maxval=1.0, dtype=w.dtype
)
)
_u = u
_v = tf.zeros((1,) + (w.shape[0],), dtype=w.dtype)

# create a fake old_w that doesn't pass the loop condition
# it won't affect computation as the first action done in the loop overwrite it.
_old_u = 10 * _u
# Prepare while loop variables
u = tf.math.l2_normalize(u, axis=axis)
# create a fake old_w that doesn't pass the loop condition, it will be overwritten
old_u = u + 2 * eps

# define the loop condition
def cond(_u, _v, old_u):
return tf.linalg.norm(_u - old_u) >= eps
# Loop body
def body(u, old_u):
old_u = u
v = linear_operator(u)
u = adjoint_operator(v)

# define the loop body
def body(_u, _v, _old_u):
_old_u = _u
_v = tf.math.l2_normalize(_u @ tf.transpose(w))
_u = tf.math.l2_normalize(_v @ w)
return _u, _v, _old_u
u = tf.math.l2_normalize(u, axis=axis)

# apply the loop
_u, _v, _old_u = tf.while_loop(
return u, old_u

# Loop stopping condition
def cond(u, old_u):
return tf.linalg.norm(u - old_u) >= eps

# Run the while loop
u, _ = tf.while_loop(
cond,
body,
(_u, _v, _old_u),
parallel_iterations=30,
(u, old_u),
maximum_iterations=maxiter,
swap_memory=SWAP_MEMORY,
)

# Prevent gradient to back-propagate into the while loop
if STOP_GRAD_SPECTRAL:
_u = tf.stop_gradient(_u)
_v = tf.stop_gradient(_v)
return _u, _v
u = tf.stop_gradient(u)

return u


def spectral_normalization(
kernel, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL
):
"""
Normalize the kernel to have it's max eigenvalue == 1.
Normalize the kernel to have its maximum singular value equal to 1.
Args:
kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel
u (tf.Tensor): initialization for the max eigen vector
eps (float): epsilon stopping criterion: norm(ut - ut-1) must be less than eps
maxiter (int): maximum number of iterations for the algorithm
kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel.
u (tf.Tensor): initialization of the maximum singular vector.
eps (float, optional): stopping criterion of the algorithm, when
norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL.
maxiter (int, optional): maximum number of iterations for the algorithm.
Defaults to DEFAULT_MAXITER_SPECTRAL.
Returns:
the normalized kernel w_bar, the maximum eigen vector, and the maximum singular
the normalized kernel, the maximum singular vector, and the maximum singular
value.
"""
_u, _v = _power_iteration(kernel, u, eps, maxiter)
# compute Sigma
sigma = _v @ kernel
sigma = sigma @ tf.transpose(_u)
# normalize it
# we assume that in the worst case we converged to sigma + eps (as u and v are

if u is None:
u = tf.random.uniform(
shape=(1, kernel.shape[-1]), minval=0.0, maxval=1.0, dtype=kernel.dtype
)

def linear_op(u):
return u @ tf.transpose(kernel)

def adjoint_op(v):
return v @ kernel

u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter)

# Compute the largest singular value and the normalized kernel.
# We assume that in the worst case we converged to sigma + eps (as u and v are
# normalized after each iteration)
# in order to be sure that operator norm of W_bar is strictly less than one we
# use sigma + eps, which ensure stability of the bjorck even when beta=0.5
W_bar = kernel / (sigma + eps)
return W_bar, _u, sigma
# In order to be sure that operator norm of normalized kernel is strictly less than
# one we use sigma + eps, which ensures stability of Björck algorithm even when
# beta=0.5
sigma = tf.reshape(tf.norm(linear_op(u)), (1, 1))
normalized_kernel = kernel / (sigma + eps)
return normalized_kernel, u, sigma


def _power_iteration_conv(
w,
u,
stride=1.0,
conv_first=True,
pad_func=None,
eps=DEFAULT_EPS_SPECTRAL,
maxiter=DEFAULT_MAXITER_SPECTRAL,
big_constant=-1,
):
def get_conv_operators(kernel, u_shape, stride=1.0, conv_first=True, pad_func=None):
"""
Internal function that performs the power iteration algorithm for convolution.
Return two functions corresponding to the linear convolution operator and its
adjoint.
Args:
w: weights matrix that we want to find eigen vector
u: initialization of the eigen matrix should be ||u||=1 for L2_norm
stride: stride parameter of the convolution
conv_first: RO or CO case , should be True in CO case (stride^2*C<M)
pad_func: function for applying padding (None is padding same)
eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps
maxiter: maximum number of iterations for the algorithm
big_constant: only for computing the minimum singular value (otherwise -1)
Returns:
u and v corresponding to the maximum eigenvalue
kernel (tf.Tensor): the convolution kernel to normalize
u_shape (tuple): shape of a singular vector (as a 4D tensor).
stride (int, optional): stride parameter of convolutions. Defaults to 1.
conv_first (bool, optional): RO or CO case , should be True in CO case
(stride^2*C<M). Defaults to True.
pad_func (Callable, optional): function for applying padding (None is padding
same). Defaults to None.
Returns:
tuple: two functions for the linear convolution operator and its adjoint
operator.
"""

def identity(x):
Expand Down Expand Up @@ -295,66 +307,28 @@ def _conv_transpose(u, w, output_shape, stride):
w_adj = _maybe_transpose_kernel(w, True)
return _conv(u_upscale, w_adj, stride=1)

def body(_u, _v, _old_u, _norm_u):
# _u is supposed to be normalized when entering in the body function
_old_u = _u
u = _u

if conv_first: # Conv, then transposed conv
v = _conv(u, w, stride)
unew = _conv_transpose(v, w, u.shape, stride)
else: # Transposed conv, then conv
v = _conv_transpose(u, w, _v.shape, stride)
unew = _conv(v, w, stride)

if big_constant > 0:
unew = big_constant * u - unew
if conv_first:

_norm_unew = tf.norm(unew)
unew = tf.math.l2_normalize(unew)
return unew, v, _old_u, _norm_unew
def linear_op(u):
return _conv(u, kernel, stride)

# define the loop condition
def adjoint_op(v):
return _conv_transpose(v, kernel, u_shape, stride)

def cond(_u, _v, old_u, _norm_u):
return tf.linalg.norm(_u - old_u) >= eps

# v shape
if conv_first:
v_shape = (
(u.shape[0],)
+ (u.shape[1] // stride, u.shape[2] // stride)
+ (w.shape[-1],)
)
else:
v_shape = (
(u.shape[0],) + (u.shape[1] * stride, u.shape[2] * stride) + (w.shape[-2],)
(u_shape[0],)
+ (u_shape[1] * stride, u_shape[2] * stride)
+ (kernel.shape[-2],)
)

# build _u and _v
_norm_u = tf.norm(u)
_u = tf.math.l2_normalize(u)
_u += tf.random.uniform(_u.shape, minval=-eps, maxval=eps)
_v = tf.zeros(v_shape) # _v will be set on the first body iteration
def linear_op(u):
return _conv_transpose(u, kernel, v_shape, stride)

# create a fake old_w that doesn't pass the loop condition
# it won't affect computation as the first action done in the loop overwrites it.
_old_u = 10 * _u
def adjoint_op(v):
return _conv(v, kernel, stride)

# apply the loop
_u, _v, _old_u, _norm_u = tf.while_loop(
cond,
body,
(_u, _v, _old_u, _norm_u),
parallel_iterations=1,
maximum_iterations=maxiter,
swap_memory=SWAP_MEMORY,
)
if STOP_GRAD_SPECTRAL:
_u = tf.stop_gradient(_u)
_v = tf.stop_gradient(_v)

return _u, _v, _norm_u
return linear_op, adjoint_op


def spectral_normalization_conv(
Expand Down Expand Up @@ -386,11 +360,14 @@ def spectral_normalization_conv(
if eps < 0:
return kernel, u, 1.0

_u, _v, _ = _power_iteration_conv(
kernel, u, stride, conv_first, pad_func, eps, maxiter
linear_op, adjoint_op = get_conv_operators(
kernel, u.shape, stride, conv_first, pad_func
)

# Calculate Sigma
sigma = tf.norm(_v)
W_bar = kernel / (sigma + eps)
return W_bar, _u, sigma
u = tf.math.l2_normalize(u) + tf.random.uniform(u.shape, minval=-eps, maxval=eps)
u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter)

# Compute the largest singular value and the normalized kernel
sigma = tf.norm(linear_op(u))
normalized_kernel = kernel / (sigma + eps)
return normalized_kernel, u, sigma
7 changes: 4 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ per-file-ignores =

[tox:tox]
envlist =
py{37,38,39,310}-tf{22,23,24,25,26,27,28,29,210,211,212,213,latest}
py{37,38,39,310}-lint
py{37,38,39,310,311}-tf{22,23,24,25,26,27,28,29,210,211,212,213,214,latest}
py{37,38,39,310,311}-lint

[testenv]
deps =
Expand All @@ -28,11 +28,12 @@ deps =
tf211: tensorflow ~= 2.11.0
tf212: tensorflow ~= 2.12.0
tf213: tensorflow ~= 2.13.0
tf214: tensorflow ~= 2.14.0

commands =
python -m unittest

[testenv:py{37,38,39,310}-lint]
[testenv:py{37,38,39,310,311}-lint]
skip_install = true
deps =
black
Expand Down
2 changes: 1 addition & 1 deletion tests/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _test_kernel(self, kernel):
).numpy()
SVmax = np.max(sigmas_svd)

u = rng.normal(size=(1, kernel.shape[-1]))
u = rng.normal(size=(1, kernel.shape[-1])).astype("float32")
W_bar, _u, sigma = spectral_normalization(kernel, u=u, eps=1e-6)
# Test sigma is close to the one computed with svd first run @ 1e-1
np.testing.assert_approx_equal(
Expand Down

0 comments on commit 9781950

Please sign in to comment.