diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index a982b01427..b09b86b605 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -826,7 +826,10 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` Returns: - IoU, with size of (N,M) and same data type as ``boxes1`` + An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always + floating-point with size ``(N, M)``: + - if ``boxes1`` has a floating-point dtype, the same dtype is used. + - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``. """ @@ -842,8 +845,10 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE) - # compute IoU and convert back to original box_dtype + # compute IoU and convert back to original box_dtype or torch.float32 iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M) + if not box_dtype.is_floating_point: + box_dtype = COMPUTE_DTYPE iou_t = iou_t.to(dtype=box_dtype) # check if NaN or Inf @@ -851,7 +856,7 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor raise ValueError("Box IoU is NaN or Inf.") # convert tensor back to numpy if needed - iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1) + iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1, dtype=box_dtype) return iou @@ -867,7 +872,10 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` Returns: - GIoU, with size of (N,M) and same data type as ``boxes1`` + An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always + floating-point with size ``(N, M)``: + - if ``boxes1`` has a floating-point dtype, the same dtype is used. + - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``. Reference: https://giou.stanford.edu/GIoU.pdf @@ -904,12 +912,15 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso # GIoU giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) + if not box_dtype.is_floating_point: + box_dtype = COMPUTE_DTYPE giou_t = giou_t.to(dtype=box_dtype) + if torch.isnan(giou_t).any() or torch.isinf(giou_t).any(): raise ValueError("Box GIoU is NaN or Inf.") # convert tensor back to numpy if needed - giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1) + giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype) return giou @@ -925,7 +936,10 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode`` Returns: - paired GIoU, with size of (N,) and same data type as ``boxes1`` + An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always + floating-point with size ``(N, )``: + - if ``boxes1`` has a floating-point dtype, the same dtype is used. + - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``. Reference: https://giou.stanford.edu/GIoU.pdf @@ -982,12 +996,15 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr enclosure = torch.prod(wh, dim=-1, keepdim=False) # (N,) giou_t: torch.Tensor = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) # type: ignore + if not box_dtype.is_floating_point: + box_dtype = COMPUTE_DTYPE giou_t = giou_t.to(dtype=box_dtype) # (N,spatial_dims) + if torch.isnan(giou_t).any() or torch.isinf(giou_t).any(): raise ValueError("Box GIoU is NaN or Inf.") # convert tensor back to numpy if needed - giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1) + giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype) return giou diff --git a/tests/data/test_box_utils.py b/tests/data/test_box_utils.py index 390fd901fd..05778f691b 100644 --- a/tests/data/test_box_utils.py +++ b/tests/data/test_box_utils.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data.box_utils import ( @@ -218,5 +219,55 @@ def test_value(self, input_data, mode2, expected_box, expected_area): assert_allclose(nms_box, [1], type_test=False) +class TestBoxUtilsDtype(unittest.TestCase): + @parameterized.expand( + [ + # numpy dtypes + (np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32)), + (np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32)), + # torch dtypes + ( + torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64), + torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64), + ), + ( + torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32), + torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32), + ), + # mixed numpy (int + float) + (np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32)), + # mixed torch (int + float) + ( + torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64), + torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32), + ), + ] + ) + def test_dtype_behavior(self, boxes1, boxes2): + funcs = [box_iou, box_giou, box_pair_giou] + for func in funcs: + result = func(boxes1, boxes2) + + if isinstance(result, np.ndarray): + self.assertTrue( + np.issubdtype(result.dtype, np.floating), f"{func.__name__} expected float, got {result.dtype}" + ) + elif torch.is_tensor(result): + self.assertTrue( + torch.is_floating_point(result), f"{func.__name__} expected float tensor, got {result.dtype}" + ) + else: + self.fail(f"Unexpected return type {type(result)}") + + def test_integer_truncation_bug(self): + # Verify fix for #8553: IoU < 1.0 with integer inputs should not truncate to 0 + boxes1 = np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32) + boxes2 = np.array([[1, 1, 1, 3, 3, 3]], dtype=np.int32) + + iou = box_iou(boxes1, boxes2) + self.assertTrue(np.issubdtype(iou.dtype, np.floating)) + self.assertGreater(iou[0, 0], 0.0, "IoU should not be truncated to 0") + + if __name__ == "__main__": unittest.main()