@@ -13,8 +13,7 @@ namespace graphbolt {
13
13
namespace ops {
14
14
15
15
torch::Tensor IndexSelect (torch::Tensor input, torch::Tensor index) {
16
- if (input.is_pinned () &&
17
- (index.is_pinned () || index.device ().type () == c10::DeviceType::CUDA)) {
16
+ if (utils::is_on_gpu (index) && input.is_pinned ()) {
18
17
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE (
19
18
c10::DeviceType::CUDA, " UVAIndexSelect" ,
20
19
{ return UVAIndexSelectImpl (input, index); });
@@ -26,9 +25,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
26
25
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
27
26
TORCH_CHECK (
28
27
indices.sizes ().size () == 1 , " IndexSelectCSC only supports 1d tensors" );
29
- if (utils::is_accessible_from_gpu (indptr) &&
30
- utils::is_accessible_from_gpu (indices) &&
31
- utils::is_accessible_from_gpu (nodes)) {
28
+ if (utils::is_on_gpu (nodes) && utils::is_accessible_from_gpu (indptr) &&
29
+ utils::is_accessible_from_gpu (indices)) {
32
30
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE (
33
31
c10::DeviceType::CUDA, " IndexSelectCSCImpl" ,
34
32
{ return IndexSelectCSCImpl (indptr, indices, nodes); });
0 commit comments