Skip to content

Commit 0584744

Browse files
committed
πŸ’… Improved API to return a result object.
1 parent f317890 commit 0584744

File tree

5 files changed

+84
-54
lines changed

5 files changed

+84
-54
lines changed

β€Žsrc/texturize/api.pyβ€Ž

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# neural-texturize β€” Copyright (c) 2020, Novelty Factory KG. See LICENSE for details.
22

33
import os
4+
import collections
45
import progressbar
56

67
import torch
@@ -77,18 +78,20 @@ class NotebookLog:
7778
class ProgressBar:
7879
def __init__(self, max_iter):
7980
import ipywidgets
81+
8082
self.bar = ipywidgets.IntProgress(
8183
value=0,
8284
min=0,
8385
max=max_iter,
8486
step=1,
85-
description='',
86-
bar_style='',
87-
orientation='horizontal',
87+
description="",
88+
bar_style="",
89+
orientation="horizontal",
8890
layout=ipywidgets.Layout(width="100%"),
8991
)
9092

9193
from IPython.display import display
94+
9295
display(self.bar)
9396

9497
def update(self, value, **keywords):
@@ -100,10 +103,17 @@ def finish(self):
100103
def create_progress_bar(self, iterations):
101104
return NotebookLog.ProgressBar(iterations)
102105

103-
def debug(self, *args): pass
104-
def notice(self, *args): pass
105-
def info(self, *args): pass
106-
def warn(self, *args): pass
106+
def debug(self, *args):
107+
pass
108+
109+
def notice(self, *args):
110+
pass
111+
112+
def info(self, *args):
113+
pass
114+
115+
def warn(self, *args):
116+
pass
107117

108118

109119
def get_default_log():
@@ -114,6 +124,11 @@ def get_default_log():
114124
return EmptyLog()
115125

116126

127+
Result = collections.namedtuple(
128+
"Result", ["images", "loss", "octave", "scale", "iteration"]
129+
)
130+
131+
117132
@torch.no_grad()
118133
def process_octaves(
119134
source,
@@ -127,12 +142,12 @@ def process_octaves(
127142
device: str = None,
128143
precision: str = None,
129144
):
130-
# Setup the output and logging to use throughout the synthesis.
145+
# Setup the output and logging to use throughout the synthesis.
131146
log = log or get_default_log()
132147

133148
# Determine which device and dtype to use by default, then set it up.
134149
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
135-
precision = getattr(torch, precision or "float32")
150+
precision = getattr(torch, precision or "float32")
136151

137152
# Load the original image, always on the host device to save memory.
138153
texture_img = load_tensor_from_image(source, device="cpu").to(dtype=precision)
@@ -171,22 +186,18 @@ def process_octaves(
171186
).to(dtype=precision)
172187

173188
# Coarse-to-fine rendering, number of octaves specified by user.
174-
for i, octave in enumerate(2 ** s for s in range(octaves - 1, -1, -1)):
175-
if i == 5:
176-
precision = torch.float16
177-
encoder = encoder.half()
178-
189+
for octave, scale in enumerate(2 ** s for s in range(octaves - 1, -1, -1)):
179190
# Each octave we start a new optimization process.
180191
synth = TextureSynthesizer(
181192
device, encoder, lr=1.0, threshold=threshold, max_iter=iterations,
182193
)
183-
log.info(f"\n OCTAVE #{i} ")
184-
log.debug("<- scale:", f"1/{octave}")
194+
log.info(f"\n OCTAVE #{octave} ")
195+
log.debug("<- scale:", f"1/{scale}")
185196

186197
# Create downscaled version of original texture to match this octave.
187198
texture_cur = F.interpolate(
188199
texture_img,
189-
scale_factor=1.0 / octave,
200+
scale_factor=1.0 / scale,
190201
mode="area",
191202
recompute_scale_factor=False,
192203
).to(device=device, dtype=precision)
@@ -195,7 +206,7 @@ def process_octaves(
195206
del texture_cur
196207

197208
# Compute the seed image for this octave, sprinkling a bit of gaussian noise.
198-
result_size = size[1] // octave, size[0] // octave
209+
result_size = size[1] // scale, size[0] // scale
199210
seed_img = F.interpolate(
200211
result_img, result_size, mode="bicubic", align_corners=False
201212
)
@@ -207,33 +218,41 @@ def process_octaves(
207218

208219
# Now we can enable the automatic gradient computation to run the optimization.
209220
with torch.enable_grad():
210-
for loss, result_img in synth.run(log, seed_img.to(dtype=precision), critics):
221+
for iteration, (loss, result_img) in enumerate(
222+
synth.run(log, seed_img.to(dtype=precision), critics)
223+
):
211224
pass
212225
del synth
213226

214227
output_img = F.interpolate(
215228
result_img, size=(size[1], size[0]), mode="nearest"
216229
).cpu()
217-
yield octave, loss, [
218-
save_tensor_to_image(output_img[j : j + 1])
219-
for j in range(output_img.shape[0])
220-
]
230+
yield Result(
231+
loss=loss,
232+
octave=octave,
233+
scale=scale,
234+
iteration=iteration,
235+
images=[
236+
save_tensor_to_image(output_img[j : j + 1])
237+
for j in range(output_img.shape[0])
238+
],
239+
)
221240
del output_img
222241

223242

224243
def process_single_file(source, log: object, output: str = None, **config: dict):
225-
for octave, _, result_img in process_octaves(
244+
for result in process_octaves(
226245
load_image_from_file(source), log, **config
227246
):
228247
filenames = []
229-
for i, result in enumerate(result_img):
248+
for i, image in enumerate(result.images):
230249
# Save the files for each octave to disk.
231250
filename = output.format(
232-
octave=octave,
251+
octave=result.octave,
233252
source=os.path.splitext(os.path.basename(source))[0],
234253
variation=i,
235254
)
236-
result.save(filename)
255+
image.save(filename)
237256
log.debug("\n=> output:", filename)
238257
filenames.append(filename)
239258

β€Žsrc/texturize/io.pyβ€Ž

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# neural-texturize β€” Copyright (c) 2020, Novelty Factory KG. See LICENSE for details.
22

3+
import urllib
34
import asyncio
45
from io import BytesIO
56

@@ -20,6 +21,12 @@ def load_tensor_from_image(image, device):
2021
return V.to_tensor(image).unsqueeze(0).to(device)
2122

2223

24+
def load_image_from_url(url, mode="RGB"):
25+
response = urllib.request.urlopen(url)
26+
buffer = BytesIO(response.read())
27+
return PIL.Image.open(buffer).convert(mode)
28+
29+
2330
def save_tensor_to_file(tensor, filename, mode="RGB"):
2431
img = save_tensor_to_image(tensor)
2532
img.save(filename)
@@ -38,10 +45,18 @@ def save_tensor_to_image(tensor, mode="RGB"):
3845
pass
3946

4047

41-
def show_result_in_notebook(images):
48+
def show_result_in_notebook(result, title="Generated Image"):
4249
clear_output()
43-
for out in images:
44-
html = ipywidgets.HTML(value="<h3>Octave #1</h3>")
50+
51+
for out in result.images:
52+
html = ipywidgets.HTML(value=f"""
53+
<h3>{title}</h3>
54+
<ul style="font-size: 16px;">
55+
<li>octave: {result.octave}</li>
56+
<li>size: {out.size}</li>
57+
<li>scale: 1/{result.scale}</li>
58+
<li>loss: {result.loss:0.4f}</li>
59+
</ul>""")
4560

4661
buffer = io.BytesIO()
4762
out.save(buffer, format="webp", quality=80)

β€Žtests/app_gram.pyβ€Ž

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,18 @@
77

88

99
def test_gram_single(image, size=(96, 88)):
10-
for _, loss, images in process_octaves(
11-
image(size), octaves=2, size=size, mode="gram"
12-
):
13-
assert len(images) == 1
14-
assert all(isinstance(img, PIL.Image.Image) for img in images)
15-
assert all(img.size == size for img in images)
16-
assert loss < 5e-2
10+
for result in process_octaves(image(size), octaves=2, size=size, mode="gram"):
11+
assert len(result.images) == 1
12+
assert all(isinstance(img, PIL.Image.Image) for img in result.images)
13+
assert all(img.size == size for img in result.images)
14+
assert result.loss < 5e-2
1715

1816

1917
def test_gram_variations(image, size=(72, 64)):
20-
for _, loss, images in process_octaves(
18+
for result in process_octaves(
2119
image(size), variations=2, octaves=2, size=size, mode="gram"
2220
):
23-
assert len(images) == 2
24-
assert all(isinstance(img, PIL.Image.Image) for img in images)
25-
assert all(img.size == size for img in images)
26-
assert loss < 5e-1
21+
assert len(result.images) == 2
22+
assert all(isinstance(img, PIL.Image.Image) for img in result.images)
23+
assert all(img.size == size for img in result.images)
24+
assert result.loss < 5e-1

β€Žtests/app_hist.pyβ€Ž

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77

88

99
def test_hist_single(image, size=(32, 48)):
10-
for _, loss, images in process_octaves(
11-
image(size), octaves=2, size=size, mode="hist"
12-
):
13-
assert len(images) == 1
14-
assert all(isinstance(img, PIL.Image.Image) for img in images)
15-
assert all(img.size == size for img in images)
16-
assert loss < 1e-1
10+
for result in process_octaves(image(size), octaves=2, size=size, mode="hist"):
11+
assert len(result.images) == 1
12+
assert all(isinstance(img, PIL.Image.Image) for img in result.images)
13+
assert all(img.size == size for img in result.images)
14+
assert result.loss < 1e-1

β€Žtests/app_patch.pyβ€Ž

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88

99
def test_patch_single(image, size=(64, 48)):
10-
for _, loss, images in process_octaves(
10+
for result in process_octaves(
1111
image(size), octaves=2, size=size, mode="patch", threshold=1e-3
1212
):
13-
assert len(images) == 1
14-
assert all(isinstance(img, PIL.Image.Image) for img in images)
15-
assert all(img.size == size for img in images)
16-
assert loss < 5.0
13+
assert len(result.images) == 1
14+
assert all(isinstance(img, PIL.Image.Image) for img in result.images)
15+
assert all(img.size == size for img in result.images)
16+
assert result.loss < 5.0

0 commit comments

Comments
Β (0)