-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding torch.compile
compatiblity
#9
Conversation
Hi @mitkotak, sorry about the delayed response - I didn't have notifications for this repo turned on(?). For your specific PR, I'll give it a try as soon as I get a chance, so thank you very much for pointing out the recent optimizations! RE: upstreaming back to |
No worries ! Appreciate the follow up and thanks again for the cool work. Yup agreed on the numerical part. Have been seeing similar funky behavior but hopefully we can contain it with some rigorous numerical asserts. Also do you guys have a TF32 equivalent for XPUs cause I have seen the numerics get funky with TF32 so was wondering what it might look like on XPUs If you are interested, I am happy to put all of this into our experimental folder and then we can test out the numerics with downstream users who might be interested in testing this out. |
XPUs do have TF32 support, but AFAIK it's not enabled by default for the same reason as why it's no longer default for PyTorch in general since it's to do with the RE: experimental folder we would like to see these kernels get used somehow, and I think that could be a good way to go. Just so you're aware we also have the monkey patching option here as an under-the-hood replacement, so in principle users could also install this repo and just import the patch to not have to make further code changes. That way if these kernels get updated, you won't have to manually port things over back to About a month or two ago I'd also worked on a branch in my fork that I haven't had time to merge in yet, but we also wrote a paper about the irreducible representation latent space. If you're interested in reading it/chatting about it, I can send you a copy to your email. |
Yup training is fine since we even use mixed-precision there. Yup not stable enough for MD.
Ahh gotcha cool yup that makes sense then. Yeah we can just have a flag that folks can turn on if they need the kernels. We will just have to make sure they don't cause any graph breaks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @mitkotak for the PR again! I checked out and tested your suggestions, and I'm happy to merge them in because it's the "optimal" comparison.
I can talk about the results concretely offline; at the very least I can say it works and I'm able to reproduce the relative speedups with a nightly build of PyTorch 2.6 (2.6.0a0+git487873f
). I haven't got an A5500 to test, and just looking at the spec sheet in comparison to the Intel GPUs (specifically GPU Max 1100 was what I used to reproduce), it does nominally have close to 2x the memory bandwidth compared to the A5500. I haven't checked to see if this is actually memory bandwidth bound, but the fact that you see relative performance drop off at higher node counts might lend itself to explain why you don't see as much relative speedup comparing torch.compile
and the Triton result.
If you have access to a datacenter level (instead of workstation grade) of Nvidia GPU (e.g. A100 or H100) you can also try running on them to see if you get results similar to the A5500, or the Intel GPU results from our first paper. I'll merge this PR to make it easier for you and anyone else to make these comparisons.
Just to follow up on this discussion - I re-ran the benchmarks this morning, and I take back what I said in the previous message! The performance with For tracking this, I've created #17. |
Hey
e3nn
developer here. Thanks for the great work ! Interested in upstreaming the kernels into the main repo if you guys are interested.We recently landed full PT2 compatibility. I re-ran the benchmarks and results were pretty dramatic. This was on an RTX A5500. I was wondering if it's possible to rerun the benchmarks on an XPU.
Interested in figuring out when to pivot to custom kernels and when to let Inductor take charge.
Also please feel free to edit my PR to suit your needs. Just wanted to point out the flag changes.
Thanks again !
Update: Did not check numerical accuracy so maybe that might be leading to some gains ? The flipping of the order is also interesting since I would have expected the approach highlighted in the paper to be better for higher Ls.
Speedup = e3nn time / custom triton kernel time