-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integration with Python performance tools like JAX #3432
Comments
How about adding a xarray + dask example ? How does a Jupyter notebook that includes parallel processing takes place ? |
Just to add this for the sake of reference: about a year and a half ago I did some experiments with a Numba-based re-implementation of MetPy's CAPE calculations, as shown here: https://github.com/jthielen/cumulonumba/blob/main/examples/cumulonumba_v_metpy_rough_test.ipynb. Key takeaways were that the speed up with Numba was substantial (by two or three orders of magnitude), but that the JIT compilation costs were not insignificant (and perhaps a deal breaker for some use cases). This was yet another factor favoring Cython over Numba for MetPy's purposes. |
Thanks for the add @jthielen ! Have you had the chance to mess around with JAX? I know your quite busy :) But overall I agree that numba is not the solution |
@winash12 , do you have a specific problem in mind to solve with xarray/dask? |
@ThomasMGeo Only a little bit, and not in this context unfortunately! That being said, for some of the underlying array operations (intersection finding, fixed point iteration), my hunch is that a JAX-type approach (given its more functional way of doing things) requiring more refactoring than Numba would. I could be mistaken on that though too! |
@winash12 Interoperability with Dask is one of the major technical areas we are focusing on at the moment. |
Regarding the cython usage how do you propose to take it forward ? Will the cython code be in python and converted to C code by the compiler or can we add C or C++ functions ? The second would need C makefiles for the build to go through plus modifications to the LDPATH etc. From an implementation perspective my question is are you planning to permit usage of cdef functions or purely def functions ? Looking at the implementation of scipy they have many classes that do use cdef functions, If we are planning to use cdef functions then worthwhile to look at xtensor - https://xtensor.readthedocs.io/en/latest/ @ThomasMGeo As an example let us assume I want to calculate potential vorticity (PV) for 4 different times and the input data is present in a single netCDF file. Now can I do the calculation of the PV of the four different time instances in parallel ? Most definitely I can as they are mutually independent data snapshots. For that I need to use dask arrays if I am not mistaken. Last time I attended the con call I recall everyone agreeing that there isn't a notebook yet to do this. |
Actually looking at it again if all we want is a faster version of numpy then I question the need for cython. xtensor has a python wrapper which we can use - https://github.com/xtensor-stack/xtensor-python |
@winash12 I really need to update our roadmap with this stuff (#1655) but the plan is to only update particular places in the code that are bottlenecks to doing calculations at scale--and to see what's slow through benchmarks. The top offender that comes to mind is CAPE/CIN, mostly due to That does not imply we're looking at general solutions for a faster numpy. It is really important for ease of maintenance and contributions from the community that we stick to Python. The nature of @ThomasMGeo's investigation was really to look at how well people using tools like JAX or CuPy can pass data from those libraries (which are numpy-like) into MetPy and have things "just work". We have no plans to depend on them, however. The same can be said about our plans for supporting Dask--we want to make sure we facilitate workflows using Dask (like the one you described for multiple levels of PV analysis), but we will not be using Dask directly within MetPy. Currently on the table are:
The leader is Cython due to how commonplace it is within the scientific Python ecosystem. Also, I am heavily interested in the ability to run Python (with MetPy) within web browsers, so any solution chosen needs to be amenable to WASM (Web Assembly), so that likely rules out Numba. Rust/C++ are included for completeness (Rust mainly because there's a lot of momentum there, but I'm unclear on the numpy integration story), but I'm 95% sure we're going down the Cython route. |
@Z-Richard mentioned I should drop my code into a public repo The code has a single runtime dependency on This is dcape |
Just wanted to post this here it case it is helpful for folks working further on this: at SciPy I dug a bit into the problem of utilizing an ODE library on the compiled code side of things (C/C++/Cython/Rust/Fortran). While there are a lot of options out there, for what we need, something on the level of LSODA would likely be best (need automatic stiffness handling, but definitely not all the vector-based complexity of something like CVODES/Sundials). This led me to https://github.com/dilawar/libsoda-cxx, which is a C++ port (of a C port of original Fortran if I recall correctly) of LSODA, as it is used in https://github.com/Nicholaswogan/numbalsoda. I unfortunately didn't get to the point of a full demo of a moist lapse rate calculation, however, so I'm not fully confident that this is the best approach. |
What should we add?
There are a few ways to speed up raw numpy calculations. I looked at two options:
This was tested on a M2 MacBook Pro (so no NVIDIA GPU) but I didn't have any trouble installing either package. Overall for two basic calcualtions, I saw speed ups on the order of 3-10x. These results are just from a few hours of hacking, and not intended to be strict benchmarks.
My Take
JAX was much easier to use, and faster. It felt like much more 'drop in' as a numpy replacement than re-writing JIT'd functions that didn't support all of numpy's functionality. If I needed to write faster numpy code, it was straightforward doing so with JAX, with or without a GPU.
Future packages or workflows to consider:
CuPy, downside is this requires a CUDA enabled GPU
Cython might be another option
Multiprocessing?
Notebook
Simple test notebook is here.
Reference
No response
The text was updated successfully, but these errors were encountered: