From 181aae36820af025eed1e33e58390f7ed9261e1a Mon Sep 17 00:00:00 2001 From: Alan Li Date: Mon, 22 Apr 2024 21:03:16 -0700 Subject: [PATCH] E2E OCR Python Predictor Setup Summary: Python predictor setup for e2e ocr recognition model with detection + recognition. It currently takes non-normalized images (0.0 to 1.0) Differential Revision: D55776288 fbshipit-source-id: b8779b065f9ad6c1599d50aefbd546c30587eeb1 --- detectron2/layers/nms.py | 1 + detectron2/layers/roi_align_rotated.py | 1 + detectron2/modeling/anchor_generator.py | 9 +++++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/detectron2/layers/nms.py b/detectron2/layers/nms.py index 1019e7f4c8..65afb746bf 100644 --- a/detectron2/layers/nms.py +++ b/detectron2/layers/nms.py @@ -22,6 +22,7 @@ def batched_nms( # Note: this function (nms_rotated) might be moved into # torchvision/ops/boxes.py in the future +@torch.compiler.disable def nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float): """ Performs non-maximum suppression (NMS) on the rotated boxes according diff --git a/detectron2/layers/roi_align_rotated.py b/detectron2/layers/roi_align_rotated.py index 2a523992e7..7e25a310db 100644 --- a/detectron2/layers/roi_align_rotated.py +++ b/detectron2/layers/roi_align_rotated.py @@ -8,6 +8,7 @@ class _ROIAlignRotated(Function): @staticmethod + @torch.compiler.disable def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): ctx.save_for_backward(roi) ctx.output_size = _pair(output_size) diff --git a/detectron2/modeling/anchor_generator.py b/detectron2/modeling/anchor_generator.py index ac94e72396..ca65c56406 100644 --- a/detectron2/modeling/anchor_generator.py +++ b/detectron2/modeling/anchor_generator.py @@ -162,7 +162,7 @@ def num_anchors(self): """ return [len(cell_anchors) for cell_anchors in self.cell_anchors] - def _grid_anchors(self, grid_sizes: List[List[int]]): + def _grid_anchors(self, grid_sizes): """ Returns: list[Tensor]: #featuremap tensors, each is (#locations x #cell_anchors) x 4 @@ -317,7 +317,12 @@ def num_anchors(self): def _grid_anchors(self, grid_sizes): anchors = [] - for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors): + for size, stride, base_anchors in zip( + grid_sizes, + self.strides, + self.cell_anchors._buffers.values(), + strict=False + ): shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors) zeros = torch.zeros_like(shift_x) shifts = torch.stack((shift_x, shift_y, zeros, zeros, zeros), dim=1)