@@ -392,13 +392,18 @@ def projection_box_section(x: jnp.ndarray,
392392
393393def _max_l2 (x , marginal_b , gamma ):
394394 scale = gamma * marginal_b
395- p = projection_simplex (x / scale )
395+ x_scale = x / scale
396+ p = projection_simplex (x_scale )
397+ # From Danskin's theorem, we do not need to backpropagate
398+ # through projection_simplex.
399+ p = jax .lax .stop_gradient (p )
396400 return jnp .dot (x , p ) - 0.5 * scale * jnp .dot (p , p )
397401
398402
399403def _max_ent (x , marginal_b , gamma ):
400404 return gamma * logsumexp (x / gamma ) - gamma * jnp .log (marginal_b )
401405
406+
402407_max_l2_vmap = jax .vmap (_max_l2 , in_axes = (1 , 0 , None ))
403408_max_l2_grad_vmap = jax .vmap (jax .grad (_max_l2 ), in_axes = (1 , 0 , None ))
404409
@@ -771,4 +776,4 @@ def kl_projection_birkhoff(sim_matrix: jnp.ndarray,
771776 return kl_projection_transport (sim_matrix = sim_matrix ,
772777 marginals = (marginals_a , marginals_b ),
773778 make_solver = make_solver ,
774- use_semi_dual = use_semi_dual )
779+ use_semi_dual = use_semi_dual )
0 commit comments