5353class GramMatrixCritic :
5454 """A `Critic` evaluates the features of an image to determine how it scores.
5555
56- This is a gram-based critic that computes a 2D histogram of feature cross-correlations for
57- a specific layer, and compares it to the target gram matrix.
56+ This critic computes a 2D histogram of feature cross-correlations for a specific
57+ layer, and compares it to the target gram matrix.
5858 """
5959
6060 def __init__ (self , layer , offset : float = - 1.0 ):
@@ -102,33 +102,44 @@ class SolverLBFGS:
102102 def __init__ (self , objective , image , lr = 1.0 ):
103103 self .objective = objective
104104 self .image = image
105+ self .lr = lr
105106 self .optimizer = torch .optim .LBFGS (
106107 [image ], lr = lr , max_iter = 2 , max_eval = 4 , history_size = 10
107108 )
108109 self .scores = []
110+ self .iteration = 1
109111
110112 def step (self ):
113+ # The first 10 iterations, we increase the learning rate slowly to full value.
114+ for group in self .optimizer .param_groups :
115+ group ["lr" ] = self .lr * min (self .iteration / 10.0 , 1.0 ) ** 2
116+
117+ # Each iteration we reset the accumulated gradients and compute the objective.
111118 def _wrap ():
119+ self .iteration += 1
112120 self .optimizer .zero_grad ()
113121 return self .objective (self .image )
114122
123+ # This optimizer decides when and how to call the objective.
115124 return self .optimizer .step (_wrap )
116125
117126
118127class MultiCriticObjective :
119- """An `Objective` that defines a problem by evaluating possible solutions (i.e. images).
128+ """An `Objective` that defines a problem to be solved by evaluating candidate
129+ solutions (i.e. images) and returning an error.
120130
121- This objective evaluates a list of critics to produce a final "loss" that's the sum of all the
122- scores returned by the critics. It's also responsible for computing the gradients.
131+ This objective evaluates a list of critics to produce a final "loss" that's the sum
132+ of all the scores returned by the critics. It's also responsible for computing the
133+ gradients.
123134 """
124135
125136 def __init__ (self , encoder , critics ):
126137 self .encoder = encoder
127138 self .critics = critics
128139
129140 def __call__ (self , image ):
130- """Main evaluation function that's called by the solver. Processes the image, computes the
131- gradients, and returns the loss.
141+ """Main evaluation function that's called by the solver. Processes the image,
142+ computes the gradients, and returns the loss.
132143 """
133144
134145 image .data .clamp_ (0.0 , 1.0 )
@@ -202,8 +213,10 @@ def run(self, seed_img, critics):
202213
203214 # See if we can terminate the optimization early.
204215 if previous is not None and abs (loss - previous ) < self .precision :
216+ assert i > 10 , f"Optimization stalled at iteration { i } ."
205217 progress .max_value = i
206218 break
219+
207220 previous = loss
208221
209222 progress .finish ()
@@ -240,7 +253,7 @@ def run(config, source):
240253 # Each octave we start a new optimization process.
241254 synth = TextureSynthesizer (
242255 encoder ,
243- lr = 0.5 ,
256+ lr = 1.0 ,
244257 precision = float (config ["--precision" ]),
245258 max_iter = int (config ["--iterations" ]),
246259 )
@@ -260,7 +273,7 @@ def run(config, source):
260273 # Compute the seed image for this octave, sprinkling a bit of gaussian noise.
261274 size = result_sz [0 ] // scale , result_sz [1 ] // scale
262275 seed_img = F .interpolate (result_img , size , mode = "bicubic" , align_corners = False )
263- seed_img += torch .empty_like (seed_img , dtype = torch .float32 ).normal_ (std = 0.1 )
276+ seed_img += torch .empty_like (seed_img , dtype = torch .float32 ).normal_ (std = 0.2 )
264277 print ("<- seed:" , tuple (seed_img .shape [2 :]), end = "\n \n " )
265278 del result_img
266279
0 commit comments