Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions monai/apps/deepedit/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,44 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
return d


class NormalizeLabelsInDatasetd(MapTransform):
class RemapLabelsToSequentiald(MapTransform):
"""
Remap label values from a dataset-specific schema to sequential indices (0, 1, 2, 3, ...).

This transform takes labels with arbitrary values defined in a label dictionary and remaps them
to a sequential range starting from 1 (with background always set to 0). This is useful for
standardizing labels across different datasets or ensuring labels are in a contiguous range.

The output label indices are assigned in alphabetical order by label name to ensure
deterministic behavior regardless of input dictionary ordering.

Args:
keys: The ``keys`` parameter will be used to get and set the actual data item to transform
label_names: Dictionary mapping label names to their current values in the dataset.
For example: {"spleen": 1, "liver": 6, "background": 0}
Will be remapped to: {"background": 0, "liver": 1, "spleen": 2}
(alphabetically sorted, excluding background)
allow_missing_keys: If True, missing keys in the data dictionary will not raise an error

Example:
>>> transform = RemapLabelsToSequentiald(
... keys="label",
... label_names={"liver": 6, "spleen": 1, "background": 0}
... )
>>> # Input label has values [0, 1, 6]
>>> # Output label will have values [0, 1, 2] (background=0, liver=1, spleen=2)
>>> # And updates d["label_names"] to {"background": 0, "liver": 1, "spleen": 2}

Note:
- Background label (if present) is always mapped to 0
- Non-background labels are mapped to sequential indices 1, 2, 3, ... in alphabetical order
- Undefined labels (not in label_names) will be set to 0 (background)
- The transform updates the data dictionary with a new "label_names" key containing the remapped values
"""

def __init__(
self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False
):
"""
Normalize label values according to label names dictionary

Args:
keys: The ``keys`` parameter will be used to get and set the actual data item to transform
label_names: all label names
"""
super().__init__(keys, allow_missing_keys)

self.label_names = label_names or {}
Expand All @@ -106,13 +132,20 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
# Dictionary containing new label numbers
new_label_names = {}
label = np.zeros(d[key].shape)
# Making sure the range values and number of labels are the same
for idx, (key_label, val_label) in enumerate(self.label_names.items(), start=1):
if key_label != "background":
new_label_names[key_label] = idx
label[d[key] == val_label] = idx
if key_label == "background":
new_label_names["background"] = 0

# Sort label names to ensure deterministic ordering (exclude background)
sorted_labels = sorted(
[(k, v) for k, v in self.label_names.items() if k != "background"]
)

# Always set background to 0 first
if "background" in self.label_names:
new_label_names["background"] = 0

# Assign sequential indices to sorted non-background labels
for idx, (key_label, val_label) in enumerate(sorted_labels, start=1):
new_label_names[key_label] = idx
label[d[key] == val_label] = idx

d["label_names"] = new_label_names
if isinstance(d[key], MetaTensor):
Expand All @@ -122,6 +155,28 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
return d


class NormalizeLabelsInDatasetd(RemapLabelsToSequentiald):
"""
.. deprecated:: 1.5.0
`NormalizeLabelsInDatasetd` is deprecated. Use :class:`RemapLabelsToSequentiald` instead.

This class is maintained for backward compatibility. Please use RemapLabelsToSequentiald
which better describes the transform's functionality.
"""

def __init__(
self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False
):
warnings.warn(
"NormalizeLabelsInDatasetd is deprecated and will be removed in a future version. "
"Please use RemapLabelsToSequentiald instead, which better describes what the transform does: "
"remapping label values to sequential indices (0, 1, 2, 3, ...).",
DeprecationWarning,
stacklevel=2,
)
super().__init__(keys, label_names, allow_missing_keys)


class SingleLabelSelectiond(MapTransform):

def __init__(
Expand Down
65 changes: 65 additions & 0 deletions tests/apps/deepedit/test_deepedit_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FindAllValidSlicesMissingLabelsd,
FindDiscrepancyRegionsDeepEditd,
NormalizeLabelsInDatasetd,
RemapLabelsToSequentiald,
ResizeGuidanceMultipleLabelDeepEditd,
SingleLabelSelectiond,
SplitPredsLabeld,
Expand Down Expand Up @@ -282,6 +283,70 @@ def test_correct_results(self, arguments, input_data, expected_result):
result = add_fn(input_data)
self.assertEqual(len(np.unique(result["label"])), expected_result)

def test_ordering_determinism(self):
"""Test that different input ordering produces the same output (alphabetical)"""
# Create a label array with different label values
label = np.array([[[0, 1, 6, 3]]]) # background=0, spleen=1, liver=6, kidney=3

# Test case 1: liver first, then kidney, then spleen
data1 = {"label": label.copy()}
transform1 = RemapLabelsToSequentiald(
keys="label",
label_names={"liver": 6, "kidney": 3, "spleen": 1, "background": 0}
)
result1 = transform1(data1)

# Test case 2: spleen first, then kidney, then liver (different order)
data2 = {"label": label.copy()}
transform2 = RemapLabelsToSequentiald(
keys="label",
label_names={"spleen": 1, "kidney": 3, "liver": 6, "background": 0}
)
result2 = transform2(data2)

# Both should produce the same output (alphabetically sorted)
# Expected mapping: background=0, kidney=1, liver=2, spleen=3
np.testing.assert_array_equal(result1["label"], result2["label"])

# Verify the actual mapping is alphabetical
expected_output = np.array([[[0, 3, 2, 1]]]) # kidney=1, liver=2, spleen=3, background=0
np.testing.assert_array_equal(result1["label"], expected_output)

# Verify label_names is correct
self.assertEqual(result1["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3})
self.assertEqual(result2["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3})

def test_multiple_labels(self):
"""Test with multiple non-background labels"""
label = np.array([[[0, 1, 2, 5]]]) # background, spleen, kidney, liver
data = {"label": label.copy()}
transform = RemapLabelsToSequentiald(
keys="label",
label_names={"spleen": 1, "kidney": 2, "liver": 5, "background": 0}
)
result = transform(data)

# Expected: background=0, kidney=1, liver=2, spleen=3 (alphabetical)
expected = np.array([[[0, 3, 1, 2]]])
np.testing.assert_array_equal(result["label"], expected)
self.assertEqual(result["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3})

def test_deprecated_name_warning(self):
"""Test that using the deprecated name raises a warning"""
import warnings

data = {"label": np.array([[[0, 1]]])}

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
transform = NormalizeLabelsInDatasetd(keys="label", label_names={"spleen": 1, "background": 0})
_ = transform(data) # Call to trigger the warning

# Check that a deprecation warning was raised
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[0].category, DeprecationWarning))
self.assertIn("RemapLabelsToSequentiald", str(w[0].message))


class TestResizeGuidanceMultipleLabelCustomd(unittest.TestCase):

Expand Down
Loading