Skip to content

Commit f55b796

Browse files
Jake VanderPlasGoogle-ML-Automation
authored andcommitted
Code update
PiperOrigin-RevId: 747453575
1 parent 24fa190 commit f55b796

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax_tpu_embedding/sparsecore/lib/nn/table_stacking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _unstack_and_unshard_stacked_table(
455455

456456
# increase a rank and the first dimension is the number of sparse cores.
457457
stacked_table_3d = jax.jit(
458-
fun=lambda x: x.reshape(num_sparse_cores, -1, stack_embedding_dim),
458+
lambda x: x.reshape(num_sparse_cores, -1, stack_embedding_dim),
459459
in_shardings=stacked_table_sharding,
460460
out_shardings=stacked_table_sharding,
461461
)(stacked_table)

0 commit comments

Comments
 (0)