diff --git a/Global/models/mapping_model.py b/Global/models/mapping_model.py index e030f0f6..ab3d8121 100755 --- a/Global/models/mapping_model.py +++ b/Global/models/mapping_model.py @@ -345,6 +345,11 @@ def inference(self, label, inst): fake_image = self.netG_B.forward(label_feat_map, flow="dec") return fake_image + def save(self, which_epoch): + self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) + self.save_network(self.mapping_net, 'mapping_net', which_epoch, self.gpu_ids) + self.save_optimizer(self.optimizer_D,"D",which_epoch) + self.save_optimizer(self.optimizer_mapping, 'mapping_net', which_epoch) class InferenceModel(Pix2PixHDModel_Mapping): def forward(self, label, inst):