Skip to content

Commit 745ab09

Browse files
author
JAXopt authors
committed
Merge pull request #300 from mblondel:ot_stop_gradient
PiperOrigin-RevId: 470727718
2 parents b742705 + 46cac2c commit 745ab09

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

jaxopt/_src/projection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,18 @@ def projection_box_section(x: jnp.ndarray,
392392

393393
def _max_l2(x, marginal_b, gamma):
394394
scale = gamma * marginal_b
395-
p = projection_simplex(x / scale)
395+
x_scale = x / scale
396+
p = projection_simplex(x_scale)
397+
# From Danskin's theorem, we do not need to backpropagate
398+
# through projection_simplex.
399+
p = jax.lax.stop_gradient(p)
396400
return jnp.dot(x, p) - 0.5 * scale * jnp.dot(p, p)
397401

398402

399403
def _max_ent(x, marginal_b, gamma):
400404
return gamma * logsumexp(x / gamma) - gamma * jnp.log(marginal_b)
401405

406+
402407
_max_l2_vmap = jax.vmap(_max_l2, in_axes=(1, 0, None))
403408
_max_l2_grad_vmap = jax.vmap(jax.grad(_max_l2), in_axes=(1, 0, None))
404409

@@ -771,4 +776,4 @@ def kl_projection_birkhoff(sim_matrix: jnp.ndarray,
771776
return kl_projection_transport(sim_matrix=sim_matrix,
772777
marginals=(marginals_a, marginals_b),
773778
make_solver=make_solver,
774-
use_semi_dual=use_semi_dual)
779+
use_semi_dual=use_semi_dual)

0 commit comments

Comments
 (0)