-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch_cluster.radius
is not compatible with the CUDA graph
#123
Comments
Thanks for reporting. I need to look into this. Any help or reference is highly appreciated as I'm pretty unfamiliar with CUDA graphs. |
In order to "compile" a piece of code into a CUDA graph one must run it first under what is called "capture mode". Basically a dry run in which CUDA just registers the kernel launches and their arguments. This piece of code must comply with a series of restrictions in order to be CUDA-graph compatible, broadly speaking:
In the particular case of the radius kernel there are several things preventing the function to capture. For instance, this suffers from CPU-GPU sync: pytorch_cluster/torch_cluster/radius.py Lines 55 to 61 in 82e9df9
You could rewrite it as: batch_size = torch.tensor(1)
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = batch_x.max().to(dtype=torch.int) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = torch.max(batch_size, batch_y.max().to(dtype=torch.int) + 1) Additionally, this line is not a static control flow: pytorch_cluster/torch_cluster/radius.py Line 65 in 82e9df9
since batch_size is dependent on the contents of the input tensors. @raimis example does not utilize this code path, since batch_x/y are not being passed. For his example, I believe this line is the culprit: pytorch_cluster/csrc/cuda/radius_cuda.cu Line 93 in 82e9df9
masked_select requires synchronization |
Thanks for this insightful issue. Do you have interest to fix the aforementioned problems? Would that be straightforward to integrate? |
We use this functionality in several places, so I am eager to help. I am really new to torch though, so my torch-fu is not very good -.- |
torch_cluster.radius
is not compatible with the CUDA graph (https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs)The text was updated successfully, but these errors were encountered: