forked from NVlabs/FoundationPose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_demo.py
79 lines (62 loc) · 3.16 KB
/
run_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from estimater import *
from datareader import *
import argparse
if __name__=='__main__':
parser = argparse.ArgumentParser()
code_dir = os.path.dirname(os.path.realpath(__file__))
parser.add_argument('--mesh_file', type=str, default=f'{code_dir}/demo_data/mustard0/mesh/textured_simple.obj')
parser.add_argument('--test_scene_dir', type=str, default=f'{code_dir}/demo_data/mustard0')
parser.add_argument('--est_refine_iter', type=int, default=5)
parser.add_argument('--track_refine_iter', type=int, default=2)
parser.add_argument('--debug', type=int, default=1)
parser.add_argument('--debug_dir', type=str, default=f'{code_dir}/debug')
args = parser.parse_args()
set_logging_format()
set_seed(0)
mesh = trimesh.load(args.mesh_file)
debug = args.debug
debug_dir = args.debug_dir
os.system(f'rm -rf {debug_dir}/* && mkdir -p {debug_dir}/track_vis {debug_dir}/ob_in_cam')
to_origin, extents = trimesh.bounds.oriented_bounds(mesh)
bbox = np.stack([-extents/2, extents/2], axis=0).reshape(2,3)
scorer = ScorePredictor()
refiner = PoseRefinePredictor()
glctx = dr.RasterizeCudaContext()
est = FoundationPose(model_pts=mesh.vertices, model_normals=mesh.vertex_normals, mesh=mesh, scorer=scorer, refiner=refiner, debug_dir=debug_dir, debug=debug, glctx=glctx)
logging.info("estimator initialization done")
reader = YcbineoatReader(video_dir=args.test_scene_dir, shorter_side=None, zfar=np.inf)
for i in range(len(reader.color_files)):
logging.info(f'i:{i}')
color = reader.get_color(i)
depth = reader.get_depth(i)
if i==0:
mask = reader.get_mask(0).astype(bool)
pose = est.register(K=reader.K, rgb=color, depth=depth, ob_mask=mask, iteration=args.est_refine_iter)
if debug>=3:
m = mesh.copy()
m.apply_transform(pose)
m.export(f'{debug_dir}/model_tf.obj')
xyz_map = depth2xyzmap(depth, reader.K)
valid = depth>=0.1
pcd = toOpen3dCloud(xyz_map[valid], color[valid])
o3d.io.write_point_cloud(f'{debug_dir}/scene_complete.ply', pcd)
else:
pose = est.track_one(rgb=color, depth=depth, K=reader.K, iteration=args.track_refine_iter)
os.makedirs(f'{debug_dir}/ob_in_cam', exist_ok=True)
np.savetxt(f'{debug_dir}/ob_in_cam/{reader.id_strs[i]}.txt', pose.reshape(4,4))
if debug>=1:
center_pose = [email protected](to_origin)
vis = draw_posed_3d_box(reader.K, img=color, ob_in_cam=center_pose, bbox=bbox)
vis = draw_xyz_axis(color, ob_in_cam=center_pose, scale=0.1, K=reader.K, thickness=3, transparency=0, is_input_rgb=True)
cv2.imshow('1', vis[...,::-1])
cv2.waitKey(1)
if debug>=2:
os.makedirs(f'{debug_dir}/track_vis', exist_ok=True)
imageio.imwrite(f'{debug_dir}/track_vis/{reader.id_strs[i]}.png', vis)