3939def main (_ ):
4040 pp .pprint (flags .FLAGS .__flags )
4141
42- if not os .path .exists (FLAGS .checkpoint_dir ):
43- os .makedirs (FLAGS .checkpoint_dir )
44- if not os .path .exists (FLAGS .sample_dir ):
45- os .makedirs (FLAGS .sample_dir )
42+ tl .files .exists_or_mkdir (FLAGS .checkpoint_dir )
43+ tl .files .exists_or_mkdir (FLAGS .sample_dir )
4644
4745 z_dim = 100
4846
@@ -138,11 +136,6 @@ def main(_):
138136 if np .mod (iter_counter , FLAGS .sample_step ) == 0 :
139137 # generate and visualize generated images
140138 img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
141- '''
142- img255 = (np.array(img) + 1) / 2 * 255
143- tl.visualize.images2d(images=img255, second=0, saveable=True,
144- name='./{}/train_{:02d}_{:04d}'.format(FLAGS.sample_dir, epoch, idx), dtype=None, fig_idx=2838)
145- '''
146139 save_images (img , [8 , 8 ],
147140 './{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
148141 print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD , errG ))
@@ -159,13 +152,13 @@ def main(_):
159152 # the latest version location
160153 net_g_name = os .path .join (save_dir , 'net_g.npz' )
161154 net_d_name = os .path .join (save_dir , 'net_d.npz' )
162- # this version is for future re-check and visualization analysis
163- net_g_iter_name = os .path .join (save_dir , 'net_g_%d.npz' % iter_counter )
164- net_d_iter_name = os .path .join (save_dir , 'net_d_%d.npz' % iter_counter )
165- tl .files .save_npz (net_g .all_params , name = net_g_name , sess = sess )
166- tl .files .save_npz (net_d .all_params , name = net_d_name , sess = sess )
167- tl .files .save_npz (net_g .all_params , name = net_g_iter_name , sess = sess )
168- tl .files .save_npz (net_d .all_params , name = net_d_iter_name , sess = sess )
155+ # # this version is for future re-check and visualization analysis
156+ # net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
157+ # net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
158+ # tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
159+ # tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
160+ # tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
161+ # tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
169162 print ("[*] Saving checkpoints SUCCESS!" )
170163
171164if __name__ == '__main__' :
0 commit comments