Skip to content

Commit

Permalink
Pass padded chunk shape from loader to block constructor
Browse files Browse the repository at this point in the history
Due to the previous change, the loader's `read_block()` method was
passing the unpadded chunk shape to the block's constructor, which is
incorrect.
  • Loading branch information
yousefmoazzam committed Sep 12, 2024
1 parent 6187cd6 commit 1e98edb
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
9 changes: 7 additions & 2 deletions httomo/loaders/standard_tomo_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from httomo.runner.dataset import DataSetBlock
from httomo.runner.dataset_store_interfaces import DataSetSource
from httomo.runner.loader import LoaderInterface
from httomo.utils import Pattern, log_once
from httomo.utils import Pattern, log_once, make_3d_shape_from_shape


class StandardTomoLoader(DataSetSource):
Expand Down Expand Up @@ -284,14 +284,19 @@ def read_block(self, start: int, length: int) -> DataSetBlock:
slices_read[0], slices_read[1], slices_read[2]
]

padded_chunk_shape_list = list(self._chunk_shape)
padded_chunk_shape_list[self._slicing_dim] += (
self._padding[0] + self._padding[1]
)

return DataSetBlock(
data=block_data,
aux_data=self._aux_data,
slicing_dim=self._slicing_dim,
block_start=start - self._padding[0],
chunk_start=self._chunk_index[self._slicing_dim],
global_shape=self._global_shape,
chunk_shape=self._chunk_shape,
chunk_shape=make_3d_shape_from_shape(padded_chunk_shape_list),
padding=self._padding,
)

Expand Down
67 changes: 67 additions & 0 deletions tests/loaders/test_standard_tomo_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,73 @@ def test_standard_tomo_loader_properties_reflect_nonzero_padding(
assert loader.global_index == EXPECTED_GLOBAL_INDEX


def test_non_zero_loader_padding_loaded_block_shape_properties(
standard_data_path: str,
standard_image_key_path: str,
):
IN_FILE_PATH = Path(__file__).parent.parent / "test_data/tomo_standard.nxs"
DARKS_FLATS_CONFIG = DarksFlatsFileConfig(
file=IN_FILE_PATH,
data_path=standard_data_path,
image_key_path=standard_image_key_path,
)
ANGLES_CONFIG = RawAngles(data_path="/entry1/tomo_entry/data/rotation_angle")
SLICING_DIM: SlicingDimType = 0
COMM = MPI.COMM_WORLD
PADDING = (2, 3)
PROJS, DET_Y, DET_X = (180, 128, 160)
PREVIEW_CONFIG = PreviewConfig(
angles=PreviewDimConfig(start=0, stop=PROJS),
detector_y=PreviewDimConfig(start=0, stop=DET_Y),
detector_x=PreviewDimConfig(start=0, stop=DET_X),
)

with mock.patch(
"httomo.darks_flats.get_darks_flats",
return_value=(np.zeros(1), np.zeros(1)),
):
loader = StandardTomoLoader(
in_file=IN_FILE_PATH,
data_path=DARKS_FLATS_CONFIG.data_path,
image_key_path=DARKS_FLATS_CONFIG.image_key_path,
darks=DARKS_FLATS_CONFIG,
flats=DARKS_FLATS_CONFIG,
angles=ANGLES_CONFIG,
preview_config=PREVIEW_CONFIG,
slicing_dim=SLICING_DIM,
comm=COMM,
padding=PADDING,
)

BLOCK_START = 0
BLOCK_LENGTH = 4
block = loader.read_block(BLOCK_START, BLOCK_LENGTH)

BLOCK_EXPECTED_GLOBAL_SHAPE = (PROJS, DET_Y, DET_X)
BLOCK_EXPECTED_CHUNK_SHAPE_UNPADDED = BLOCK_EXPECTED_GLOBAL_SHAPE
BLOCK_EXPECTED_CHUNK_SHAPE = (
BLOCK_EXPECTED_GLOBAL_SHAPE[0] + PADDING[0] + PADDING[1],
BLOCK_EXPECTED_GLOBAL_SHAPE[1],
BLOCK_EXPECTED_GLOBAL_SHAPE[2],
)
BLOCK_EXPECTED_SHAPE_UNPADDED = (
BLOCK_LENGTH,
BLOCK_EXPECTED_GLOBAL_SHAPE[1],
BLOCK_EXPECTED_GLOBAL_SHAPE[2],
)
BLOCK_EXPECTED_SHAPE = (
BLOCK_LENGTH + PADDING[0] + PADDING[1],
BLOCK_EXPECTED_GLOBAL_SHAPE[1],
BLOCK_EXPECTED_GLOBAL_SHAPE[2],
)

assert block.global_shape == BLOCK_EXPECTED_GLOBAL_SHAPE
assert block.chunk_shape == BLOCK_EXPECTED_CHUNK_SHAPE
assert block.chunk_shape_unpadded == BLOCK_EXPECTED_CHUNK_SHAPE_UNPADDED
assert block.shape == BLOCK_EXPECTED_SHAPE
assert block.shape_unpadded == BLOCK_EXPECTED_SHAPE_UNPADDED


@pytest.mark.parametrize(
"preview_config",
[
Expand Down

0 comments on commit 1e98edb

Please sign in to comment.