Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMFF Converts all of the JAX Code to Float64 #194

Open
jacob-jacob-jacob opened this issue Mar 12, 2025 · 1 comment
Open

DMFF Converts all of the JAX Code to Float64 #194

jacob-jacob-jacob opened this issue Mar 12, 2025 · 1 comment
Labels
wontfix This will not be worked on

Comments

@jacob-jacob-jacob
Copy link

Summary

Thank you for the great library!

DMFF sets JAX to use float64. Therefore any code using DMFF has to be refactored if one wants to keep using Float32.

DMFF Version

--

JAX Version

0.4.20

OpenMM Version

--

Python Version, CUDA Version, GCC Version, Operating System Version etc

No response

Details

DMFF sets JAX to use float64.
Therefore, any allocation (like jnp.zeros, random.normal, jnp.linspace, ...) will from then on output float64 which is different to the default of float32 which is also the most common choice for neural networks in general.

To keep using your old code base with float32 instead of float64 (float64 would slow down the neural network training a lot) the full code base has to be refactored after importing dmff. Every single allocation statement has to be found and an explicit dype=jnp.float32 has to be added.
This procedure is quite error prone and as long as one overlooks only one of these statements it could happen that your full code or parts of your code now use float64 and slow down the performance of your algorithms.

Is there any way to make this more user-friendly?

@jacob-jacob-jacob jacob-jacob-jacob added the wontfix This will not be worked on label Mar 12, 2025
@KuangYu
Copy link
Collaborator

KuangYu commented Apr 10, 2025

Currently there is a single setting (PRECITION="float" or "double") in dmff/settings.py, which is a global setting supposed to control the precision of the code. Maybe try that out first to see if it does the job? Right now the default setting is indeed "double"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants