Skip to content

Commit 7f158b2

Browse files
committed
fix: prevent integer truncation in box_giou and box_pair_giou by applying IoU dtype rules
Signed-off-by: reworld223 <[email protected]>
1 parent c2bd375 commit 7f158b2

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

monai/data/box_utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
826826
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
827827
828828
Returns:
829-
The output is always floating-point:
829+
The output is always floating-point (size: (N, M)):
830830
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
831831
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
832832
@@ -844,7 +844,7 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
844844

845845
inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE)
846846

847-
# compute IoU and convert back to original box_dtype or float32
847+
# compute IoU and convert back to original box_dtype or torch.float32
848848
iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M)
849849
if not box_dtype.is_floating_point:
850850
box_dtype = COMPUTE_DTYPE
@@ -871,7 +871,9 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
871871
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
872872
873873
Returns:
874-
GIoU, with size of (N,M) and same data type as ``boxes1``
874+
The output is always floating-point (size: (N, M)):
875+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
876+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
875877
876878
Reference:
877879
https://giou.stanford.edu/GIoU.pdf
@@ -889,7 +891,7 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
889891

890892
# we do computation with compute_dtype to avoid overflow
891893
box_dtype = boxes1_t.dtype
892-
894+
893895
inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE)
894896
iou = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M)
895897

@@ -908,12 +910,15 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
908910

909911
# GIoU
910912
giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps)
913+
if not box_dtype.is_floating_point:
914+
box_dtype = COMPUTE_DTYPE
911915
giou_t = giou_t.to(dtype=box_dtype)
916+
912917
if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
913918
raise ValueError("Box GIoU is NaN or Inf.")
914919

915920
# convert tensor back to numpy if needed
916-
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
921+
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype)
917922
return giou
918923

919924

@@ -929,7 +934,9 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
929934
boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``
930935
931936
Returns:
932-
paired GIoU, with size of (N,) and same data type as ``boxes1``
937+
The output is always floating-point (size: (N,)):
938+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
939+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
933940
934941
Reference:
935942
https://giou.stanford.edu/GIoU.pdf
@@ -986,12 +993,15 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
986993
enclosure = torch.prod(wh, dim=-1, keepdim=False) # (N,)
987994

988995
giou_t: torch.Tensor = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) # type: ignore
996+
if not box_dtype.is_floating_point:
997+
box_dtype = COMPUTE_DTYPE
989998
giou_t = giou_t.to(dtype=box_dtype) # (N,spatial_dims)
999+
9901000
if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
9911001
raise ValueError("Box GIoU is NaN or Inf.")
9921002

9931003
# convert tensor back to numpy if needed
994-
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
1004+
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype= box_dtype)
9951005
return giou
9961006

9971007

0 commit comments

Comments
 (0)