Skip to content

Commit 34e1d68

Browse files
author
Flax Authors
committed
Merge pull request #4983 from google:propagate-eager-sharding
PiperOrigin-RevId: 811550398
2 parents b9cc05f + 7959380 commit 34e1d68

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flax/nnx/variablelib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def __init__(
285285
# shard the value if applicable
286286
do_eager_sharding = config.flax_always_shard_variable
287287
if 'eager_sharding' in metadata:
288-
do_eager_sharding = metadata.pop('eager_sharding')
288+
do_eager_sharding = metadata['eager_sharding']
289289
if do_eager_sharding and 'sharding_names' in metadata:
290290
value = core_spmd.shard_value(
291291
value, metadata['sharding_names'], metadata.get('sharding_rules', None),

0 commit comments

Comments
 (0)