Skip to content

Commit 690959f

Browse files
committed
fix: address CodeRabbit review comments on align_corners changes
- scale_affine: handle dst dim == 1 special case (avoid spurious non-zero scale via max(n-1, 1)) and document that `centered` is ignored when align_corners=True. - kwargs_from_pending: docstring now notes align_corners extraction from extra_info. - test_affine_transform: use detach().cpu().numpy() consistently and pin identity_affine dtype. - test_affine: switch the known method_0/align_corners=True mismatch to subTest + SkipTest so it stays visible in CI; deepcopy lazy params. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent b721472 commit 690959f

4 files changed

Lines changed: 26 additions & 15 deletions

File tree

monai/transforms/lazy/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ def affine_from_pending(pending_item):
9090

9191

9292
def kwargs_from_pending(pending_item):
93-
"""Extract kwargs from a pending transform item."""
93+
"""Extract kwargs from a pending transform item.
94+
95+
When ``pending_item`` is a dict, ``align_corners`` is also extracted from its ``extra_info`` entry
96+
(if present and boolean) so the lazy pipeline preserves the original transform's alignment.
97+
"""
9498
if not isinstance(pending_item, dict):
9599
return {}
96100
ret = {

monai/transforms/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,6 +2105,7 @@ def scale_affine(spatial_size, new_spatial_size, centered: bool = True, align_co
21052105
spatial_size: original spatial size.
21062106
new_spatial_size: new spatial size.
21072107
centered: whether the scaling is with respect to the image center (True, default) or corner (False).
2108+
Ignored when ``align_corners=True``, since corner-aligned scaling is inherently centered.
21082109
align_corners: if True, use (size-1) based scaling to match torch.nn.functional.interpolate behavior.
21092110
21102111
Returns:
@@ -2115,9 +2116,10 @@ def scale_affine(spatial_size, new_spatial_size, centered: bool = True, align_co
21152116
if spatial_size == new_spatial_size:
21162117
return np.eye(r + 1)
21172118
if align_corners:
2118-
# Match interpolate behavior: (src-1)/(dst-1)
2119+
# Match interpolate behavior: (src-1)/(dst-1); when dst == 1 the scale collapses to 0
21192120
s = np.array(
2120-
[(float(o) - 1) / max(float(n) - 1, 1) for o, n in zip(spatial_size, new_spatial_size)], dtype=float
2121+
[0.0 if float(n) == 1 else (float(o) - 1) / (float(n) - 1) for o, n in zip(spatial_size, new_spatial_size)],
2122+
dtype=float,
21212123
)
21222124
else:
21232125
# Standard scaling: src/dst

tests/networks/layers/test_affine_transform.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,17 +392,19 @@ def test_align_corners_consistency(self):
392392
image = torch.arange(1.0, 13.0).view(1, 1, 3, 4)
393393

394394
# Identity affine in pixel space (i, j, k convention with reverse_indexing=True)
395-
identity_affine = torch.eye(3).unsqueeze(0)
395+
identity_affine = torch.eye(3, dtype=torch.float32).unsqueeze(0)
396396

397397
# Test with align_corners=True (the default)
398398
xform_true = AffineTransform(align_corners=True)
399399
out_true = xform_true(image, identity_affine)
400-
np.testing.assert_allclose(out_true.numpy(), image.numpy(), atol=1e-5, rtol=_rtol)
400+
np.testing.assert_allclose(out_true.detach().cpu().numpy(), image.detach().cpu().numpy(), atol=1e-5, rtol=_rtol)
401401

402402
# Test with align_corners=False
403403
xform_false = AffineTransform(align_corners=False)
404404
out_false = xform_false(image, identity_affine)
405-
np.testing.assert_allclose(out_false.numpy(), image.numpy(), atol=1e-5, rtol=_rtol)
405+
np.testing.assert_allclose(
406+
out_false.detach().cpu().numpy(), image.detach().cpu().numpy(), atol=1e-5, rtol=_rtol
407+
)
406408

407409
def test_align_corners_true_translation(self):
408410
"""
@@ -425,7 +427,7 @@ def test_align_corners_true_translation(self):
425427

426428
# Expected: shift columns left by 1, rightmost column becomes 0
427429
expected = torch.tensor([[[[2, 3, 4, 0], [6, 7, 8, 0], [10, 11, 12, 0], [14, 15, 16, 0]]]], dtype=torch.float32)
428-
np.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-4, rtol=_rtol)
430+
np.testing.assert_allclose(out.detach().cpu().numpy(), expected.detach().cpu().numpy(), atol=1e-4, rtol=_rtol)
429431

430432

431433
if __name__ == "__main__":

tests/transforms/test_affine.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_affine(self, input_param, input_data, expected_val):
189189
set_track_meta(True)
190190

191191
# test lazy
192-
lazy_input_param = input_param.copy()
192+
lazy_input_param = deepcopy(input_param)
193193
for align_corners in [True, False]:
194194
lazy_input_param["align_corners"] = align_corners
195195
resampler = Affine(**lazy_input_param)
@@ -238,13 +238,16 @@ def method_3(im, ac):
238238

239239
for call in (method_0, method_1, method_2, method_3):
240240
for ac in (False, True):
241-
# Skip method_0 with align_corners=True due to known issue with lazy pipeline
242-
# padding_mode override when using align_corners=True in optimized path
243-
if call == method_0 and ac:
244-
continue
245-
out = call(im, ac)
246-
ref = Resize(align_corners=ac, spatial_size=(sp_size, sp_size), mode="bilinear")(im)
247-
assert_allclose(out, ref, rtol=1e-4, atol=1e-4, type_test=False)
241+
with self.subTest(method=call.__name__, align_corners=ac):
242+
if call is method_0 and ac:
243+
# Known issue: lazy pipeline padding_mode override mismatches
244+
# when using align_corners=True in the optimized path.
245+
raise unittest.SkipTest(
246+
"method_0 with align_corners=True is a known mismatch in the lazy pipeline."
247+
)
248+
out = call(im, ac)
249+
ref = Resize(align_corners=ac, spatial_size=(sp_size, sp_size), mode="bilinear")(im)
250+
assert_allclose(out, ref, rtol=1e-4, atol=1e-4, type_test=False)
248251

249252

250253
if __name__ == "__main__":

0 commit comments

Comments
 (0)