Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560165866
  • Loading branch information
The precondition Authors committed Aug 25, 2023
1 parent f0ef94c commit fad765b
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions precondition/tearfree/sketchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,16 +448,18 @@ def _all_nan(y):
intercept = jnp.mean(vals) - slope * jnp.mean(ranks)
log_ranks = jnp.log(jnp.arange(k + 1, d + 1))
fitted_vals = slope * log_ranks + intercept
tail = jnp.exp(jax.scipy.special.logsumexp(fitted_vals * 2)) / (d - k)
tail = jnp.exp(jax.scipy.special.logsumexp(fitted_vals * 2))
undeflated = jnp.square(jnp.maximum(top_eigs, 0.0))
else:
tail = axis_state.tail * decay + cutoff**2
# Avoid numerical error from the sqrt computation and from subtracting
# and re-adding cutoff^2 (mathematically, undeflated == deflated^2 + tail).
undeflated = jnp.square(jnp.maximum(top_eigs, 0.0)) + axis_state.tail * decay
# Avoid numerical error from the sqrt computation and from subtracting
# and re-adding cutoff^2 (mathematically, undeflated == deflated^2 + tail).
undeflated = (
jnp.square(jnp.maximum(top_eigs, 0.0)) + axis_state.tail * decay
)
eigvecs = u[:, :k]

mask = deflated > 0
# Would be nice to statically assert deflated == 0 implies undeflated == 0.

alpha = jnp.asarray(-1.0 / (2 * update.ndim), dtype=jnp.float32)
eigvecs *= mask
Expand Down

0 comments on commit fad765b

Please sign in to comment.