From 553f208eb5139ca52c972450466193aa34591c1a Mon Sep 17 00:00:00 2001 From: YiwenShaoStephen Date: Wed, 26 Aug 2020 15:50:00 -0400 Subject: [PATCH] match train and test in data parallel --- test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test.py b/test.py index baae58a..f770fc9 100755 --- a/test.py +++ b/test.py @@ -56,7 +56,7 @@ def main(): print(model) if use_cuda: - model = torch.nn.DataParallel(model).cuda() + model = model.cuda() # Load checkpoint. print('==> Resuming from checkpoint..') @@ -71,6 +71,8 @@ def test(testloader, model, output_file, use_cuda): model.eval() with open(output_file, 'wb') as f: for i, (inputs, input_lengths, utt_ids) in enumerate(testloader): + if use_cuda: + inputs = inputs.cuda() lprobs, output_lengths = model(inputs, input_lengths) for j in range(inputs.size(0)): output_length = output_lengths[j]