-
Notifications
You must be signed in to change notification settings - Fork 18
/
utils.py
88 lines (73 loc) · 2.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import torch
import pytorch3d
from pytorch3d.renderer import (
AlphaCompositor,
PointsRasterizationSettings,
PointsRenderer,
PointsRasterizer,
)
import imageio
def save_checkpoint(epoch, model, args, best=False):
if best:
path = os.path.join(args.checkpoint_dir, 'best_model.pt')
else:
path = os.path.join(args.checkpoint_dir, 'model_epoch_{}.pt'.format(epoch))
torch.save(model.state_dict(), path)
def create_dir(directory):
"""
Creates a directory if it does not already exist.
"""
if not os.path.exists(directory):
os.makedirs(directory)
def get_points_renderer(
image_size=256, device=None, radius=0.01, background_color=(1, 1, 1)
):
"""
Returns a Pytorch3D renderer for point clouds.
Args:
image_size (int): The rendered image size.
device (torch.device): The torch device to use (CPU or GPU). If not specified,
will automatically use GPU if available, otherwise CPU.
radius (float): The radius of the rendered point in NDC.
background_color (tuple): The background color of the rendered image.
Returns:
PointsRenderer.
"""
if device is None:
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
raster_settings = PointsRasterizationSettings(image_size=image_size, radius=radius,)
renderer = PointsRenderer(
rasterizer=PointsRasterizer(raster_settings=raster_settings),
compositor=AlphaCompositor(background_color=background_color),
)
return renderer
def viz_seg (verts, labels, path, device):
"""
visualize segmentation result
output: a 360-degree gif
"""
image_size=256
background_color=(1, 1, 1)
colors = [[1.0,1.0,1.0], [1.0,0.0,1.0], [0.0,1.0,1.0],[1.0,1.0,0.0],[0.0,0.0,1.0], [1.0,0.0,0.0]]
# Construct various camera viewpoints
dist = 3
elev = 0
azim = [180 - 12*i for i in range(30)]
R, T = pytorch3d.renderer.cameras.look_at_view_transform(dist=dist, elev=elev, azim=azim, device=device)
c = pytorch3d.renderer.FoVPerspectiveCameras(R=R, T=T, fov=60, device=device)
sample_verts = verts.unsqueeze(0).repeat(30,1,1).to(torch.float)
sample_labels = labels.unsqueeze(0)
sample_colors = torch.zeros((1,10000,3))
# Colorize points based on segmentation labels
for i in range(6):
sample_colors[sample_labels==i] = torch.tensor(colors[i])
sample_colors = sample_colors.repeat(30,1,1).to(torch.float)
point_cloud = pytorch3d.structures.Pointclouds(points=sample_verts, features=sample_colors).to(device)
renderer = get_points_renderer(image_size=image_size, background_color=background_color, device=device)
rend = renderer(point_cloud, cameras=c).cpu().numpy() # (30, 256, 256, 3)
rend = (rend * 255).astype(np.uint8)
imageio.mimsave(path, rend, fps=15)