From aa4c17b9ec1c4bb7352876e96d3cbb7f69d97043 Mon Sep 17 00:00:00 2001 From: Yizhan Date: Fri, 21 Nov 2025 17:15:19 -0800 Subject: [PATCH] fix: Enhance object detection grouping by introducing batch index handling in DetPredictor. This change addresses issues with multi-table cell detection by ensuring correct alignment of detections post-NMS. Added logic to extract or construct batch indices for improved grouping of predictions. --- .../models/object_detection/predictor.py | 107 ++++++++++++++++-- 1 file changed, 96 insertions(+), 11 deletions(-) diff --git a/paddlex/inference/models/object_detection/predictor.py b/paddlex/inference/models/object_detection/predictor.py index c049cc83c7..6740327f48 100644 --- a/paddlex/inference/models/object_detection/predictor.py +++ b/paddlex/inference/models/object_detection/predictor.py @@ -163,6 +163,7 @@ def _format_output(self, pred: Sequence[Any]) -> List[dict]: compatible with SOLOv2 output. - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks], compatible with Instance Segmentation output. + - When len(pred) >= 2 and pred[2] exists as batch_inds, use batch_inds for grouping. Returns: List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2), @@ -185,18 +186,102 @@ def _format_output(self, pred: Sequence[Any]) -> List[dict]: for i in range(len(pred_class_id)) ] - if len(pred) == 3: - # Adapt to Instance Segmentation - pred_mask = [] - for idx in range(len(pred[1])): - np_boxes_num = pred[1][idx] - box_idx_end = box_idx_start + np_boxes_num - np_boxes = pred[0][box_idx_start:box_idx_end] - pred_box.append(np_boxes) + # Fix for multi-table cell detection issue: + # When multiple table crops are processed in a single batch, post-NMS ordering + # of detections can shift, causing cell assignments to no longer align with + # their original tables. Solution: group by batch_idx instead of relying on + # the order of flattened predictions. + # Reference: https://github.com/PaddlePaddle/PaddleX/issues/17133 + + # Check if batch_inds is available (for RT-DETR and similar models) + # RT-DETR models may output batch_inds as a separate output from raw inference + batch_inds = None + + # First, try to find batch_inds in the raw RT-DETR outputs + # For RT-DETR, batch_inds might be in a separate output + # We need to distinguish it from masks (which are 3D arrays for Instance Segmentation) + # Check all outputs after box_nums for potential batch_inds + for i in range(2, len(pred)): + candidate = pred[i] + # batch_inds should be 1D array of integers matching boxes length + # masks for Instance Segmentation are typically 3D or have different shape + if ( + isinstance(candidate, np.ndarray) + and candidate.ndim == 1 + and candidate.dtype in (np.int32, np.int64, np.int8, np.int16) + and len(candidate) == len(pred[0]) + and len(pred[1]) > 0 + and candidate.max() < len(pred[1]) # batch_id should be < batch_size + and candidate.min() >= 0 # batch_id should be >= 0 + ): + batch_inds = candidate.astype(np.int32) + break # Found batch_inds, stop searching + + # If batch_inds is not found in outputs, construct it from box_nums + # This assumes boxes are ordered correctly before any NMS or reordering + if batch_inds is None and len(pred) >= 2 and len(pred[1]) > 0: + box_nums = pred[1] + if isinstance(box_nums, np.ndarray) and box_nums.ndim == 1: + total_boxes = len(pred[0]) + expected_total = int(box_nums.sum()) + # Only construct batch_inds if the total matches (boxes haven't been reordered) + if total_boxes == expected_total: + batch_inds = np.zeros(total_boxes, dtype=np.int32) + box_idx = 0 + for batch_idx, num_boxes in enumerate(box_nums): + num_boxes = int(num_boxes) + if box_idx + num_boxes <= total_boxes: + batch_inds[box_idx : box_idx + num_boxes] = batch_idx + box_idx += num_boxes + else: + # If mismatch, don't use batch_inds + batch_inds = None + break + + # Use batch_inds for grouping if available + # This ensures correct grouping even when post-NMS reorders detections + if batch_inds is not None: + unique_batch_ids = np.unique(batch_inds) + pred_box = [] + # Find mask output if exists (for Instance Segmentation) + # Masks are typically not 1D integer arrays (unlike batch_inds) + mask_output_idx = None + if len(pred) >= 3: + for i in range(2, len(pred)): + candidate = pred[i] + # If this is not the batch_inds we found, and it looks like masks + if ( + isinstance(candidate, np.ndarray) + and not (candidate.ndim == 1 and candidate.dtype in (np.int32, np.int64, np.int8, np.int16) and len(candidate) == len(pred[0]) and candidate.max() < len(pred[1]) and candidate.min() >= 0) + and len(candidate) == len(pred[0]) + ): + mask_output_idx = i + break + + pred_mask = [] if mask_output_idx is not None else None + + # Group by batch_idx instead of slicing (fixes post-NMS ordering issue) + for batch_id in unique_batch_ids: + mask = batch_inds == batch_id + np_boxes = pred[0][mask] + pred_box.append(np_boxes) + if pred_mask is not None and mask_output_idx is not None: + np_masks = pred[mask_output_idx][mask] + pred_mask.append(np_masks) + else: + # Fallback to original box_nums slicing method if len(pred) == 3: - np_masks = pred[2][box_idx_start:box_idx_end] - pred_mask.append(np_masks) - box_idx_start = box_idx_end + # Adapt to Instance Segmentation + pred_mask = [] + for idx in range(len(pred[1])): + np_boxes_num = pred[1][idx] + box_idx_end = box_idx_start + np_boxes_num + np_boxes = pred[0][box_idx_start:box_idx_end] + pred_box.append(np_boxes) + if len(pred) == 3: + np_masks = pred[2][box_idx_start:box_idx_end] + pred_mask.append(np_masks) + box_idx_start = box_idx_end if len(pred) == 3: return [