11# neural-texturize β Copyright (c) 2020, Novelty Factory KG. See LICENSE for details.
22
33import os
4+ import collections
45import progressbar
56
67import torch
@@ -77,18 +78,20 @@ class NotebookLog:
7778 class ProgressBar :
7879 def __init__ (self , max_iter ):
7980 import ipywidgets
81+
8082 self .bar = ipywidgets .IntProgress (
8183 value = 0 ,
8284 min = 0 ,
8385 max = max_iter ,
8486 step = 1 ,
85- description = '' ,
86- bar_style = '' ,
87- orientation = ' horizontal' ,
87+ description = "" ,
88+ bar_style = "" ,
89+ orientation = " horizontal" ,
8890 layout = ipywidgets .Layout (width = "100%" ),
8991 )
9092
9193 from IPython .display import display
94+
9295 display (self .bar )
9396
9497 def update (self , value , ** keywords ):
@@ -100,10 +103,17 @@ def finish(self):
100103 def create_progress_bar (self , iterations ):
101104 return NotebookLog .ProgressBar (iterations )
102105
103- def debug (self , * args ): pass
104- def notice (self , * args ): pass
105- def info (self , * args ): pass
106- def warn (self , * args ): pass
106+ def debug (self , * args ):
107+ pass
108+
109+ def notice (self , * args ):
110+ pass
111+
112+ def info (self , * args ):
113+ pass
114+
115+ def warn (self , * args ):
116+ pass
107117
108118
109119def get_default_log ():
@@ -114,6 +124,11 @@ def get_default_log():
114124 return EmptyLog ()
115125
116126
127+ Result = collections .namedtuple (
128+ "Result" , ["images" , "loss" , "octave" , "scale" , "iteration" ]
129+ )
130+
131+
117132@torch .no_grad ()
118133def process_octaves (
119134 source ,
@@ -127,12 +142,12 @@ def process_octaves(
127142 device : str = None ,
128143 precision : str = None ,
129144):
130- # Setup the output and logging to use throughout the synthesis.
145+ # Setup the output and logging to use throughout the synthesis.
131146 log = log or get_default_log ()
132147
133148 # Determine which device and dtype to use by default, then set it up.
134149 device = torch .device (device or ("cuda" if torch .cuda .is_available () else "cpu" ))
135- precision = getattr (torch , precision or "float32" )
150+ precision = getattr (torch , precision or "float32" )
136151
137152 # Load the original image, always on the host device to save memory.
138153 texture_img = load_tensor_from_image (source , device = "cpu" ).to (dtype = precision )
@@ -171,22 +186,18 @@ def process_octaves(
171186 ).to (dtype = precision )
172187
173188 # Coarse-to-fine rendering, number of octaves specified by user.
174- for i , octave in enumerate (2 ** s for s in range (octaves - 1 , - 1 , - 1 )):
175- if i == 5 :
176- precision = torch .float16
177- encoder = encoder .half ()
178-
189+ for octave , scale in enumerate (2 ** s for s in range (octaves - 1 , - 1 , - 1 )):
179190 # Each octave we start a new optimization process.
180191 synth = TextureSynthesizer (
181192 device , encoder , lr = 1.0 , threshold = threshold , max_iter = iterations ,
182193 )
183- log .info (f"\n OCTAVE #{ i } " )
184- log .debug ("<- scale:" , f"1/{ octave } " )
194+ log .info (f"\n OCTAVE #{ octave } " )
195+ log .debug ("<- scale:" , f"1/{ scale } " )
185196
186197 # Create downscaled version of original texture to match this octave.
187198 texture_cur = F .interpolate (
188199 texture_img ,
189- scale_factor = 1.0 / octave ,
200+ scale_factor = 1.0 / scale ,
190201 mode = "area" ,
191202 recompute_scale_factor = False ,
192203 ).to (device = device , dtype = precision )
@@ -195,7 +206,7 @@ def process_octaves(
195206 del texture_cur
196207
197208 # Compute the seed image for this octave, sprinkling a bit of gaussian noise.
198- result_size = size [1 ] // octave , size [0 ] // octave
209+ result_size = size [1 ] // scale , size [0 ] // scale
199210 seed_img = F .interpolate (
200211 result_img , result_size , mode = "bicubic" , align_corners = False
201212 )
@@ -207,33 +218,41 @@ def process_octaves(
207218
208219 # Now we can enable the automatic gradient computation to run the optimization.
209220 with torch .enable_grad ():
210- for loss , result_img in synth .run (log , seed_img .to (dtype = precision ), critics ):
221+ for iteration , (loss , result_img ) in enumerate (
222+ synth .run (log , seed_img .to (dtype = precision ), critics )
223+ ):
211224 pass
212225 del synth
213226
214227 output_img = F .interpolate (
215228 result_img , size = (size [1 ], size [0 ]), mode = "nearest"
216229 ).cpu ()
217- yield octave , loss , [
218- save_tensor_to_image (output_img [j : j + 1 ])
219- for j in range (output_img .shape [0 ])
220- ]
230+ yield Result (
231+ loss = loss ,
232+ octave = octave ,
233+ scale = scale ,
234+ iteration = iteration ,
235+ images = [
236+ save_tensor_to_image (output_img [j : j + 1 ])
237+ for j in range (output_img .shape [0 ])
238+ ],
239+ )
221240 del output_img
222241
223242
224243def process_single_file (source , log : object , output : str = None , ** config : dict ):
225- for octave , _ , result_img in process_octaves (
244+ for result in process_octaves (
226245 load_image_from_file (source ), log , ** config
227246 ):
228247 filenames = []
229- for i , result in enumerate (result_img ):
248+ for i , image in enumerate (result . images ):
230249 # Save the files for each octave to disk.
231250 filename = output .format (
232- octave = octave ,
251+ octave = result . octave ,
233252 source = os .path .splitext (os .path .basename (source ))[0 ],
234253 variation = i ,
235254 )
236- result .save (filename )
255+ image .save (filename )
237256 log .debug ("\n => output:" , filename )
238257 filenames .append (filename )
239258
0 commit comments