Skip to content

Commit cc87e7e

Browse files
Antoine Salioufacebook-github-bot
authored andcommitted
Add assertion error if model not c2 compatible
Summary: Pull Request resolved: #4248 The caffe 2 model exported was not handling MODEL.RPN.BBOX_REG_WEIGHTS) different of (1., 1., 1., 1.) or (1., 1., 1., 1., 1.) correctly because the config is not passed to caffe2 generate_proposals config. Modifying it won't slove the issue as there will be an api change customers like boltnn will have to manage. This fix is to fail during caffe2 export for non supported config to avoid non regression behavior troubleshooting Reviewed By: sstsai-adl Differential Revision: D36489096 fbshipit-source-id: 47493ebd5283fd3b3614c764aac5b96d58bf7474
1 parent e091a07 commit cc87e7e

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

detectron2/export/c10.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22

33
import math
4+
from typing import Dict
5+
46
import torch
57
import torch.nn.functional as F
68

7-
from detectron2.layers import cat
9+
from detectron2.layers import ShapeSpec, cat
810
from detectron2.layers.roi_align_rotated import ROIAlignRotated
911
from detectron2.modeling import poolers
1012
from detectron2.modeling.proposal_generator import rpn
@@ -162,6 +164,14 @@ def _set_tensor_mode(self, v):
162164

163165

164166
class Caffe2RPN(Caffe2Compatible, rpn.RPN):
167+
168+
@classmethod
169+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
170+
ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape)
171+
assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1., 1., 1., 1.) or \
172+
tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1., 1., 1., 1., 1.)
173+
return ret
174+
165175
def _generate_proposals(
166176
self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None
167177
):

tests/export/test_c10.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import unittest
3+
from detectron2.config import get_cfg
4+
from detectron2.export.c10 import Caffe2RPN
5+
from detectron2.layers import ShapeSpec
6+
7+
8+
class TestCaffe2RPN(unittest.TestCase):
9+
10+
def test_instantiation(self):
11+
cfg = get_cfg()
12+
cfg.MODEL.RPN.BBOX_REG_WEIGHTS = (1, 1, 1, 1, 1)
13+
input_shapes = {'res4': ShapeSpec(channels=256, stride=4)}
14+
rpn = Caffe2RPN(cfg, input_shapes)
15+
assert rpn is not None
16+
cfg.MODEL.RPN.BBOX_REG_WEIGHTS = (10, 10, 5, 5, 1)
17+
with self.assertRaises(AssertionError):
18+
rpn = Caffe2RPN(cfg, input_shapes)

0 commit comments

Comments
 (0)