Skip to content

Commit

Permalink
Update raymarching.py
Browse files Browse the repository at this point in the history
Fix the FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  • Loading branch information
pergyz authored Aug 10, 2024
1 parent cc8fcba commit 8d6c9dd
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions ernerf/raymarching/raymarching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.amp import custom_bwd, custom_fwd

try:
import _raymarching_face as _backend
Expand All @@ -17,7 +17,7 @@

class _near_far_from_aabb(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
''' near_far_from_aabb, CUDA implementation
Calculate rays' intersection time (near and far) with aabb
Expand Down Expand Up @@ -50,7 +50,7 @@ def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):

class _sph_from_ray(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, rays_o, rays_d, radius):
''' sph_from_ray, CUDA implementation
get spherical coordinate on the background sphere from rays.
Expand Down Expand Up @@ -127,7 +127,7 @@ def forward(ctx, indices):

class _packbits(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, grid, thresh, bitfield=None):
''' packbits, CUDA implementation
Pack up the density grid into a bit field to accelerate ray marching.
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(ctx, grid, thresh, bitfield=None):

class _morton3D_dilation(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, grid):
''' max pooling with morton coord, CUDA implementation
or maybe call it dilation... we don't support adjust kernel size.
Expand Down Expand Up @@ -185,7 +185,7 @@ def forward(ctx, grid):

class _march_rays_train(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
''' march rays to generate points (forward only)
Args:
Expand Down Expand Up @@ -261,7 +261,7 @@ def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, ste

# to support optimizing camera poses.
@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays):
# grad_xyzs/dirs: [M, 3]

Expand All @@ -282,7 +282,7 @@ def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays):

class _composite_rays_train(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
Expand Down Expand Up @@ -317,7 +317,7 @@ def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
return weights_sum, ambient_sum, depth, image

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):

# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
Expand Down Expand Up @@ -346,7 +346,7 @@ def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):

class _march_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
''' march rays to generate points (forward only, for inference)
Args:
Expand Down Expand Up @@ -400,7 +400,7 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, den

class _composite_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
@custom_fwd(cast_inputs=torch.float32, device_type='cuda') # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
''' composite rays' rgbs, according to the ray marching formula. (for inference)
Args:
Expand All @@ -425,7 +425,7 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weig

class _composite_rays_ambient(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
@custom_fwd(cast_inputs=torch.float32, device_type='cuda') # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
_backend.composite_rays_ambient(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
return tuple()
Expand All @@ -441,7 +441,7 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambi

class _composite_rays_train_sigma(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
Expand Down Expand Up @@ -476,7 +476,7 @@ def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
return weights_sum, ambient_sum, depth, image

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):

# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
Expand All @@ -502,7 +502,7 @@ def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):

class _composite_rays_ambient_sigma(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
@custom_fwd(cast_inputs=torch.float32, device_type='cuda') # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
_backend.composite_rays_ambient_sigma(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
return tuple()
Expand All @@ -515,7 +515,7 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambi
# uncertainty
class _composite_rays_train_uncertainty(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, sigmas, rgbs, ambient, uncertainty, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
Expand Down Expand Up @@ -552,7 +552,7 @@ def forward(ctx, sigmas, rgbs, ambient, uncertainty, deltas, rays, T_thresh=1e-4
return weights_sum, ambient_sum, uncertainty_sum, depth, image

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_depth, grad_image):

# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
Expand Down Expand Up @@ -580,7 +580,7 @@ def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad

class _composite_rays_uncertainty(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
@custom_fwd(cast_inputs=torch.float32, device_type='cuda') # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh=1e-2):
_backend.composite_rays_uncertainty(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum)
return tuple()
Expand All @@ -593,7 +593,7 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambi
# triplane(eye)
class _composite_rays_train_triplane(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
Expand Down Expand Up @@ -632,7 +632,7 @@ def forward(ctx, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, T_th
return weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_depth, grad_image):

# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
Expand Down Expand Up @@ -662,10 +662,10 @@ def backward(ctx, grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_unc

class _composite_rays_triplane(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
@custom_fwd(cast_inputs=torch.float32, device_type='cuda') # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh=1e-2):
_backend.composite_rays_triplane(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum)
return tuple()


composite_rays_triplane = _composite_rays_triplane.apply
composite_rays_triplane = _composite_rays_triplane.apply

0 comments on commit 8d6c9dd

Please sign in to comment.