Jitting SMRT to improve processing time #28
ECCCBen
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Below is an example on how to use
Jax
jit
function to compile and improve computation time when running SMRT. See here for more information on Jax.To be able to run the code below, an extra library has to be installed: NumPyro. It can be installed using
conda
on Linux andpip
on Windows.The code below uses parameters and snow properties used in Montpetit et al., 2024:
-17.673293907762805
-17.673293907762805
-17.673293907762805
-17.673293907762805
-17.673293907762805
-17.673293907762805
-17.673293907762805
-17.673293907762805
1.7 s ± 88.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
-17.673294
-17.673294
-17.673294
-17.673294
-17.673294
-17.673294
-17.673294
-17.673294
The slowest run took 18.16 times longer than the fastest. This could mean that an intermediate result is being cached.
48.6 µs ± 78.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Jitted SMRT provides same output with lower precision (1e-6) but runs ~35 000 times faster than normal SMRT on average...
Beta Was this translation helpful? Give feedback.
All reactions