@@ -43,7 +43,6 @@ def main(_):
4343 tl .files .exists_or_mkdir (FLAGS .sample_dir )
4444
4545 z_dim = 100
46-
4746 with tf .device ("/gpu:0" ):
4847 ##========================= DEFINE MODEL ===========================##
4948 z = tf .placeholder (tf .float32 , [FLAGS .batch_size , z_dim ], name = 'z_noise' )
@@ -94,15 +93,14 @@ def main(_):
9493 net_d_name = os .path .join (save_dir , 'net_d.npz' )
9594
9695 data_files = glob (os .path .join ("./data" , FLAGS .dataset , "*.jpg" ))
97- # sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
98- sample_seed = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
96+
97+ sample_seed = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
9998
10099 ##========================= TRAIN MODELS ================================##
101100 iter_counter = 0
102101 for epoch in range (FLAGS .epoch ):
103102 ## shuffle data
104103 shuffle (data_files )
105- print ("[*] Dataset shuffled!" )
106104
107105 ## update sample files based on shuffled data
108106 sample_files = data_files [0 :FLAGS .sample_size ]
@@ -119,46 +117,28 @@ def main(_):
119117 # more image augmentation functions in http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html
120118 batch = [get_image (batch_file , FLAGS .image_size , is_crop = FLAGS .is_crop , resize_w = FLAGS .output_size , is_grayscale = 0 ) for batch_file in batch_files ]
121119 batch_images = np .array (batch ).astype (np .float32 )
122- # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
123- batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 )
120+ batch_z = np .random .normal (loc = 0.0 , scale = 1.0 , size = (FLAGS .sample_size , z_dim )).astype (np .float32 ) # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
124121 start_time = time .time ()
125122 # updates the discriminator
126123 errD , _ = sess .run ([d_loss , d_optim ], feed_dict = {z : batch_z , real_images : batch_images })
127124 # updates the generator, run generator twice to make sure that d_loss does not go to zero (difference from paper)
128125 for _ in range (2 ):
129126 errG , _ = sess .run ([g_loss , g_optim ], feed_dict = {z : batch_z })
130127 print ("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
131- % (epoch , FLAGS .epoch , idx , batch_idxs ,
132- time .time () - start_time , errD , errG ))
133- sys .stdout .flush ()
128+ % (epoch , FLAGS .epoch , idx , batch_idxs , time .time () - start_time , errD , errG ))
134129
135130 iter_counter += 1
136131 if np .mod (iter_counter , FLAGS .sample_step ) == 0 :
137132 # generate and visualize generated images
138133 img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
139- save_images (img , [8 , 8 ],
140- './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
134+ tl .visualize .save_images (img , [8 , 8 ], './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
141135 print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD , errG ))
142- sys .stdout .flush ()
143136
144137 if np .mod (iter_counter , FLAGS .save_step ) == 0 :
145138 # save current network parameters
146139 print ("[*] Saving checkpoints..." )
147- img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
148- model_dir = "%s_%s_%s" % (FLAGS .dataset , FLAGS .batch_size , FLAGS .output_size )
149- save_dir = os .path .join (FLAGS .checkpoint_dir , model_dir )
150- if not os .path .exists (save_dir ):
151- os .makedirs (save_dir )
152- # the latest version location
153- net_g_name = os .path .join (save_dir , 'net_g.npz' )
154- net_d_name = os .path .join (save_dir , 'net_d.npz' )
155- # # this version is for future re-check and visualization analysis
156- # net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
157- # net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
158- # tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
159- # tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
160- # tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
161- # tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
140+ tl .files .save_npz (net_g .all_params , name = net_g_name , sess = sess )
141+ tl .files .save_npz (net_d .all_params , name = net_d_name , sess = sess )
162142 print ("[*] Saving checkpoints SUCCESS!" )
163143
164144if __name__ == '__main__' :
0 commit comments