Skip to content

Commit

Permalink
Add tests for writing chunked intermediate data
Browse files Browse the repository at this point in the history
  • Loading branch information
yousefmoazzam committed Jun 14, 2024
1 parent b3d9acc commit 3bde4b9
Showing 1 changed file with 121 additions and 0 deletions.
121 changes: 121 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,124 @@ def test_save_intermediate_data_mpi(tmp_path: Path):
np.testing.assert_array_equal(file["test_file.h5"], [0, 0])
assert "data_dims" in file
np.testing.assert_array_equal(file["data_dims"]["detector_x_y"], [10, 20])


@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5])
def test_save_intermediate_data_frames_per_chunk(
tmp_path: Path,
frames_per_chunk: int,
):
FILE_NAME = "test_file.h5"
DATA_PATH = "/data"
GLOBAL_SHAPE = (10, 10, 10)
global_data = np.arange(np.prod(GLOBAL_SHAPE), dtype=np.float32).reshape(
GLOBAL_SHAPE
)
aux_data = AuxiliaryData(angles=np.ones(GLOBAL_SHAPE[0], dtype=np.float32))
block = DataSetBlock(
data=global_data,
aux_data=aux_data,
slicing_dim=0,
block_start=0,
chunk_start=0,
chunk_shape=GLOBAL_SHAPE,
global_shape=GLOBAL_SHAPE,
)

with h5py.File(tmp_path / FILE_NAME, "w") as f:
save_intermediate_data(
data=block.data,
global_shape=block.global_shape,
global_index=block.global_index,
slicing_dim=block.slicing_dim,
file=f,
frames_per_chunk=frames_per_chunk,
path=DATA_PATH,
detector_x=block.global_shape[2],
detector_y=block.global_shape[1],
angles=block.angles,
)

# Define the expected chunk shape, based on the `frames_per_chunk` value and the slicing
# dim of the data that was saved
expected_chunk_shape = [0, 0, 0]
expected_chunk_shape[block.slicing_dim] = frames_per_chunk
DIMS = [0, 1, 2]
non_slicing_dims = list(set(DIMS) - set([block.slicing_dim]))
for dim in non_slicing_dims:
expected_chunk_shape[dim] = block.global_shape[dim]

with h5py.File(tmp_path / FILE_NAME, "r") as f:
chunk_shape = f[DATA_PATH].chunks

if frames_per_chunk != 0:
assert chunk_shape == tuple(expected_chunk_shape)
else:
assert chunk_shape is None


@pytest.mark.mpi
@pytest.mark.skipif(
MPI.COMM_WORLD.size != 2, reason="Only rank-2 MPI is supported with this test"
)
@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5])
def test_save_intermediate_data_frames_per_chunk_mpi(
tmp_path: Path,
frames_per_chunk: int,
):
COMM = MPI.COMM_WORLD
tmp_path = COMM.bcast(tmp_path)
FILE_NAME = "test_file.h5"
DATA_PATH = "/data"
SLICING_DIM = 0
GLOBAL_SHAPE = (10, 10, 10)
CHUNK_SIZE = GLOBAL_SHAPE[SLICING_DIM] // 2
global_data = np.arange(np.prod(GLOBAL_SHAPE), dtype=np.float32).reshape(
GLOBAL_SHAPE
)
aux_data = AuxiliaryData(angles=np.ones(GLOBAL_SHAPE[0], dtype=np.float32))
rank_data = (
global_data[:CHUNK_SIZE, :, :]
if COMM.rank == 0
else global_data[CHUNK_SIZE:, :, :]
)
block = DataSetBlock(
data=rank_data,
aux_data=aux_data,
slicing_dim=0,
block_start=0,
chunk_start=0 if COMM.rank == 0 else CHUNK_SIZE,
global_shape=GLOBAL_SHAPE,
chunk_shape=(CHUNK_SIZE, GLOBAL_SHAPE[1], GLOBAL_SHAPE[2]),
)

with h5py.File(tmp_path / FILE_NAME, "w", driver="mpio", comm=COMM) as f:
save_intermediate_data(
data=block.data,
global_shape=block.global_shape,
global_index=block.global_index,
slicing_dim=block.slicing_dim,
file=f,
frames_per_chunk=frames_per_chunk,
path=DATA_PATH,
detector_x=block.global_shape[2],
detector_y=block.global_shape[1],
angles=block.angles,
)

# Define the expected chunk shape, based on the `frames_per_chunk` value and the slicing
# dim of the data that was saved
expected_chunk_shape = [0, 0, 0]
expected_chunk_shape[block.slicing_dim] = frames_per_chunk
DIMS = [0, 1, 2]
non_slicing_dims = list(set(DIMS) - set([block.slicing_dim]))
for dim in non_slicing_dims:
expected_chunk_shape[dim] = block.global_shape[dim]

with h5py.File(tmp_path / FILE_NAME, "r") as f:
chunk_shape = f[DATA_PATH].chunks

if frames_per_chunk != 0:
assert chunk_shape == tuple(expected_chunk_shape)
else:
assert chunk_shape is None

0 comments on commit 3bde4b9

Please sign in to comment.