Skip to content

Commit b3a6ac2

Browse files
Add clinical DICOM preprocessing utilities for CT/MRI with unit tests
1 parent 798f8af commit b3a6ac2

File tree

2 files changed

+140
-36
lines changed

2 files changed

+140
-36
lines changed
Lines changed: 109 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,129 @@
1+
"""Unit tests for clinical DICOM preprocessing utilities."""
2+
13
import numpy as np
4+
from unittest.mock import patch, MagicMock
25
import pytest
36

47
from monai.transforms import ScaleIntensityRange, NormalizeIntensity
8+
from monai.transforms.clinical_preprocessing import (
9+
get_ct_preprocessing_pipeline,
10+
get_mri_preprocessing_pipeline,
11+
preprocess_dicom_series,
12+
)
513

614

7-
def test_ct_windowing_range_and_shape():
15+
def test_ct_windowing_range_and_shape_direct():
16+
"""Test ScaleIntensityRange transform on sample CT data."""
817
rng = np.random.default_rng(0)
9-
10-
sample_ct = rng.integers(
11-
-1024, 2048, size=(64, 64, 64), dtype=np.int16
12-
)
13-
14-
transform = ScaleIntensityRange(
15-
a_min=-1000,
16-
a_max=400,
17-
b_min=0.0,
18-
b_max=1.0,
19-
clip=True,
20-
)
21-
22-
output = transform(sample_ct)
23-
output = np.asarray(output)
18+
sample_ct = rng.integers(-1024, 2048, size=(64, 64, 64), dtype=np.int16)
19+
transform = ScaleIntensityRange(a_min=-1000, a_max=400, b_min=0.0, b_max=1.0, clip=True)
20+
output = np.asarray(transform(sample_ct))
2421

2522
assert output.shape == sample_ct.shape
2623
assert np.isfinite(output).all()
2724
assert output.min() >= -1e-6
2825
assert output.max() <= 1.0 + 1e-6
2926

3027

31-
def test_mri_normalization_mean_std():
28+
def test_mri_normalization_mean_std_direct():
29+
"""Test NormalizeIntensity transform on sample MRI data."""
3230
rng = np.random.default_rng(0)
33-
3431
sample_mri = rng.random((64, 64, 64), dtype=np.float32)
35-
3632
transform = NormalizeIntensity(nonzero=True)
33+
output = np.asarray(transform(sample_mri))
3734

38-
output = transform(sample_mri)
39-
output = np.asarray(output)
35+
assert output.shape == sample_mri.shape
36+
assert np.isclose(float(output.mean()), 0.0, atol=0.1)
37+
assert np.isclose(float(output.std()), 1.0, atol=0.1)
4038

41-
mean_val = float(output.mean())
42-
std_val = float(output.std())
4339

44-
assert output.shape == sample_mri.shape
45-
assert np.isclose(mean_val, 0.0, atol=0.1)
46-
assert np.isclose(std_val, 1.0, atol=0.1)
40+
@patch("monai.transforms.clinical_preprocessing.LoadImage")
41+
def test_ct_pipeline(mock_loadimage):
42+
"""Test get_ct_preprocessing_pipeline returns correct transform sequence."""
43+
pipeline = get_ct_preprocessing_pipeline()
44+
assert len(pipeline.transforms) == 3
45+
assert pipeline.transforms[0].__class__.__name__ == "LoadImage"
46+
assert pipeline.transforms[1].__class__.__name__ == "EnsureChannelFirst"
47+
assert pipeline.transforms[2].__class__.__name__ == "ScaleIntensityRange"
48+
49+
50+
@patch("monai.transforms.clinical_preprocessing.LoadImage")
51+
def test_mri_pipeline(mock_loadimage):
52+
"""Test get_mri_preprocessing_pipeline returns correct transform sequence."""
53+
pipeline = get_mri_preprocessing_pipeline()
54+
assert len(pipeline.transforms) == 3
55+
assert pipeline.transforms[0].__class__.__name__ == "LoadImage"
56+
assert pipeline.transforms[1].__class__.__name__ == "EnsureChannelFirst"
57+
assert pipeline.transforms[2].__class__.__name__ == "NormalizeIntensity"
58+
59+
60+
@patch("monai.transforms.clinical_preprocessing.get_ct_preprocessing_pipeline")
61+
def test_preprocess_dicom_series_ct(mock_pipeline):
62+
"""Test preprocess_dicom_series with CT modality."""
63+
mock_transform = MagicMock()
64+
mock_pipeline.return_value = mock_transform
65+
preprocess_dicom_series("dummy_path.dcm", "CT")
66+
mock_pipeline.assert_called_once()
67+
mock_transform.assert_called_once_with("dummy_path.dcm")
68+
69+
70+
@patch("monai.transforms.clinical_preprocessing.get_ct_preprocessing_pipeline")
71+
def test_preprocess_dicom_series_ct_lowercase(mock_pipeline):
72+
"""Test preprocess_dicom_series with lowercase CT modality."""
73+
mock_transform = MagicMock()
74+
mock_pipeline.return_value = mock_transform
75+
preprocess_dicom_series("dummy_path.dcm", "ct")
76+
mock_pipeline.assert_called_once()
77+
mock_transform.assert_called_once_with("dummy_path.dcm")
78+
79+
80+
@patch("monai.transforms.clinical_preprocessing.get_mri_preprocessing_pipeline")
81+
def test_preprocess_dicom_series_mri(mock_pipeline):
82+
"""Test preprocess_dicom_series with MRI modality."""
83+
mock_transform = MagicMock()
84+
mock_pipeline.return_value = mock_transform
85+
preprocess_dicom_series("dummy_path.dcm", "MRI")
86+
mock_pipeline.assert_called_once()
87+
mock_transform.assert_called_once_with("dummy_path.dcm")
88+
89+
90+
@patch("monai.transforms.clinical_preprocessing.get_mri_preprocessing_pipeline")
91+
def test_preprocess_dicom_series_mr(mock_pipeline):
92+
"""Test preprocess_dicom_series with MR modality."""
93+
mock_transform = MagicMock()
94+
mock_pipeline.return_value = mock_transform
95+
preprocess_dicom_series("dummy_path.dcm", "MR")
96+
mock_pipeline.assert_called_once()
97+
mock_transform.assert_called_once_with("dummy_path.dcm")
98+
99+
100+
def test_preprocess_dicom_series_invalid_modality():
101+
"""Test preprocess_dicom_series raises ValueError for unsupported modality."""
102+
with pytest.raises(ValueError) as exc:
103+
preprocess_dicom_series("dummy_path.dcm", "PET")
104+
assert "Unsupported modality" in str(exc.value)
105+
assert "PET" in str(exc.value)
106+
107+
108+
def test_preprocess_dicom_series_invalid_type():
109+
"""Test preprocess_dicom_series raises TypeError for non-string modality."""
110+
with pytest.raises(TypeError) as exc:
111+
preprocess_dicom_series("dummy_path.dcm", 123)
112+
assert "modality must be a string" in str(exc.value)
113+
114+
115+
def test_preprocess_dicom_series_none_modality():
116+
"""Test preprocess_dicom_series raises TypeError for None modality."""
117+
with pytest.raises(TypeError) as exc:
118+
preprocess_dicom_series("dummy_path.dcm", None)
119+
assert "modality must be a string" in str(exc.value)
120+
121+
122+
@patch("monai.transforms.clinical_preprocessing.get_ct_preprocessing_pipeline")
123+
def test_preprocess_dicom_series_whitespace(mock_pipeline):
124+
"""Test preprocess_dicom_series handles whitespace in modality."""
125+
mock_transform = MagicMock()
126+
mock_pipeline.return_value = mock_transform
127+
preprocess_dicom_series("dummy_path.dcm", " CT ")
128+
mock_pipeline.assert_called_once()
129+
mock_transform.assert_called_once_with("dummy_path.dcm")

monai/transforms/clinical_preprocessing.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
"""Clinical DICOM preprocessing utilities for CT and MRI modalities."""
2+
13
from typing import Union
4+
from os import PathLike
25

36
from monai.transforms import (
47
Compose,
@@ -8,10 +11,18 @@
811
NormalizeIntensity,
912
)
1013

14+
SUPPORTED_MODALITIES = "CT, MR, MRI"
15+
1116

12-
def get_ct_preprocessing_pipeline():
17+
def get_ct_preprocessing_pipeline() -> Compose:
1318
"""
14-
CT preprocessing pipeline using standard HU windowing.
19+
Build a CT preprocessing pipeline using standard HU windowing.
20+
21+
The pipeline applies LoadImage, EnsureChannelFirst, and ScaleIntensityRange
22+
with HU window [-1000, 400] normalized to [0.0, 1.0].
23+
24+
Returns:
25+
Compose: A composed transform pipeline for CT preprocessing.
1526
"""
1627
return Compose(
1728
[
@@ -28,9 +39,15 @@ def get_ct_preprocessing_pipeline():
2839
)
2940

3041

31-
def get_mri_preprocessing_pipeline():
42+
def get_mri_preprocessing_pipeline() -> Compose:
3243
"""
33-
MRI preprocessing pipeline using intensity normalization.
44+
Build an MRI preprocessing pipeline using intensity normalization.
45+
46+
The pipeline applies LoadImage, EnsureChannelFirst, and NormalizeIntensity
47+
with nonzero=True to normalize only non-zero voxels.
48+
49+
Returns:
50+
Compose: A composed transform pipeline for MRI preprocessing.
3451
"""
3552
return Compose(
3653
[
@@ -42,21 +59,25 @@ def get_mri_preprocessing_pipeline():
4259

4360

4461
def preprocess_dicom_series(
45-
dicom_path: Union[str, bytes],
62+
dicom_path: Union[str, bytes, PathLike],
4663
modality: str,
47-
):
64+
) -> "MetaTensor":
4865
"""
4966
Preprocess a DICOM series based on modality.
5067
5168
Args:
5269
dicom_path: Path to DICOM file or directory.
53-
modality: CT, MR, or MRI.
70+
modality: Imaging modality. Supported values: "CT", "MR", "MRI" (case-insensitive).
5471
5572
Returns:
56-
Preprocessed image.
73+
MetaTensor: Preprocessed image with intensity values normalized based on modality.
74+
75+
Raises:
76+
TypeError: If modality is not a string.
77+
ValueError: If modality is not one of the supported values.
5778
"""
5879
if not isinstance(modality, str):
59-
raise TypeError("modality must be a string")
80+
raise TypeError(f"modality must be a string, got {type(modality).__name__}")
6081

6182
modality = modality.strip().upper()
6283

@@ -65,6 +86,6 @@ def preprocess_dicom_series(
6586
elif modality in ("MR", "MRI"):
6687
transform = get_mri_preprocessing_pipeline()
6788
else:
68-
raise ValueError("Unsupported modality")
89+
raise ValueError(f"Unsupported modality: {modality}. Supported values: {SUPPORTED_MODALITIES}")
6990

7091
return transform(dicom_path)

0 commit comments

Comments
 (0)