Skip to content

Commit 991bc89

Browse files
committedOct 6, 2018
Scalar tensor
1 parent b5a0a83 commit 991bc89

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed
 

‎train_strong.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def process_epoch(
154154
tnf_batch = batch_preprocessing_fn(batch)
155155
theta = model(tnf_batch)
156156
loss = loss_fn(theta, tnf_batch['theta_GT'])
157-
loss_np = loss.data.cpu().numpy()[0]
157+
loss_np = loss.data.cpu().numpy()
158158
epoch_loss += loss_np
159159
if mode == 'train':
160160
loss.backward()

‎train_weak.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def process_epoch(
225225
optimizer.zero_grad()
226226
tnf_batch = batch_preprocessing_fn(batch)
227227
loss = loss_fn(tnf_batch)
228-
loss_np = loss.data.cpu().numpy()[0]
228+
loss_np = loss.data.cpu().numpy()
229229
epoch_loss += loss_np
230230
if mode == 'train':
231231
loss.backward()

‎util/train_test_fn.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def train_fun_strong(
2222
loss = loss_fn(theta, tnf_batch['theta_GT'])
2323
loss.backward()
2424
optimizer.step()
25-
train_loss += loss.data.cpu().numpy()[0]
25+
train_loss += loss.data.cpu().numpy()
2626
if batch_idx % log_interval == 0:
2727
print(
2828
'Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format(
@@ -42,7 +42,7 @@ def test_fun_strong(model, loss_fn, dataloader, pair_generation_tnf, use_cuda=Tr
4242
tnf_batch = pair_generation_tnf(batch)
4343
theta = model(tnf_batch)
4444
loss = loss_fn(theta, tnf_batch['theta_GT'])
45-
test_loss += loss.data.cpu().numpy()[0]
45+
test_loss += loss.data.cpu().numpy()
4646

4747
test_loss /= len(dataloader)
4848
print('Test set: Average loss: {:.4f}'.format(test_loss))
@@ -96,7 +96,7 @@ def train_fun_weak(
9696

9797
loss.backward()
9898
optimizer.step()
99-
train_loss += loss.data.cpu().numpy()[0]
99+
train_loss += loss.data.cpu().numpy()
100100
print_train_progress(log_interval, batch_idx, len(dataloader), epoch, loss.data[0])
101101
train_loss /= len(dataloader)
102102
print('Train set: Average loss: {:.4f}'.format(train_loss))
@@ -134,7 +134,7 @@ def test_fun_weak(
134134
inliers_pos = loss_fn(theta_pos, corr_pos)
135135
inliers_neg = loss_fn(theta_neg, corr_neg)
136136
loss = torch.sum(inliers_neg - inliers_pos)
137-
test_loss += loss.data.cpu().numpy()[0]
137+
test_loss += loss.data.cpu().numpy()
138138

139139
test_loss /= len(dataloader)
140140
print('Test set: Average loss: {:.4f}'.format(test_loss))

0 commit comments

Comments
 (0)