Skip to content

Commit

Permalink
match train and test in data parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
YiwenShaoStephen committed Aug 26, 2020
1 parent 8684887 commit 553f208
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main():
print(model)

if use_cuda:
model = torch.nn.DataParallel(model).cuda()
model = model.cuda()

# Load checkpoint.
print('==> Resuming from checkpoint..')
Expand All @@ -71,6 +71,8 @@ def test(testloader, model, output_file, use_cuda):
model.eval()
with open(output_file, 'wb') as f:
for i, (inputs, input_lengths, utt_ids) in enumerate(testloader):
if use_cuda:
inputs = inputs.cuda()
lprobs, output_lengths = model(inputs, input_lengths)
for j in range(inputs.size(0)):
output_length = output_lengths[j]
Expand Down

0 comments on commit 553f208

Please sign in to comment.