-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_builder.py
263 lines (212 loc) · 11.1 KB
/
model_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import pickle
from functools import wraps
import importlib
import logging
import torch
import torch.nn as nn
from core.config import cfg
from ops import RoIPool, RoIAlign
import modeling.heads as heads
import utils.vgg_weights_helper as vgg_utils
import utils.hrnet_weights_helper as hrnet_utils
import os
logger = logging.getLogger(__name__)
def get_func(func_name):
"""Helper to return a function object by name. func_name must identify a
function in this module or the path to a function relative to the base
'modeling' module.
"""
if func_name == '':
return None
try:
parts = func_name.split('.')
# Refers to a function in this module
if len(parts) == 1:
return globals()[parts[0]]
# Otherwise, assume we're referencing a module under modeling
module_name = 'modeling.' + '.'.join(parts[:-1])
module = importlib.import_module(module_name)
return getattr(module, parts[-1])
except Exception:
logger.error('Failed to find function: %s', func_name)
raise
def compare_state_dict(sa, sb):
if sa.keys() != sb.keys():
return False
for k, va in sa.items():
if not torch.equal(va, sb[k]):
return False
return True
def check_inference(net_func):
@wraps(net_func)
def wrapper(self, *args, **kwargs):
if not self.training:
if cfg.PYTORCH_VERSION_LESS_THAN_040:
return net_func(self, *args, **kwargs)
else:
with torch.no_grad():
return net_func(self, *args, **kwargs)
else:
raise ValueError('You should call this function only on inference.'
'Set the network in inference mode by net.eval().')
return wrapper
def testing_function(predict_cls, predict_det, ref_cls_score, ref_iou_score,return_dict):
res = []
for cls_score, iou_score in zip(ref_cls_score, ref_iou_score):
preds = cls_score * iou_score
res.append(preds[:, 1:])
return_dict['refine_score'] = res
return return_dict
class Generalized_RCNN(nn.Module):
def __init__(self):
super().__init__()
# For cache
self.mapping_to_detectron = None
self.orphans_in_detectron = None
cls_num = cfg.MODEL.NUM_CLASSES + 1
# feature extraction
self.Conv_Body = get_func(cfg.MODEL.CONV_BODY)()
self.Box_Head = get_func(cfg.FAST_RCNN.ROI_BOX_HEAD)(self.Conv_Body.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale)
self.cls_iou_model = heads.cls_iou_model(self.Box_Head.dim_out, cls_num, cfg.REFINE_TIMES,
class_agnostic=False)
self.CIM_layer_list = []
for ref_time in range(cfg.REFINE_TIMES):
self.CIM_layer_list.append(heads.CIM_layer(p_seed=cfg.p_seed,
cls_thr=0.25 + cfg.step_rate * ref_time,
iou_thr=0.5 + cfg.step_rate * ref_time,
Anti_noise_sampling=cfg.Anti_noise_sampling,
))
self.using_CIM = [True, True, True]
# load pre-trained weights
self._init_modules()
def _init_modules(self):
if cfg.VGG_CLS_FEATURE:
if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
vgg_utils.load_pretrained_imagenet_weights(self)
# Note: resnet using pre-trained weight from torch
#####
# if not cfg.ResNet_CLS_FEATURE:
# if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
# resnet_utils.load_pretrained_imagenet_weights(self)
if cfg.HRNET_CLS_FEATURE:
if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS:
hrnet_utils.load_pretrained_imagenet_weights(self)
if cfg.TRAIN.FREEZE_CONV_BODY: # False
for p in self.Conv_Body.parameters():
p.requires_grad = False
def forward(self, data, rois, masks, labels, gtrois, mat, path=None, index=None):
with torch.set_grad_enabled(self.training):
im_data = data
if self.training:
index = index.squeeze(dim=0).type(im_data.dtype)
rois = rois.squeeze(dim=0).type(im_data.dtype)
masks = masks.squeeze(dim=0).type(im_data.dtype)
labels = labels.squeeze(dim=0).type(im_data.dtype)
mat = mat.squeeze(dim=0).type(im_data.dtype)
return_dict = {} # A dict to collect return variables
blob_conv = self.Conv_Body(im_data)
# if not self.training:
return_dict['blob_conv'] = blob_conv
masks.requires_grad = False
seg_x = self.Box_Head(blob_conv, rois, masks.detach())
file_name = os.path.splitext(os.path.split(path)[1])[0]
iou_dir = cfg.iou_dir
asy_iou_dir = cfg.asy_iou_dir
predict_cls, predict_det, ref_cls_score, ref_iou_score = self.cls_iou_model(seg_x)
if self.training:
index = index.long()
try:
iou_map = pickle.load(open(os.path.join(iou_dir, file_name + ".pkl"), "rb"))
iou_map = torch.tensor(iou_map, device=labels.device)[index][:, index]
except:
print("iou_map lose " + os.path.join(iou_dir, file_name + ".pkl"))
raise NotImplementedError("Please generate or download iou_map")
try:
asy_iou_map = pickle.load(open(os.path.join(asy_iou_dir, file_name + ".pkl"), "rb"))
asy_iou_map = torch.tensor(asy_iou_map, device=labels.device)[index][:, index]
except:
print("asy_iou_map lose " + os.path.join(asy_iou_dir, file_name + ".pkl"))
raise NotImplementedError("Please generate or download asy_iou_map")
return_dict['losses'] = {}
# Anti-noise branch loss
return_dict['losses']['bag_loss'] = torch.tensor(0, dtype=torch.float32, device=seg_x.device)
return_dict['losses']['pcl_loss'] = torch.tensor(0, dtype=torch.float32, device=seg_x.device)
# Refinement branch loss
return_dict['losses']['cls_loss'] = torch.tensor(0, dtype=torch.float32, device=seg_x.device)
return_dict['losses']['iou_loss'] = torch.tensor(0, dtype=torch.float32, device=seg_x.device)
for i, (cls_score, iou_score, CIM_layer) in enumerate(zip(ref_cls_score, ref_iou_score, self.CIM_layer_list)):
# follow WSDDN
lmda = 3 if i == 0 else 1
#########
if i == 0:
pseudo_labels, pseudo_iou_labels, loss_weights = CIM_layer(predict_cls,
predict_det,
rois, labels, iou_map,
asy_iou_map,
using_CIM=self.using_CIM[i])
else:
pseudo_labels, pseudo_iou_labels, loss_weights = CIM_layer(ref_cls_score[i - 1],
ref_iou_score[i - 1],
rois, labels, iou_map,
asy_iou_map,
using_CIM=self.using_CIM[i])
if pseudo_labels == None:
continue
pseudo_labels = pseudo_labels.detach()
pseudo_iou_labels = pseudo_iou_labels.detach()
loss_weights = lmda * loss_weights.detach()
cls_loss, iou_loss, bag_loss = heads.cls_iou_loss(cls_score, iou_score, pseudo_labels, pseudo_iou_labels, loss_weights, labels)
return_dict['losses']['cls_loss'] += cls_loss.clone()
return_dict['losses']['iou_loss'] += 3 * iou_loss.clone()
return_dict['losses']['bag_loss'] += bag_loss.clone()
return_dict['losses']['bag_loss'] += heads.mil_bag_loss(predict_cls, predict_det, labels)
pcl_loss = heads.PCL_loss(predict_cls, mat, labels)
return_dict['losses']['pcl_loss'] += pcl_loss
for k, v in return_dict['losses'].items():
return_dict['losses'][k] = v.unsqueeze(0)
else:
# Testing
return_dict = testing_function(predict_cls, predict_det, ref_cls_score, ref_iou_score, return_dict)
return return_dict
def roi_feature_transform(self, blobs_in, rois, method='RoIPoolF',
resolution=7, spatial_scale=1. / 16., sampling_ratio=0):
"""Add the specified RoI pooling method. The sampling_ratio argument
is supported for some, but not all, RoI transform methods.
RoIFeatureTransform abstracts away:
- Use of FPN or not
- Specifics of the transform method
"""
assert method in {'RoIPoolF', 'RoICrop', 'RoIAlign'}, \
'Unknown pooling method: {}'.format(method)
if method == 'RoIPoolF':
xform_out = RoIPool(resolution, spatial_scale)(blobs_in, rois)
elif method == 'RoIAlign':
xform_out = RoIAlign(
resolution, spatial_scale, sampling_ratio)(blobs_in.contiguous(), rois.contiguous())
return xform_out
@check_inference
def convbody_net(self, data):
"""For inference. Run Conv Body only"""
blob_conv = self.Conv_Body(data)
return blob_conv
@property
def detectron_weight_mapping(self):
if self.mapping_to_detectron is None:
d_wmap = {} # detectron_weight_mapping
d_orphan = [] # detectron orphan weight list
for name, m_child in self.named_children():
if list(m_child.parameters()): # if module has any parameter
try:
child_map, child_orphan = m_child.detectron_weight_mapping()
d_orphan.extend(child_orphan)
for key, value in child_map.items():
new_key = name + '.' + key
d_wmap[new_key] = value
except:
print("model:{}, dont have detectron_weight_mapping function".format(m_child))
self.mapping_to_detectron = d_wmap
self.orphans_in_detectron = d_orphan
return self.mapping_to_detectron, self.orphans_in_detectron
def _add_loss(self, return_dict, key, value):
"""Add loss tensor to returned dictionary"""
return_dict['losses'][key] = value