Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache functions with JAX export #18

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Conversation

shtopane
Copy link
Contributor

@shtopane shtopane commented Oct 16, 2024

Adds the ability to save a function on disk with flexible input shapes using JAX export module.
The decorator cacheable_function_with_export takes a function and it's shape polymorphism information exports it and saves it on disk and returns the exported artefact. This happens if the config option enable_persistent_cache is True. Otherwise, the decorator returns the original function.
Used as:

@cacheable_function_with_export("f", {
      # place here the shape of x in a tuple, with the first item symbolic shape spec and the second as the dtype
      # "a, b" is the symbolic shape information passed to jax export api. Note, a != b
      "x": ("a, b", jnp.float64)
})
def f(x):
      .....

A function can be exported as:

  • with args only (default)
  • with kwargs only - option export_with_kwargs should be passed as True
@cacheable_function_with_export("f", shape_struct, export_with_kwargs=True)
  • with static_argnums and static_argnames from jax.jit
@cacheable_function_with_export("f", shape_struct, skip_jitting=True)
@partial(jax.jit, static_argnums=(1,))
  • with export vjp_order to support reverse-mode AD
@cacheable_function_with_export("f", shape_struct, vjp_order=1)
  • with alternative shape structure. This case is when a function is called with different dimension arguments between calls.
    For example, call x can be 2D array and 3D array. The function should be exported twice and on subsequent calls, the decorator decides which function to load based on the arguments passed.
    Usage:
@cacheable_function_with_export("f", {
      # first export structure
      "x": ("a, b", jnp.float64)
},
{
   # alternative export structure
   "x": ("a, b, c", jnp.float64)
}
)
def f(x):
      .....

Limitations

Currently, there's no support for forward-mode AD offered by JAX export. Attempts to export a function used in jvp results in an error. The vmap transformation is not supported as well, as there is not custom_batching_rule developed for exported artefacts.

@gboehl
Copy link
Owner

gboehl commented Oct 16, 2024

Thank you! Could you add a short description above?

@gboehl
Copy link
Owner

gboehl commented Oct 16, 2024

Thanks again! I'll keep this PR open in the prospect that JAX will implement forward AD for rehydrated functions. Good work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants