@@ -22,7 +22,7 @@ def train_fun_strong(
22
22
loss = loss_fn (theta , tnf_batch ['theta_GT' ])
23
23
loss .backward ()
24
24
optimizer .step ()
25
- train_loss += loss .data .cpu ().numpy ()[ 0 ]
25
+ train_loss += loss .data .cpu ().numpy ()
26
26
if batch_idx % log_interval == 0 :
27
27
print (
28
28
'Train Epoch: {} [{}/{} ({:.0f}%)]\t \t Loss: {:.6f}' .format (
@@ -42,7 +42,7 @@ def test_fun_strong(model, loss_fn, dataloader, pair_generation_tnf, use_cuda=Tr
42
42
tnf_batch = pair_generation_tnf (batch )
43
43
theta = model (tnf_batch )
44
44
loss = loss_fn (theta , tnf_batch ['theta_GT' ])
45
- test_loss += loss .data .cpu ().numpy ()[ 0 ]
45
+ test_loss += loss .data .cpu ().numpy ()
46
46
47
47
test_loss /= len (dataloader )
48
48
print ('Test set: Average loss: {:.4f}' .format (test_loss ))
@@ -96,7 +96,7 @@ def train_fun_weak(
96
96
97
97
loss .backward ()
98
98
optimizer .step ()
99
- train_loss += loss .data .cpu ().numpy ()[ 0 ]
99
+ train_loss += loss .data .cpu ().numpy ()
100
100
print_train_progress (log_interval , batch_idx , len (dataloader ), epoch , loss .data [0 ])
101
101
train_loss /= len (dataloader )
102
102
print ('Train set: Average loss: {:.4f}' .format (train_loss ))
@@ -134,7 +134,7 @@ def test_fun_weak(
134
134
inliers_pos = loss_fn (theta_pos , corr_pos )
135
135
inliers_neg = loss_fn (theta_neg , corr_neg )
136
136
loss = torch .sum (inliers_neg - inliers_pos )
137
- test_loss += loss .data .cpu ().numpy ()[ 0 ]
137
+ test_loss += loss .data .cpu ().numpy ()
138
138
139
139
test_loss /= len (dataloader )
140
140
print ('Test set: Average loss: {:.4f}' .format (test_loss ))
0 commit comments