diff --git a/httomo/methods.py b/httomo/methods.py index 03232d253..e30ec04c1 100644 --- a/httomo/methods.py +++ b/httomo/methods.py @@ -1,3 +1,4 @@ +import logging import pathlib from typing import Tuple import numpy as np @@ -5,7 +6,7 @@ import httomo from httomo.runner.dataset import DataSetBlock -from httomo.utils import xp +from httomo.utils import log_once, xp __all__ = ["calculate_stats", "save_intermediate_data"] @@ -46,6 +47,15 @@ def save_intermediate_data( angles: np.ndarray, ) -> None: """Saves intermediate data to a file, including auxiliary""" + if frames_per_chunk > data.shape[slicing_dim]: + warn_message = ( + f"frames_per_chunk={frames_per_chunk} exceeds number of elements in " + f"slicing dim={slicing_dim} of data with shape {data.shape}. Falling " + "back to 1 frame per-chunk" + ) + log_once(warn_message, logging.DEBUG) + frames_per_chunk = 1 + if frames_per_chunk > 0: chunk_shape = [0, 0, 0] chunk_shape[slicing_dim] = frames_per_chunk diff --git a/tests/test_methods.py b/tests/test_methods.py index 4e4665477..7e3687e23 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -208,7 +208,7 @@ def test_save_intermediate_data_mpi(tmp_path: Path): np.testing.assert_array_equal(file["data_dims"]["detector_x_y"], [10, 20]) -@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5]) +@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5, 1000]) def test_save_intermediate_data_frames_per_chunk( tmp_path: Path, frames_per_chunk: int, @@ -247,7 +247,9 @@ def test_save_intermediate_data_frames_per_chunk( # Define the expected chunk shape, based on the `frames_per_chunk` value and the slicing # dim of the data that was saved expected_chunk_shape = [0, 0, 0] - expected_chunk_shape[block.slicing_dim] = frames_per_chunk + expected_chunk_shape[block.slicing_dim] = ( + frames_per_chunk if frames_per_chunk != 1000 else 1 + ) DIMS = [0, 1, 2] non_slicing_dims = list(set(DIMS) - set([block.slicing_dim])) for dim in non_slicing_dims: @@ -266,7 +268,7 @@ def test_save_intermediate_data_frames_per_chunk( @pytest.mark.skipif( MPI.COMM_WORLD.size != 2, reason="Only rank-2 MPI is supported with this test" ) -@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5]) +@pytest.mark.parametrize("frames_per_chunk", [0, 1, 5, 1000]) def test_save_intermediate_data_frames_per_chunk_mpi( tmp_path: Path, frames_per_chunk: int, @@ -314,7 +316,9 @@ def test_save_intermediate_data_frames_per_chunk_mpi( # Define the expected chunk shape, based on the `frames_per_chunk` value and the slicing # dim of the data that was saved expected_chunk_shape = [0, 0, 0] - expected_chunk_shape[block.slicing_dim] = frames_per_chunk + expected_chunk_shape[block.slicing_dim] = ( + frames_per_chunk if frames_per_chunk != 1000 else 1 + ) DIMS = [0, 1, 2] non_slicing_dims = list(set(DIMS) - set([block.slicing_dim])) for dim in non_slicing_dims: