Skip to content

Commit

Permalink
Use 1 frame per chunk if flag value exceeds slicing dim length
Browse files Browse the repository at this point in the history
If the value provided for the `--frames-per-chunk` flag is larger than
the length of the data's slicing dimension, then the number of frames
per-chunk will be set to 1 (and will be logged in the verbose logfile).

This is to avoid `h5py` errors with trying to create a chunk whose shape
exceeds the boundaries of the data being written.
  • Loading branch information
yousefmoazzam committed Jul 25, 2024
1 parent 307ee62 commit aef83c9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
12 changes: 11 additions & 1 deletion httomo/methods.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import pathlib
from typing import Tuple
import numpy as np
import h5py

import httomo
from httomo.runner.dataset import DataSetBlock
from httomo.utils import xp
from httomo.utils import log_once, xp

__all__ = ["calculate_stats", "save_intermediate_data"]

Expand Down Expand Up @@ -46,6 +47,15 @@ def save_intermediate_data(
angles: np.ndarray,
) -> None:
"""Saves intermediate data to a file, including auxiliary"""
if frames_per_chunk > data.shape[slicing_dim]:
warn_message = (
f"frames_per_chunk={frames_per_chunk} exceeds number of elements in "
f"slicing dim={slicing_dim} of data with shape {data.shape}. Falling "
"back to 1 frame per-chunk"
)
log_once(warn_message, logging.DEBUG)
frames_per_chunk = 1

if frames_per_chunk > 0:
chunk_shape = [0, 0, 0]
chunk_shape[slicing_dim] = frames_per_chunk
Expand Down
12 changes: 8 additions & 4 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_save_intermediate_data_mpi(tmp_path: Path):
np.testing.assert_array_equal(file["data_dims"]["detector_x_y"], [10, 20])


@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5])
@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5, 1000])
def test_save_intermediate_data_frames_per_chunk(
tmp_path: Path,
frames_per_chunk: int,
Expand Down Expand Up @@ -247,7 +247,9 @@ def test_save_intermediate_data_frames_per_chunk(
# Define the expected chunk shape, based on the `frames_per_chunk` value and the slicing
# dim of the data that was saved
expected_chunk_shape = [0, 0, 0]
expected_chunk_shape[block.slicing_dim] = frames_per_chunk
expected_chunk_shape[block.slicing_dim] = (
frames_per_chunk if frames_per_chunk != 1000 else 1
)
DIMS = [0, 1, 2]
non_slicing_dims = list(set(DIMS) - set([block.slicing_dim]))
for dim in non_slicing_dims:
Expand All @@ -266,7 +268,7 @@ def test_save_intermediate_data_frames_per_chunk(
@pytest.mark.skipif(
MPI.COMM_WORLD.size != 2, reason="Only rank-2 MPI is supported with this test"
)
@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5])
@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5, 1000])
def test_save_intermediate_data_frames_per_chunk_mpi(
tmp_path: Path,
frames_per_chunk: int,
Expand Down Expand Up @@ -314,7 +316,9 @@ def test_save_intermediate_data_frames_per_chunk_mpi(
# Define the expected chunk shape, based on the `frames_per_chunk` value and the slicing
# dim of the data that was saved
expected_chunk_shape = [0, 0, 0]
expected_chunk_shape[block.slicing_dim] = frames_per_chunk
expected_chunk_shape[block.slicing_dim] = (
frames_per_chunk if frames_per_chunk != 1000 else 1
)
DIMS = [0, 1, 2]
non_slicing_dims = list(set(DIMS) - set([block.slicing_dim]))
for dim in non_slicing_dims:
Expand Down

0 comments on commit aef83c9

Please sign in to comment.