@@ -826,7 +826,10 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
826826 boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
827827
828828 Returns:
829- IoU, with size of (N,M) and same data type as ``boxes1``
829+ An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
830+ floating-point with size ``(N, M)``:
831+ - if ``boxes1`` has a floating-point dtype, the same dtype is used.
832+ - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
830833
831834 """
832835
@@ -842,16 +845,18 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
842845
843846 inter , union = _box_inter_union (boxes1_t , boxes2_t , compute_dtype = COMPUTE_DTYPE )
844847
845- # compute IoU and convert back to original box_dtype
848+ # compute IoU and convert back to original box_dtype or torch.float32
846849 iou_t = inter / (union + torch .finfo (COMPUTE_DTYPE ).eps ) # (N,M)
850+ if not box_dtype .is_floating_point :
851+ box_dtype = COMPUTE_DTYPE
847852 iou_t = iou_t .to (dtype = box_dtype )
848853
849854 # check if NaN or Inf
850855 if torch .isnan (iou_t ).any () or torch .isinf (iou_t ).any ():
851856 raise ValueError ("Box IoU is NaN or Inf." )
852857
853858 # convert tensor back to numpy if needed
854- iou , * _ = convert_to_dst_type (src = iou_t , dst = boxes1 )
859+ iou , * _ = convert_to_dst_type (src = iou_t , dst = boxes1 , dtype = box_dtype )
855860 return iou
856861
857862
@@ -867,7 +872,10 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
867872 boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
868873
869874 Returns:
870- GIoU, with size of (N,M) and same data type as ``boxes1``
875+ An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
876+ floating-point with size ``(N, M)``:
877+ - if ``boxes1`` has a floating-point dtype, the same dtype is used.
878+ - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
871879
872880 Reference:
873881 https://giou.stanford.edu/GIoU.pdf
@@ -904,12 +912,15 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
904912
905913 # GIoU
906914 giou_t = iou - (enclosure - union ) / (enclosure + torch .finfo (COMPUTE_DTYPE ).eps )
915+ if not box_dtype .is_floating_point :
916+ box_dtype = COMPUTE_DTYPE
907917 giou_t = giou_t .to (dtype = box_dtype )
918+
908919 if torch .isnan (giou_t ).any () or torch .isinf (giou_t ).any ():
909920 raise ValueError ("Box GIoU is NaN or Inf." )
910921
911922 # convert tensor back to numpy if needed
912- giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 )
923+ giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 , dtype = box_dtype )
913924 return giou
914925
915926
@@ -925,7 +936,10 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
925936 boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``
926937
927938 Returns:
928- paired GIoU, with size of (N,) and same data type as ``boxes1``
939+ An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
940+ floating-point with size ``(N, )``:
941+ - if ``boxes1`` has a floating-point dtype, the same dtype is used.
942+ - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
929943
930944 Reference:
931945 https://giou.stanford.edu/GIoU.pdf
@@ -982,12 +996,15 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
982996 enclosure = torch .prod (wh , dim = - 1 , keepdim = False ) # (N,)
983997
984998 giou_t : torch .Tensor = iou - (enclosure - union ) / (enclosure + torch .finfo (COMPUTE_DTYPE ).eps ) # type: ignore
999+ if not box_dtype .is_floating_point :
1000+ box_dtype = COMPUTE_DTYPE
9851001 giou_t = giou_t .to (dtype = box_dtype ) # (N,spatial_dims)
1002+
9861003 if torch .isnan (giou_t ).any () or torch .isinf (giou_t ).any ():
9871004 raise ValueError ("Box GIoU is NaN or Inf." )
9881005
9891006 # convert tensor back to numpy if needed
990- giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 )
1007+ giou , * _ = convert_to_dst_type (src = giou_t , dst = boxes1 , dtype = box_dtype )
9911008 return giou
9921009
9931010
0 commit comments