forked from akanazawa/cmr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
119 lines (89 loc) · 2.98 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Demo of CMR.
Note that CMR assumes that the object has been detected, so please use a picture of a bird that is centered and well cropped.
Sample usage:
python -m cmr.demo --name bird_net --num_train_epoch 500 --img_path cmr/demo_data/img1.jpg
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags, app
import numpy as np
import skimage.io as io
import torch
from .nnutils import test_utils
from .nnutils import predictor as pred_util
from .utils import image as img_util
flags.DEFINE_string('img_path', 'data/im1963.jpg', 'Image to run')
flags.DEFINE_integer('img_size', 256, 'image size the network was trained on.')
opts = flags.FLAGS
def preprocess_image(img_path, img_size=256):
img = io.imread(img_path) / 255.
# Scale the max image size to be img_size
scale_factor = float(img_size) / np.max(img.shape[:2])
img, _ = img_util.resize_img(img, scale_factor)
# Crop img_size x img_size from the center
center = np.round(np.array(img.shape[:2]) / 2).astype(int)
# img center in (x, y)
center = center[::-1]
bbox = np.hstack([center - img_size / 2., center + img_size / 2.])
img = img_util.crop(img, bbox, bgval=1.)
# Transpose the image to 3xHxW
img = np.transpose(img, (2, 0, 1))
return img
def visualize(img, outputs, renderer):
vert = outputs['verts'][0]
cam = outputs['cam_pred'][0]
texture = outputs['texture'][0]
shape_pred = renderer(vert, cam)
img_pred = renderer(vert, cam, texture=texture)
# Different viewpoints.
vp1 = renderer.diff_vp(
vert, cam, angle=30, axis=[0, 1, 0], texture=texture, extra_elev=True)
vp2 = renderer.diff_vp(
vert, cam, angle=60, axis=[0, 1, 0], texture=texture, extra_elev=True)
vp3 = renderer.diff_vp(
vert, cam, angle=60, axis=[1, 0, 0], texture=texture)
img = np.transpose(img, (1, 2, 0))
import matplotlib.pyplot as plt
plt.ion()
plt.figure(1)
plt.clf()
plt.subplot(231)
plt.imshow(img)
plt.title('input')
plt.axis('off')
plt.subplot(232)
plt.imshow(shape_pred)
plt.title('pred mesh')
plt.axis('off')
plt.subplot(233)
plt.imshow(img_pred)
plt.title('pred mesh w/texture')
plt.axis('off')
plt.subplot(234)
plt.imshow(vp1)
plt.title('different viewpoints')
plt.axis('off')
plt.subplot(235)
plt.imshow(vp2)
plt.axis('off')
plt.subplot(236)
plt.imshow(vp3)
plt.axis('off')
plt.draw()
plt.show()
import ipdb
ipdb.set_trace()
def main(_):
img = preprocess_image(opts.img_path, img_size=opts.img_size)
batch = {'img': torch.Tensor(np.expand_dims(img, 0))}
predictor = pred_util.MeshPredictor(opts)
outputs = predictor.predict(batch)
# This is resolution
renderer = predictor.vis_rend
renderer.set_light_dir([0, 1, -1], 0.4)
visualize(img, outputs, predictor.vis_rend)
if __name__ == '__main__':
opts.batch_size = 1
app.run(main)