diff --git a/checkpoint/orbax/checkpoint/metadata/sharding_test.py b/checkpoint/orbax/checkpoint/metadata/sharding_test.py index 1309cc1f..60ec562f 100644 --- a/checkpoint/orbax/checkpoint/metadata/sharding_test.py +++ b/checkpoint/orbax/checkpoint/metadata/sharding_test.py @@ -55,7 +55,9 @@ def test_convert_between_jax_positional_sharding_and_sharding_metadata( [1, -1] ) expected_positional_sharding_metadata = ( - sharding_metadata.PositionalShardingMetadata(jax_sharding.shape) + sharding_metadata.PositionalShardingMetadata( + jax_sharding.shape, jax_sharding.memory_kind + ) ) converted_positional_sharding_metadata = ( sharding_metadata.from_jax_sharding(jax_sharding)