From dde114835589a315915fd3b2b4068e7a8ce4499a Mon Sep 17 00:00:00 2001 From: Marco Cheung <73602420+LAFLAMIE1024@users.noreply.github.com> Date: Mon, 1 Aug 2022 22:26:20 +0800 Subject: [PATCH] Update projected_model.py --- models/projected_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/projected_model.py b/models/projected_model.py index 477f63e1..f9194f75 100644 --- a/models/projected_model.py +++ b/models/projected_model.py @@ -50,15 +50,16 @@ def initialize(self, opt): self.netArc = self.netArc.cuda() self.netArc.eval() self.netArc.requires_grad_(False) + if not self.isTrain: pretrained_path = opt.checkpoints_dir self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) return + self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{}) # self.netD.feature_network.requires_grad_(False) self.netD.cuda() - if self.isTrain: # define loss functions self.criterionFeat = nn.L1Loss() @@ -83,6 +84,7 @@ def initialize(self, opt): self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path) self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path) + torch.cuda.empty_cache() def cosin_metric(self, x1, x2):