From 1523b3e9929a06d798871eb9afc4c9f770743baf Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 9 Feb 2023 21:02:04 -0800 Subject: [PATCH] Add ONNX support to KRCNNConvDeconvUpsampleHead Summary: In order to export `KRCNNConvDeconvUpsampleHead` to ONNX using `torch.jit.script`, changes to both PyTorch and detectron2: - Pytorch has a bug which prevents a tensor wrapped by a list as float. Refer to the required [fix](https://github.com/pytorch/pytorch/pull/81386) - `detectron2/structures/keypoints.py::heatmaps_to_keypoints` internally does advanced indexing on a `squeeze`d tensor. The aforementioned `squeeze` fails rank inference due to the presence of `onnx::If` on its implementation (to support dynamic dims). The fix is replacing `squeeze` by `reshape`. A possible fix to `squeeze` on PyTorch side might be done too (TBD and would take some time), but the proposed change here does not bring any consequence to detectron2 while it enables ONNX support with scriptable `KRCNNConvDeconvUpsampleHead `. After the proposed changes, the `KRCNNConvDeconvUpsampleHead` does include a `Loop` node to represent a for-loop inside the model and `dynamic outputs`, as shown below: ![image](https://user-images.githubusercontent.com/5469809/179559001-f60fb8af-ec79-4758-b271-736467b5d96f.png) This PR has been tested with ONNX Runtime (this [PR](https://github.com/facebookresearch/detectron2/pull/4205)) to ensure the ONNX output matches PyTorch's for different `gen_input(X, Y)` combinations and it succeeded. The model was converted to ONNX once with a particular input and tested with inputs of different shapes and compared to equality to PyTorch's Depends on: https://github.com/pytorch/pytorch/pull/81386 and https://github.com/facebookresearch/detectron2/pull/4291 Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/4315 Reviewed By: newstzpz Differential Revision: D42756423 fbshipit-source-id: dc410df18da07f48c14f4cae9a4a91530a0ec602 --- detectron2/structures/keypoints.py | 14 +++---- detectron2/utils/testing.py | 15 ++++++++ tests/test_export_onnx.py | 62 ++++++++++++++++++++++++++++-- 3 files changed, 79 insertions(+), 12 deletions(-) diff --git a/detectron2/structures/keypoints.py b/detectron2/structures/keypoints.py index d0ee8724ac..b93ebed4f6 100644 --- a/detectron2/structures/keypoints.py +++ b/detectron2/structures/keypoints.py @@ -179,10 +179,6 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. """ - # The decorator use of torch.no_grad() was not supported by torchscript. - # https://github.com/pytorch/pytorch/issues/44768 - maps = maps.detach() - rois = rois.detach() offset_x = rois[:, 0] offset_y = rois[:, 1] @@ -202,11 +198,11 @@ def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tenso for i in range(num_rois): outsize = (int(heights_ceil[i]), int(widths_ceil[i])) - roi_map = F.interpolate( - maps[[i]], size=outsize, mode="bicubic", align_corners=False - ).squeeze( - 0 - ) # #keypoints x H x W + roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False) + + # Although semantically equivalent, `reshape` is used instead of `squeeze` due + # to limitation during ONNX export of `squeeze` in scripting mode + roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W # softmax over the spatial region max_score, _ = roi_map.view(num_keypoints, -1).max(1) diff --git a/detectron2/utils/testing.py b/detectron2/utils/testing.py index b597ed92fd..3f5b9dbe44 100644 --- a/detectron2/utils/testing.py +++ b/detectron2/utils/testing.py @@ -176,6 +176,21 @@ def min_torch_version(min_version: str) -> bool: return installed_version >= min_version +def has_dynamic_axes(onnx_model): + """ + Return True when all ONNX input/output have only dynamic axes for all ranks + """ + return all( + not dim.dim_param.isnumeric() + for inp in onnx_model.graph.input + for dim in inp.type.tensor_type.shape.dim + ) and all( + not dim.dim_param.isnumeric() + for out in onnx_model.graph.output + for dim in out.type.tensor_type.shape.dim + ) + + def register_custom_op_onnx_export( opname: str, symbolic_fn: Callable, opset_version: int, min_version: str ) -> None: diff --git a/tests/test_export_onnx.py b/tests/test_export_onnx.py index ffab02ae48..aa15e1a406 100644 --- a/tests/test_export_onnx.py +++ b/tests/test_export_onnx.py @@ -10,11 +10,17 @@ from detectron2.config import get_cfg from detectron2.export import STABLE_ONNX_OPSET_VERSION from detectron2.export.flatten import TracingAdapter +from detectron2.export.torchscript_patch import patch_builtin_len +from detectron2.layers import ShapeSpec from detectron2.modeling import build_model +from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead +from detectron2.structures import Boxes, Instances from detectron2.utils.testing import ( _pytorch1111_symbolic_opset9_repeat_interleave, _pytorch1111_symbolic_opset9_to, get_sample_coco_image, + has_dynamic_axes, + random_boxes, register_custom_op_onnx_export, skipIfOnCPUCI, skipIfUnsupportedMinOpsetVersion, @@ -26,6 +32,8 @@ @unittest.skipIf(not _check_module_exists("onnx"), "ONNX not installed.") @skipIfUnsupportedMinTorchVersion("1.10") class TestONNXTracingExport(unittest.TestCase): + opset_version = STABLE_ONNX_OPSET_VERSION + def testMaskRCNNFPN(self): def inference_func(model, images): with warnings.catch_warnings(record=True): @@ -85,9 +93,55 @@ def inference_func(model, image): self._test_model_zoo_from_config_path( "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func, - opset_version=STABLE_ONNX_OPSET_VERSION, ) + def testKeypointHead(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = KRCNNConvDeconvUpsampleHead( + ShapeSpec(channels=4, height=14, width=14), num_keypoints=17, conv_dims=(4,) + ) + + def forward(self, x, predbox1, predbox2): + inst = [ + Instances((100, 100), pred_boxes=Boxes(predbox1)), + Instances((100, 100), pred_boxes=Boxes(predbox2)), + ] + ret = self.model(x, inst) + return tuple(x.pred_keypoints for x in ret) + + model = M() + model.eval() + + def gen_input(num1, num2): + feat = torch.randn((num1 + num2, 4, 14, 14)) + box1 = random_boxes(num1) + box2 = random_boxes(num2) + return feat, box1, box2 + + with patch_builtin_len(): + onnx_model = self._test_model( + model, + gen_input(1, 2), + input_names=["features", "pred_boxes", "pred_classes"], + output_names=["box1", "box2"], + dynamic_axes={ + "features": {0: "batch", 1: "static_four", 2: "height", 3: "width"}, + "pred_boxes": {0: "batch", 1: "static_four"}, + "pred_classes": {0: "batch", 1: "static_four"}, + "box1": {0: "num_instance", 1: "K", 2: "static_three"}, + "box2": {0: "num_instance", 1: "K", 2: "static_three"}, + }, + ) + + # Although ONNX models are not executable by PyTorch to verify + # support of batches with different sizes, we can verify model's IR + # does not hard-code input and/or output shapes. + # TODO: Add tests with different batch sizes when detectron2's CI + # support ONNX Runtime backend. + assert has_dynamic_axes(onnx_model) + ################################################################################ # Testcase internals - DO NOT add tests below this point ################################################################################ @@ -114,6 +168,9 @@ def _test_model( save_onnx_graph_path=None, **export_kwargs, ): + # Not imported in the beginning of file to prevent runtime errors + # for environments without ONNX. + # This testcase checks dependencies before running import onnx # isort:skip f = io.BytesIO() @@ -138,6 +195,7 @@ def _test_model( assert onnx_model is not None if save_onnx_graph_path: onnx.save(onnx_model, save_onnx_graph_path) + return onnx_model def _test_model_zoo_from_config_path( self, @@ -171,9 +229,7 @@ def _test_model_from_config_path( point_rend.add_pointrend_config(cfg) cfg.merge_from_file(config_path) cfg.freeze() - model = build_model(cfg) - image = get_sample_coco_image() inputs = tuple(image.clone() for _ in range(batch)) return self._test_model(