Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adds the ability to save a function on disk with flexible input shapes using JAX export module.
The decorator
cacheable_function_with_export
takes a function and it's shape polymorphism information exports it and saves it on disk and returns the exported artefact. This happens if the config optionenable_persistent_cache
isTrue
. Otherwise, the decorator returns the original function.Used as:
A function can be exported as:
args
only (default)kwargs
only - optionexport_with_kwargs
should be passed asTrue
@cacheable_function_with_export("f", shape_struct, export_with_kwargs=True)
static_argnums
andstatic_argnames
fromjax.jit
vjp_order
to support reverse-mode AD@cacheable_function_with_export("f", shape_struct, vjp_order=1)
For example, call
x
can be 2D array and 3D array. The function should be exported twice and on subsequent calls, the decorator decides which function to load based on the arguments passed.Usage:
Limitations
Currently, there's no support for forward-mode AD offered by JAX export. Attempts to export a function used in
jvp
results in an error. Thevmap
transformation is not supported as well, as there is notcustom_batching_rule
developed for exported artefacts.