diff --git a/nerf/nerf_helpers.py b/nerf/nerf_helpers.py index 4fcb372..520250c 100644 --- a/nerf/nerf_helpers.py +++ b/nerf/nerf_helpers.py @@ -103,9 +103,11 @@ def get_ray_bundle( ], dim=-1, ) - ray_directions = torch.sum( - directions[..., None, :] * tform_cam2world[:3, :3], dim=-1 - ) + ray_directions = torch.tensordot( + tform_cam2world[:3, :3], + directions.T, + dims=([1], [0]) + ).T ray_origins = tform_cam2world[:3, -1].expand(ray_directions.shape) return ray_origins, ray_directions