Skip to content

Commit

Permalink
Use .scalar_type() in AT_DISPATCH
Browse files Browse the repository at this point in the history
As `AT_DISPATCH_FLOATING_TYPES(x.type()`  pattern is deprecated since pytorch/pytorch#17996
  • Loading branch information
malfet authored and Luthaf committed Nov 7, 2024
1 parent 877370a commit b003104
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sphericart-torch/src/torch_cuda_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "cuda_base.hpp"
#include "sphericart.hpp"

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_SAME_DTYPE(x, y) \
TORCH_CHECK(x.scalar_type() == y.scalar_type(), #x " and " #y " must have the same dtype.")
Expand All @@ -34,7 +34,7 @@ torch::Tensor sphericart_torch::spherical_harmonics_backward_cuda(
xyz_grad = torch::empty_like(xyz);

AT_DISPATCH_FLOATING_TYPES(
xyz.type(), "spherical_harmonics_backward_cuda", ([&] {
xyz.scalar_type(), "spherical_harmonics_backward_cuda", ([&] {
sphericart::cuda::spherical_harmonics_backward_cuda_base<scalar_t>(
dsph.data_ptr<scalar_t>(),
sph_grad.data_ptr<scalar_t>(),
Expand All @@ -48,4 +48,4 @@ torch::Tensor sphericart_torch::spherical_harmonics_backward_cuda(
}
// synchronization happens within spherical_harmonics_backward_cuda_base
return xyz_grad;
}
}

0 comments on commit b003104

Please sign in to comment.