Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding torch.compile compatiblity #9

Merged
merged 1 commit into from
Oct 24, 2024
Merged

Conversation

mitkotak
Copy link
Contributor

@mitkotak mitkotak commented Sep 1, 2024

Hey e3nn developer here. Thanks for the great work ! Interested in upstreaming the kernels into the main repo if you guys are interested.

We recently landed full PT2 compatibility. I re-ran the benchmarks and results were pretty dramatic. This was on an RTX A5500. I was wondering if it's possible to rerun the benchmarks on an XPU.

Interested in figuring out when to pivot to custom kernels and when to let Inductor take charge.

Also please feel free to edit my PR to suit your needs. Just wanted to point out the flag changes.

Thanks again !

Update: Did not check numerical accuracy so maybe that might be leading to some gains ? The flipping of the order is also interesting since I would have expected the approach highlighted in the paper to be better for higher Ls.

python benchmark.py cuda 1 --min_log_size 2.0 --max_log_size 6.0 && 
python benchmark.py cuda 2 --min_log_size 2.0 --max_log_size 6.0 &&
python benchmark.py cuda 3 --min_log_size 2.0 --max_log_size 6.0 &&
python benchmark.py cuda 4 --min_log_size 2.0 --max_log_size 6.0

Speedup = e3nn time / custom triton kernel time

speedup_comparison

@laserkelvin
Copy link
Contributor

Hi @mitkotak, sorry about the delayed response - I didn't have notifications for this repo turned on(?). For your specific PR, I'll give it a try as soon as I get a chance, so thank you very much for pointing out the recent optimizations!

RE: upstreaming back to e3nn we have thought about, and we can still have that discussion but hesitated mainly because the results aren't numerically exactly the same, given that we're using literals pretty heavily, and didn't want to affect the e3nn user experience (e.g. in case someone uses weights trained without EquiTriton, then uses EquiTriton and gets different results). We are definitely happy to collaborate, and would love to discuss!

@mitkotak
Copy link
Contributor Author

mitkotak commented Oct 4, 2024

No worries ! Appreciate the follow up and thanks again for the cool work. Yup agreed on the numerical part. Have been seeing similar funky behavior but hopefully we can contain it with some rigorous numerical asserts. Also do you guys have a TF32 equivalent for XPUs cause I have seen the numerics get funky with TF32 so was wondering what it might look like on XPUs

If you are interested, I am happy to put all of this into our experimental folder and then we can test out the numerics with downstream users who might be interested in testing this out.

@laserkelvin
Copy link
Contributor

XPUs do have TF32 support, but AFAIK it's not enabled by default for the same reason as why it's no longer default for PyTorch in general since it's to do with the dtype and not so much the hardware implementation. Are your observations for training, or for inference, or both? For MD I would think it make a huge difference, but I would have thought it would be stable enough for training.

RE: experimental folder we would like to see these kernels get used somehow, and I think that could be a good way to go. Just so you're aware we also have the monkey patching option here as an under-the-hood replacement, so in principle users could also install this repo and just import the patch to not have to make further code changes. That way if these kernels get updated, you won't have to manually port things over back to e3nn each time.

About a month or two ago I'd also worked on a branch in my fork that I haven't had time to merge in yet, but we also wrote a paper about the irreducible representation latent space. If you're interested in reading it/chatting about it, I can send you a copy to your email.

@mitkotak
Copy link
Contributor Author

mitkotak commented Oct 4, 2024

XPUs do have TF32 support, but AFAIK it's not enabled by default for the same reason as why it's no longer default for PyTorch in general since it's to do with the dtype and not so much the hardware implementation. Are your observations for training, or for inference, or both? For MD I would think it make a huge difference, but I would have thought it would be stable enough for training.

Yup training is fine since we even use mixed-precision there. Yup not stable enough for MD.

RE: experimental folder we would like to see these kernels get used somehow, and I think that could be a good way to go. Just so you're aware we also have the monkey patching option here as an under-the-hood replacement, so in principle users could also install this repo and just import the patch to not have to make further code changes. That way if these kernels get updated, you won't have to manually port things over back to e3nn each time.

Ahh gotcha cool yup that makes sense then. Yeah we can just have a flag that folks can turn on if they need the kernels. We will just have to make sure they don't cause any graph breaks.

Copy link
Contributor

@laserkelvin laserkelvin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @mitkotak for the PR again! I checked out and tested your suggestions, and I'm happy to merge them in because it's the "optimal" comparison.

I can talk about the results concretely offline; at the very least I can say it works and I'm able to reproduce the relative speedups with a nightly build of PyTorch 2.6 (2.6.0a0+git487873f). I haven't got an A5500 to test, and just looking at the spec sheet in comparison to the Intel GPUs (specifically GPU Max 1100 was what I used to reproduce), it does nominally have close to 2x the memory bandwidth compared to the A5500. I haven't checked to see if this is actually memory bandwidth bound, but the fact that you see relative performance drop off at higher node counts might lend itself to explain why you don't see as much relative speedup comparing torch.compile and the Triton result.

If you have access to a datacenter level (instead of workstation grade) of Nvidia GPU (e.g. A100 or H100) you can also try running on them to see if you get results similar to the A5500, or the Intel GPU results from our first paper. I'll merge this PR to make it easier for you and anyone else to make these comparisons.

@laserkelvin
Copy link
Contributor

Just to follow up on this discussion - I re-ran the benchmarks this morning, and I take back what I said in the previous message! The performance with e3nn + torch.compile seems to be better than the first version of these kernels. I'll need to do more testing, but it does seem like torch.compile changes do make the first version kernels less dramatic in performance difference.

For tracking this, I've created #17.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants