From e553fd461f3dd39c84cee5f90980fd7f008290d7 Mon Sep 17 00:00:00 2001 From: Depersonalizc <1046770165@qq.com> Date: Wed, 31 Mar 2021 13:40:58 +0800 Subject: [PATCH] Accelerate get_ray_bundle() A tiny tweak of using torch.tensordot() to compute c2w ray transform, up to ~6x acceleration. --- nerf/nerf_helpers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nerf/nerf_helpers.py b/nerf/nerf_helpers.py index 4fcb372..520250c 100644 --- a/nerf/nerf_helpers.py +++ b/nerf/nerf_helpers.py @@ -103,9 +103,11 @@ def get_ray_bundle( ], dim=-1, ) - ray_directions = torch.sum( - directions[..., None, :] * tform_cam2world[:3, :3], dim=-1 - ) + ray_directions = torch.tensordot( + tform_cam2world[:3, :3], + directions.T, + dims=([1], [0]) + ).T ray_origins = tform_cam2world[:3, -1].expand(ray_directions.shape) return ray_origins, ray_directions