|
15 | 15 |
|
16 | 16 | def featureL2Norm(feature):
|
17 | 17 | epsilon = 1e-6
|
18 |
| - # print(feature.size()) |
19 |
| - # print(torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).size()) |
20 | 18 | norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon,
|
21 | 19 | 0.5).unsqueeze(1).expand_as(feature)
|
22 | 20 | return torch.div(feature, norm)
|
@@ -136,12 +134,7 @@ def __init__(self, shape='3D', normalization=True):
|
136 | 134 | self.shape = shape
|
137 | 135 | self.ReLU = nn.ReLU()
|
138 | 136 |
|
139 |
| - def forward(self, feature_A, feature_B, cues_A, cues_B): |
140 |
| - if cues_A is None or cues_B is None: |
141 |
| - # localization cues are available |
142 |
| - pass # TODO: create uniform localization weight over all the images. |
143 |
| - |
144 |
| - # TODO: update this to use localization cues. |
| 137 | + def forward(self, feature_A, feature_B): |
145 | 138 | b, c, h, w = feature_A.size()
|
146 | 139 | if self.shape == '3D':
|
147 | 140 | # reshape features for matrix multiplication
|
@@ -223,9 +216,6 @@ def __init__(
|
223 | 216 | use_cuda=True,
|
224 | 217 | delf_path=''
|
225 | 218 | ):
|
226 |
| - # regressor_channels_1 = 128, |
227 |
| - # regressor_channels_2 = 64): |
228 |
| - |
229 | 219 | super(CNNGeometric, self).__init__()
|
230 | 220 | self.use_cuda = use_cuda
|
231 | 221 | self.feature_self_matching = feature_self_matching
|
@@ -259,11 +249,8 @@ def forward(self, tnf_batch):
|
259 | 249 | # feature extraction
|
260 | 250 | feature_A = self.FeatureExtraction(tnf_batch['source_image'])
|
261 | 251 | feature_B = self.FeatureExtraction(tnf_batch['target_image'])
|
262 |
| - # localization cues |
263 |
| - cues_A = tnf_batch.get('source_cues') |
264 |
| - cues_B = tnf_batch.get('target_cues') |
265 | 252 | # feature correlation
|
266 |
| - correlation = self.FeatureCorrelation(feature_A, feature_B, cues_A, cues_B) |
| 253 | + correlation = self.FeatureCorrelation(feature_A, feature_B) |
267 | 254 | # regression to tnf parameters theta
|
268 | 255 | theta = self.FeatureRegression(correlation)
|
269 | 256 |
|
|
0 commit comments