forked from AnetaKaczynska/MeVGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprogan.patch
57 lines (51 loc) · 1.92 KB
/
progan.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
diff --git a/models/trainer/standard_configurations/pgan_config.py b/models/trainer/standard_configurations/pgan_config.py
index 59ad0ea..77b38e9 100644
--- a/models/trainer/standard_configurations/pgan_config.py
+++ b/models/trainer/standard_configurations/pgan_config.py
@@ -45,7 +45,7 @@ _C.alphaSizeJumps = [0, 32, 32, 32, 32, 32, 32, 32, 32, 32]
_C.depthScales = [512, 512, 512, 512, 256, 128, 64, 32, 16]
# Mini batch size
-_C.miniBatchSize = 16
+_C.miniBatchSize = 8
# Dimension of the latent vector
_C.dimLatentVector = 512
diff --git a/models/utils/utils.py b/models/utils/utils.py
index 2a215a7..cef5d9b 100644
--- a/models/utils/utils.py
+++ b/models/utils/utils.py
@@ -341,3 +341,26 @@ def saveScore(outPath, outValue, *args):
json.dump(fullDict, file, indent=2)
os.remove(flagPath)
+
+
+def load_progan(name='jelito3d_batchsize8', checkPointDir='output_networks/jelito3d_batchsize8', freeze_pgan_disc=True):
+ """Load pretrained ProGAN from checkpoint."""
+ checkpointData = getLastCheckPoint(checkPointDir, name, scale=None, iter=None)
+ modelConfig, pathModel, _ = checkpointData
+ _, scale, _ = parse_state_name(pathModel)
+
+ module = 'PGAN'
+ packageStr, modelTypeStr = getNameAndPackage(module)
+ modelType = loadmodule(packageStr, modelTypeStr)
+
+ with open(modelConfig, 'rb') as file:
+ config = json.load(file)
+
+ model = modelType(useGPU=True, storeAVG=True, **config)
+ model.load(pathModel)
+
+ if freeze_pgan_disc:
+ for param in model.netD.parameters():
+ param.requires_grad = False
+
+ return model
diff --git a/visualization/visualizer.py b/visualization/visualizer.py
index f3aab21..3874796 100644
--- a/visualization/visualizer.py
+++ b/visualization/visualizer.py
@@ -6,7 +6,7 @@ import torchvision.utils as vutils
import numpy as np
import random
-vis = visdom.Visdom()
+# vis = visdom.Visdom()
def resizeTensor(data, out_size_image):