Skip to content

Commit

Permalink
add check on tomo mask shape (#210)
Browse files Browse the repository at this point in the history
* add check on tomo mask shape

* make f-string and add test also in extraction

* add tests for the errors

* update tooltips to mention that the dimensions should be in pixels

* use correct input name
  • Loading branch information
sroet authored Aug 7, 2024
1 parent 0793404 commit 6de4c01
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/pytom_tm/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def extract_candidates(argv=None):
type=pathlib.Path,
required=False,
action=CheckFileExists,
help="Here you can provide a mask for the extraction with dimensions equal to "
help="Here you can provide a mask for the extraction with dimensions (in pixels) equal to "
"the tomogram. All values in the mask that are smaller or equal to 0 will be "
"removed, all values larger than 0 are considered regions of interest. It can "
"be used to extract annotations only within a specific cellular region."
Expand Down Expand Up @@ -687,7 +687,7 @@ def match_template(argv=None):
type=pathlib.Path,
required=False,
action=CheckFileExists,
help="Here you can provide a mask for matching with dimensions equal to "
help="Here you can provide a mask for matching with dimensions (in pixels) equal to "
"the tomogram. If a subvolume only has values <= 0 for this mask it will be skipped.",
)

Expand Down
5 changes: 5 additions & 0 deletions src/pytom_tm/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ def extract_particles(
tomogram_mask = read_mrc(job.tomogram_mask)

if tomogram_mask is not None:
if tomogram_mask.shape != job.tomo_shape:
raise ValueError(
"Tomogram mask does not have the same number of pixels as the tomogram.\n"
f"Tomogram mask shape: {tomogram_mask.shape}, tomogram shape: {job.tomo_shape}"
)
slices = [
slice(origin, origin + size)
for origin, size in zip(job.search_origin, job.search_size)
Expand Down
5 changes: 5 additions & 0 deletions src/pytom_tm/tmjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,11 @@ def __init__(
self.tomogram_mask = tomogram_mask
if tomogram_mask is not None:
temp = read_mrc(tomogram_mask)
if temp.shape != self.tomo_shape:
raise ValueError(
"Tomogram mask does not have the same number of pixels as the tomogram.\n"
f"Tomogram mask shape: {temp.shape}, tomogram shape: {self.tomo_shape}"
)
if np.all(temp <= 0):
raise ValueError(
f"No values larger than 0 found in the tomogram mask: {tomogram_mask}"
Expand Down
42 changes: 42 additions & 0 deletions tests/test_tmjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TEST_DATA_DIR = pathlib.Path(__file__).parent.joinpath("test_data")
TEST_TOMOGRAM = TEST_DATA_DIR.joinpath("tomogram.mrc")
TEST_BROKEN_TOMOGRAM_MASK = TEST_DATA_DIR.joinpath("broken_tomogram_mask.mrc")
TEST_WRONG_SIZE_TOMO_MASK = TEST_DATA_DIR.joinpath("wrong_size_tomogram_mask.mrc")
TEST_EXTRACTION_MASK_OUTSIDE = TEST_DATA_DIR.joinpath("extraction_mask_outside.mrc")
TEST_EXTRACTION_MASK_INSIDE = TEST_DATA_DIR.joinpath("extraction_mask_inside.mrc")
TEST_TEMPLATE = TEST_DATA_DIR.joinpath("template.mrc")
Expand Down Expand Up @@ -124,10 +125,17 @@ def setUpClass(cls) -> None:
broken_tomogram_mask = np.zeros(TOMO_SHAPE, dtype=np.float32)
write_mrc(TEST_BROKEN_TOMOGRAM_MASK, broken_tomogram_mask, 1.0)

# write wrong size tomogram mask
size = list(TOMO_SHAPE)
size[0] += 1
wrong_size_tomogram_mask = np.ones(tuple(size), dtype=np.float32)
write_mrc(TEST_WRONG_SIZE_TOMO_MASK, wrong_size_tomogram_mask, 1.0)

@classmethod
def tearDownClass(cls) -> None:
TEST_MASK.unlink()
TEST_BROKEN_TOMOGRAM_MASK.unlink()
TEST_WRONG_SIZE_TOMO_MASK.unlink()
TEST_EXTRACTION_MASK_OUTSIDE.unlink()
TEST_EXTRACTION_MASK_INSIDE.unlink()
TEST_TEMPLATE.unlink()
Expand Down Expand Up @@ -251,6 +259,19 @@ def test_tm_job_errors(self):
voxel_size=1.0,
tomogram_mask=TEST_BROKEN_TOMOGRAM_MASK,
)
# Test wrong size template mask
with self.assertRaisesRegex(ValueError, str(TOMO_SHAPE)):
TMJob(
"0",
10,
TEST_TOMOGRAM,
TEST_TEMPLATE,
TEST_MASK,
TEST_DATA_DIR,
angle_increment=ANGULAR_SEARCH,
voxel_size=1.0,
tomogram_mask=TEST_WRONG_SIZE_TOMO_MASK,
)

def test_tm_job_copy(self):
copy = self.job.copy()
Expand Down Expand Up @@ -680,3 +701,24 @@ def test_extraction(self):
msg="We expected a detected particle with a extraction mask that "
"covers the object.",
)

# test mask that is the wrong size raises an error
with self.assertRaisesRegex(ValueError, str(TOMO_SHAPE)):
_, _ = extract_particles(
job,
5,
100,
tomogram_mask_path=TEST_WRONG_SIZE_TOMO_MASK,
create_plot=False,
)

# Also test the raise if it somehow got attached to the job
job = self.job.copy()
job.tomogram_mask = TEST_WRONG_SIZE_TOMO_MASK
with self.assertRaisesRegex(ValueError, str(TOMO_SHAPE)):
_, _ = extract_particles(
job,
5,
100,
create_plot=False,
)

0 comments on commit 6de4c01

Please sign in to comment.