@@ -826,7 +826,7 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
826826 boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
827827
828828 Returns:
829- The output is always floating-point:
829+ The output is always floating-point (size: (N, M)) :
830830 - if ``boxes1`` has a floating-point dtype, the same dtype is used.
831831 - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
832832
@@ -844,7 +844,7 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
844844
845845 inter , union = _box_inter_union (boxes1_t , boxes2_t , compute_dtype = COMPUTE_DTYPE )
846846
847- # compute IoU and convert back to original box_dtype or float32
847+ # compute IoU and convert back to original box_dtype or torch. float32
848848 iou_t = inter / (union + torch .finfo (COMPUTE_DTYPE ).eps ) # (N,M)
849849 if not box_dtype .is_floating_point :
850850 box_dtype = COMPUTE_DTYPE
@@ -871,7 +871,9 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
871871 boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
872872
873873 Returns:
874- GIoU, with size of (N,M) and same data type as ``boxes1``
874+ The output is always floating-point (size: (N, M)):
875+ - if ``boxes1`` has a floating-point dtype, the same dtype is used.
876+ - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
875877
876878 Reference:
877879 https://giou.stanford.edu/GIoU.pdf
@@ -889,7 +891,7 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
889891
890892 # we do computation with compute_dtype to avoid overflow
891893 box_dtype = boxes1_t .dtype
892-
894+
893895 inter , union = _box_inter_union (boxes1_t , boxes2_t , compute_dtype = COMPUTE_DTYPE )
894896 iou = inter / (union + torch .finfo (COMPUTE_DTYPE ).eps ) # (N,M)
895897
@@ -908,12 +910,15 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
908910
909911 # GIoU
910912 giou_t = iou - (enclosure - union ) / (enclosure + torch .finfo (COMPUTE_DTYPE ).eps )
913+ if not box_dtype .is_floating_point :
914+ box_dtype = COMPUTE_DTYPE
911915 giou_t = giou_t .to (dtype = box_dtype )
916+
912917 if torch .isnan (giou_t ).any () or torch .isinf (giou_t ).any ():
913918 raise ValueError ("Box GIoU is NaN or Inf." )
914919
915920 # convert tensor back to numpy if needed
916- giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 )
921+ giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 , dtype = box_dtype )
917922 return giou
918923
919924
@@ -929,7 +934,9 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
929934 boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``
930935
931936 Returns:
932- paired GIoU, with size of (N,) and same data type as ``boxes1``
937+ The output is always floating-point (size: (N,)):
938+ - if ``boxes1`` has a floating-point dtype, the same dtype is used.
939+ - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
933940
934941 Reference:
935942 https://giou.stanford.edu/GIoU.pdf
@@ -986,12 +993,15 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
986993 enclosure = torch .prod (wh , dim = - 1 , keepdim = False ) # (N,)
987994
988995 giou_t : torch .Tensor = iou - (enclosure - union ) / (enclosure + torch .finfo (COMPUTE_DTYPE ).eps ) # type: ignore
996+ if not box_dtype .is_floating_point :
997+ box_dtype = COMPUTE_DTYPE
989998 giou_t = giou_t .to (dtype = box_dtype ) # (N,spatial_dims)
999+
9901000 if torch .isnan (giou_t ).any () or torch .isinf (giou_t ).any ():
9911001 raise ValueError ("Box GIoU is NaN or Inf." )
9921002
9931003 # convert tensor back to numpy if needed
994- giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 )
1004+ giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 , dtype = box_dtype )
9951005 return giou
9961006
9971007
0 commit comments