diff --git a/rationale_net/utils/model.py b/rationale_net/utils/model.py index 370792a..297fef4 100644 --- a/rationale_net/utils/model.py +++ b/rationale_net/utils/model.py @@ -19,11 +19,18 @@ def get_model(args, embeddings, train_data): print('\nLoading model from [%s]...' % args.snapshot) try: gen_path = learn.get_gen_path(args.snapshot) - if os.path.exists(gen_path): - gen = torch.load(gen_path) - model = torch.load(args.snapshot) + if args.cuda: + if os.path.exists(gen_path): + gen = torch.load(gen_path) + model = torch.load(args.snapshot) + else: + if os.path.exists(gen_path): + gen = torch.load(gen_path, map_location=lambda storage, loc: storage) + gen.args.cuda = "false" + model = torch.load(args.snapshot, map_location=lambda storage, loc: storage) + model.args.cuda = "false" except : - print("Sorry, This snapshot doesn't exist."); exit() + print("Sorry, This snapshot doesn't exist.") if args.num_gpus > 1: model = nn.DataParallel(model,