diff --git a/src/pytom_tm/entry_points.py b/src/pytom_tm/entry_points.py index ce5e2b32..8a642179 100644 --- a/src/pytom_tm/entry_points.py +++ b/src/pytom_tm/entry_points.py @@ -436,7 +436,7 @@ def extract_candidates(argv=None): type=pathlib.Path, required=False, action=CheckFileExists, - help="Here you can provide a mask for the extraction with dimensions equal to " + help="Here you can provide a mask for the extraction with dimensions (in pixels) equal to " "the tomogram. All values in the mask that are smaller or equal to 0 will be " "removed, all values larger than 0 are considered regions of interest. It can " "be used to extract annotations only within a specific cellular region." @@ -687,7 +687,7 @@ def match_template(argv=None): type=pathlib.Path, required=False, action=CheckFileExists, - help="Here you can provide a mask for matching with dimensions equal to " + help="Here you can provide a mask for matching with dimensions (in pixels) equal to " "the tomogram. If a subvolume only has values <= 0 for this mask it will be skipped.", ) diff --git a/src/pytom_tm/extract.py b/src/pytom_tm/extract.py index e7653fcb..a72f83cd 100644 --- a/src/pytom_tm/extract.py +++ b/src/pytom_tm/extract.py @@ -216,6 +216,11 @@ def extract_particles( tomogram_mask = read_mrc(job.tomogram_mask) if tomogram_mask is not None: + if tomogram_mask.shape != job.tomo_shape: + raise ValueError( + "Tomogram mask does not have the same number of pixels as the tomogram.\n" + f"Tomogram mask shape: {tomogram_mask.shape}, tomogram shape: {job.tomo_shape}" + ) slices = [ slice(origin, origin + size) for origin, size in zip(job.search_origin, job.search_size) diff --git a/src/pytom_tm/tmjob.py b/src/pytom_tm/tmjob.py index 5a12b229..f0aa17e6 100644 --- a/src/pytom_tm/tmjob.py +++ b/src/pytom_tm/tmjob.py @@ -371,6 +371,11 @@ def __init__( self.tomogram_mask = tomogram_mask if tomogram_mask is not None: temp = read_mrc(tomogram_mask) + if temp.shape != self.tomo_shape: + raise ValueError( + "Tomogram mask does not have the same number of pixels as the tomogram.\n" + f"Tomogram mask shape: {temp.shape}, tomogram shape: {self.tomo_shape}" + ) if np.all(temp <= 0): raise ValueError( f"No values larger than 0 found in the tomogram mask: {tomogram_mask}" diff --git a/tests/test_tmjob.py b/tests/test_tmjob.py index 242dea42..4e4a0bdb 100644 --- a/tests/test_tmjob.py +++ b/tests/test_tmjob.py @@ -20,6 +20,7 @@ TEST_DATA_DIR = pathlib.Path(__file__).parent.joinpath("test_data") TEST_TOMOGRAM = TEST_DATA_DIR.joinpath("tomogram.mrc") TEST_BROKEN_TOMOGRAM_MASK = TEST_DATA_DIR.joinpath("broken_tomogram_mask.mrc") +TEST_WRONG_SIZE_TOMO_MASK = TEST_DATA_DIR.joinpath("wrong_size_tomogram_mask.mrc") TEST_EXTRACTION_MASK_OUTSIDE = TEST_DATA_DIR.joinpath("extraction_mask_outside.mrc") TEST_EXTRACTION_MASK_INSIDE = TEST_DATA_DIR.joinpath("extraction_mask_inside.mrc") TEST_TEMPLATE = TEST_DATA_DIR.joinpath("template.mrc") @@ -124,10 +125,17 @@ def setUpClass(cls) -> None: broken_tomogram_mask = np.zeros(TOMO_SHAPE, dtype=np.float32) write_mrc(TEST_BROKEN_TOMOGRAM_MASK, broken_tomogram_mask, 1.0) + # write wrong size tomogram mask + size = list(TOMO_SHAPE) + size[0] += 1 + wrong_size_tomogram_mask = np.ones(tuple(size), dtype=np.float32) + write_mrc(TEST_WRONG_SIZE_TOMO_MASK, wrong_size_tomogram_mask, 1.0) + @classmethod def tearDownClass(cls) -> None: TEST_MASK.unlink() TEST_BROKEN_TOMOGRAM_MASK.unlink() + TEST_WRONG_SIZE_TOMO_MASK.unlink() TEST_EXTRACTION_MASK_OUTSIDE.unlink() TEST_EXTRACTION_MASK_INSIDE.unlink() TEST_TEMPLATE.unlink() @@ -251,6 +259,19 @@ def test_tm_job_errors(self): voxel_size=1.0, tomogram_mask=TEST_BROKEN_TOMOGRAM_MASK, ) + # Test wrong size template mask + with self.assertRaisesRegex(ValueError, str(TOMO_SHAPE)): + TMJob( + "0", + 10, + TEST_TOMOGRAM, + TEST_TEMPLATE, + TEST_MASK, + TEST_DATA_DIR, + angle_increment=ANGULAR_SEARCH, + voxel_size=1.0, + tomogram_mask=TEST_WRONG_SIZE_TOMO_MASK, + ) def test_tm_job_copy(self): copy = self.job.copy() @@ -680,3 +701,24 @@ def test_extraction(self): msg="We expected a detected particle with a extraction mask that " "covers the object.", ) + + # test mask that is the wrong size raises an error + with self.assertRaisesRegex(ValueError, str(TOMO_SHAPE)): + _, _ = extract_particles( + job, + 5, + 100, + tomogram_mask_path=TEST_WRONG_SIZE_TOMO_MASK, + create_plot=False, + ) + + # Also test the raise if it somehow got attached to the job + job = self.job.copy() + job.tomogram_mask = TEST_WRONG_SIZE_TOMO_MASK + with self.assertRaisesRegex(ValueError, str(TOMO_SHAPE)): + _, _ = extract_particles( + job, + 5, + 100, + create_plot=False, + )