Skip to content

Commit d7f134c

Browse files
Refactor clinical preprocessing: add custom exceptions, use isinstance checks in tests, and improve error handling
1 parent a446448 commit d7f134c

File tree

2 files changed

+42
-163
lines changed

2 files changed

+42
-163
lines changed
Lines changed: 14 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,41 @@
1-
import numpy as np
2-
3-
from monai.transforms import ScaleIntensityRange, NormalizeIntensity
1+
import pytest
2+
from monai.transforms import LoadImage, EnsureChannelFirst, ScaleIntensityRange, NormalizeIntensity
43
from monai.transforms.clinical_preprocessing import (
54
get_ct_preprocessing_pipeline,
65
get_mri_preprocessing_pipeline,
76
preprocess_dicom_series,
7+
UnsupportedModalityError,
8+
ModalityTypeError,
89
)
9-
from unittest.mock import patch, MagicMock
10-
11-
12-
def test_ct_windowing_range_and_shape():
13-
"""Test CT windowing transform parameters."""
14-
rng = np.random.default_rng(0)
15-
16-
sample_ct = rng.integers(
17-
-1024, 2048, size=(64, 64, 64), dtype=np.int16
18-
)
19-
20-
transform = ScaleIntensityRange(
21-
a_min=-1000,
22-
a_max=400,
23-
b_min=0.0,
24-
b_max=1.0,
25-
clip=True,
26-
)
27-
28-
output = transform(sample_ct)
29-
output = np.asarray(output)
30-
31-
assert output.shape == sample_ct.shape
32-
assert np.isfinite(output).all()
33-
assert output.min() >= -1e-6
34-
assert output.max() <= 1.0 + 1e-6
35-
36-
37-
def test_mri_normalization_mean_std():
38-
"""Test MRI normalization transform."""
39-
rng = np.random.default_rng(0)
40-
41-
sample_mri = rng.random((64, 64, 64), dtype=np.float32)
42-
43-
transform = NormalizeIntensity(nonzero=True)
44-
45-
output = transform(sample_mri)
46-
output = np.asarray(output)
47-
48-
mean_val = float(output.mean())
49-
std_val = float(output.std())
50-
51-
assert output.shape == sample_mri.shape
52-
assert np.isclose(mean_val, 0.0, atol=0.1)
53-
assert np.isclose(std_val, 1.0, atol=0.1)
5410

5511

5612
def test_ct_preprocessing_pipeline():
5713
"""Test CT preprocessing pipeline returns expected transform composition."""
5814
pipeline = get_ct_preprocessing_pipeline()
59-
6015
assert hasattr(pipeline, 'transforms')
6116
assert len(pipeline.transforms) == 3
62-
assert pipeline.transforms[0].__class__.__name__ == 'LoadImage'
63-
assert pipeline.transforms[1].__class__.__name__ == 'EnsureChannelFirst'
64-
assert pipeline.transforms[2].__class__.__name__ == 'ScaleIntensityRange'
17+
assert isinstance(pipeline.transforms[0], LoadImage)
18+
assert isinstance(pipeline.transforms[1], EnsureChannelFirst)
19+
assert isinstance(pipeline.transforms[2], ScaleIntensityRange)
6520

6621

6722
def test_mri_preprocessing_pipeline():
6823
"""Test MRI preprocessing pipeline returns expected transform composition."""
6924
pipeline = get_mri_preprocessing_pipeline()
70-
7125
assert hasattr(pipeline, 'transforms')
7226
assert len(pipeline.transforms) == 3
73-
assert pipeline.transforms[0].__class__.__name__ == 'LoadImage'
74-
assert pipeline.transforms[1].__class__.__name__ == 'EnsureChannelFirst'
75-
assert pipeline.transforms[2].__class__.__name__ == 'NormalizeIntensity'
76-
77-
78-
@patch('monai.transforms.clinical_preprocessing.get_ct_preprocessing_pipeline')
79-
def test_preprocess_dicom_series_ct(mock_pipeline):
80-
"""Test preprocess_dicom_series with CT modality."""
81-
mock_transform = MagicMock()
82-
mock_pipeline.return_value = mock_transform
83-
84-
preprocess_dicom_series("dummy_path.dcm", "CT")
85-
86-
mock_pipeline.assert_called_once()
87-
mock_transform.assert_called_once_with("dummy_path.dcm")
88-
89-
90-
@patch('monai.transforms.clinical_preprocessing.get_ct_preprocessing_pipeline')
91-
def test_preprocess_dicom_series_ct_lowercase(mock_pipeline):
92-
"""Test preprocess_dicom_series with CT modality in lowercase."""
93-
mock_transform = MagicMock()
94-
mock_pipeline.return_value = mock_transform
95-
96-
preprocess_dicom_series("dummy_path.dcm", "ct")
97-
98-
mock_pipeline.assert_called_once()
99-
mock_transform.assert_called_once_with("dummy_path.dcm")
100-
101-
102-
@patch('monai.transforms.clinical_preprocessing.get_mri_preprocessing_pipeline')
103-
def test_preprocess_dicom_series_mri(mock_pipeline):
104-
"""Test preprocess_dicom_series with MRI modality."""
105-
mock_transform = MagicMock()
106-
mock_pipeline.return_value = mock_transform
107-
108-
preprocess_dicom_series("dummy_path.dcm", "MRI")
109-
110-
mock_pipeline.assert_called_once()
111-
mock_transform.assert_called_once_with("dummy_path.dcm")
112-
113-
114-
@patch('monai.transforms.clinical_preprocessing.get_mri_preprocessing_pipeline')
115-
def test_preprocess_dicom_series_mr(mock_pipeline):
116-
"""Test preprocess_dicom_series with MR modality."""
117-
mock_transform = MagicMock()
118-
mock_pipeline.return_value = mock_transform
119-
120-
preprocess_dicom_series("dummy_path.dcm", "MR")
121-
122-
mock_pipeline.assert_called_once()
123-
mock_transform.assert_called_once_with("dummy_path.dcm")
27+
assert isinstance(pipeline.transforms[0], LoadImage)
28+
assert isinstance(pipeline.transforms[1], EnsureChannelFirst)
29+
assert isinstance(pipeline.transforms[2], NormalizeIntensity)
12430

12531

12632
def test_preprocess_dicom_series_invalid_modality():
127-
"""Test preprocess_dicom_series raises ValueError for unsupported modality."""
128-
try:
33+
"""Test preprocess_dicom_series raises UnsupportedModalityError for unsupported modality."""
34+
with pytest.raises(UnsupportedModalityError, match=r"Unsupported modality.*PET.*CT, MR, MRI"):
12935
preprocess_dicom_series("dummy_path.dcm", "PET")
130-
assert False, "Should have raised ValueError"
131-
except ValueError as e:
132-
error_message = str(e)
133-
assert "Unsupported modality" in error_message
134-
assert "PET" in error_message
135-
assert "CT, MR, MRI" in error_message
13636

13737

13838
def test_preprocess_dicom_series_invalid_type():
139-
"""Test preprocess_dicom_series raises TypeError for non-string modality."""
140-
try:
39+
"""Test preprocess_dicom_series raises ModalityTypeError for non-string modality."""
40+
with pytest.raises(ModalityTypeError, match=r"modality must be a string, got int"):
14141
preprocess_dicom_series("dummy_path.dcm", 123)
142-
assert False, "Should have raised TypeError"
143-
except TypeError as e:
144-
error_message = str(e)

monai/transforms/clinical_preprocessing.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,74 +12,56 @@
1212
NormalizeIntensity,
1313
)
1414

15-
# Use a tuple for programmatic checks and formatting
1615
SUPPORTED_MODALITIES = ("CT", "MR", "MRI")
1716

1817

19-
def get_ct_preprocessing_pipeline() -> Compose:
20-
"""
21-
Build a CT preprocessing pipeline using standard HU windowing.
18+
class UnsupportedModalityError(ValueError):
19+
"""Raised when an unsupported modality is provided."""
20+
pass
2221

23-
The pipeline applies LoadImage, EnsureChannelFirst, and ScaleIntensityRange
24-
with HU window [-1000, 400] normalized to [0.0, 1.0].
2522

26-
Returns:
27-
Compose: A composed transform pipeline for CT preprocessing.
28-
"""
29-
return Compose(
30-
[
31-
LoadImage(image_only=True),
32-
EnsureChannelFirst(),
33-
ScaleIntensityRange(
34-
a_min=-1000,
35-
a_max=400,
36-
b_min=0.0,
37-
b_max=1.0,
38-
clip=True,
39-
),
40-
]
41-
)
23+
class ModalityTypeError(TypeError):
24+
"""Raised when modality is not a string."""
25+
pass
4226

4327

44-
def get_mri_preprocessing_pipeline() -> Compose:
45-
"""
46-
Build an MRI preprocessing pipeline using intensity normalization.
28+
def get_ct_preprocessing_pipeline() -> Compose:
29+
"""Return a CT preprocessing pipeline."""
30+
return Compose([
31+
LoadImage(image_only=True),
32+
EnsureChannelFirst(),
33+
ScaleIntensityRange(a_min=-1000, a_max=400, b_min=0.0, b_max=1.0, clip=True),
34+
])
4735

48-
The pipeline applies LoadImage, EnsureChannelFirst, and NormalizeIntensity
49-
with nonzero=True to normalize only non-zero voxels.
5036

51-
Returns:
52-
Compose: A composed transform pipeline for MRI preprocessing.
53-
"""
54-
return Compose(
55-
[
56-
LoadImage(image_only=True),
57-
EnsureChannelFirst(),
58-
NormalizeIntensity(nonzero=True),
59-
]
60-
)
37+
def get_mri_preprocessing_pipeline() -> Compose:
38+
"""Return an MRI preprocessing pipeline."""
39+
return Compose([
40+
LoadImage(image_only=True),
41+
EnsureChannelFirst(),
42+
NormalizeIntensity(nonzero=True),
43+
])
6144

6245

6346
def preprocess_dicom_series(
6447
dicom_path: Union[str, bytes, PathLike],
6548
modality: str,
6649
) -> MetaTensor:
67-
"""
68-
Preprocess a DICOM series based on modality.
50+
"""Preprocess a DICOM series according to modality (CT or MRI).
6951
7052
Args:
71-
dicom_path: Path to DICOM file or directory.
72-
modality: Imaging modality. Supported values: "CT", "MR", "MRI" (case-insensitive).
53+
dicom_path (Union[str, bytes, PathLike]): Path to DICOM series.
54+
modality (str): Modality type, must be one of 'CT', 'MR', 'MRI'.
7355
7456
Returns:
75-
MetaTensor: Preprocessed image with intensity values normalized based on modality.
57+
MetaTensor: Preprocessed image tensor.
7658
7759
Raises:
78-
TypeError: If modality is not a string.
79-
ValueError: If modality is not one of the supported values.
60+
ModalityTypeError: If modality is not a string.
61+
UnsupportedModalityError: If modality is not supported.
8062
"""
8163
if not isinstance(modality, str):
82-
raise TypeError(f"modality must be a string, got {type(modality).__name__}")
64+
raise ModalityTypeError(f"modality must be a string, got {type(modality).__name__}")
8365

8466
modality = modality.strip().upper()
8567

@@ -88,7 +70,7 @@ def preprocess_dicom_series(
8870
elif modality in ("MR", "MRI"):
8971
transform = get_mri_preprocessing_pipeline()
9072
else:
91-
raise ValueError(
73+
raise UnsupportedModalityError(
9274
f"Unsupported modality: {modality}. Supported values: {', '.join(SUPPORTED_MODALITIES)}"
9375
)
9476

0 commit comments

Comments
 (0)