Skip to content
This repository has been archived by the owner on Jul 30, 2024. It is now read-only.

Accelerate get_ray_bundle() #29

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Depersonalizc
Copy link

A tiny tweak of using torch.tensordot() to compute c2w ray transform, up to ~6x acceleration.

A tiny tweak of using torch.tensordot() to compute c2w ray transform, up to ~6x acceleration.
@BjoernHaefner
Copy link

I did do

ray_directions = torch.einsum("mn,whn->whm", tform_cam2world[:3, :3], directions)

As the proposed version with tensordot resulted in a UserWarning with PyTorch 1.13

UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541702/work/aten/src/ATen/native/TensorShape.cpp:3277.)

This also gave another speedup:

start = time.time()
for i in range(1000):
    torch.einsum("mn,whn->whm", tform_cam2world[:3, :3], directions)
stop = time.time()
print(f"time: {(stop-start)} [einsum]")

start = time.time()
for i in range(1000):
    torch.sum(directions[..., None, :] * tform_cam2world[:3, :3], dim=-1)
stop = time.time()
print(f"time: {(stop-start)} [sum]")

start = time.time()
for i in range(1000):
    torch.tensordot(tform_cam2world[:3, :3], directions.T, dims=([1], [0])).T
stop = time.time()
print(f"time: {(stop-start)} [tensordot]")


time: 0.06764698028564453 [einsum]
time: 0.29248738288879395 [sum]
time: 0.10972046852111816 [tensordot]

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants