diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 19341db9c..640905c77 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -76,7 +76,7 @@ jobs: strategy: matrix: python-version: ["3.9"] - jax-version: ["newest", "0.4.26"] # keep in sync with minimum version in export/pyproject.toml + jax-version: ["newest", "0.4.30"] # keep in sync with minimum version in export/pyproject.toml steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.8.0 diff --git a/export/orbax/export/jax_module.py b/export/orbax/export/jax_module.py index 9fac67995..8a9db9041 100644 --- a/export/orbax/export/jax_module.py +++ b/export/orbax/export/jax_module.py @@ -16,14 +16,16 @@ import dataclasses import os -from typing import Any, Callable, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union from absl import logging import jax +from jax import export as jax_export from jax.experimental import jax2tf import orbax.checkpoint as ocp from orbax.export import dtensor_utils from orbax.export import typing as orbax_export_typing +from orbax.export import utils as orbax_export_utils import tensorflow as tf from tensorflow.experimental import dtensor @@ -333,6 +335,23 @@ def jax_methods(self) -> Mapping[str, Callable[..., Any]]: apply_fn_map = self._nontrackable_metadata.apply_fn_map return _make_closures(params, apply_fn_map) + def to_jax_exported_map( + self, model_inputs: PyTree, output_dir: Union[str, None] = None + ) -> Mapping[str, jax_export.Exported]: + """Converts the orbax.export JaxModule to jax_export.Exported. + + Args: + model_inputs: The model inputs. + output_dir: The output directory to save the jax_exported_map. + + Returns: + A mapping from method key to jax_export.Exported. + """ + jax_exported_map = _jax_module_to_jax_exported_map(self, model_inputs) + if output_dir is not None: + orbax_export_utils.save_jax_exported_map(output_dir, jax_exported_map) + return jax_exported_map + def _get_param_names(params: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" @@ -429,3 +448,74 @@ def _to_tf_variable(x, name, trainable, pspec): return jax.tree_util.tree_map( _to_tf_variable, params, names, trainable, pspecs ) + + +def _jax_module_to_jax_exported_map( + j_module: JaxModule, + model_inputs: PyTree, +) -> Mapping[str, jax_export.Exported]: + """Convert the orbax.export JaxModule to jax_export.Exported. + + Args: + j_module: The orbax.export JaxModule. + model_inputs: The model inputs. + + Returns: + A mapping from method key to jax_export.Exported. + """ + apply_fn_map = j_module.apply_fn_map + model_params = j_module.model_params + input_polymorphic_shape_map = j_module.input_polymorphic_shape_map + jax2tf_kwargs_map = j_module.jax2tf_kwargs_map + + jax_exported_map = {} + + def _symbolic_args_specs(model_inputs, method_key): + input_polymorphic_shape = input_polymorphic_shape_map[method_key] + polymorphic_constraints: Sequence[str] = () + if 'polymorphic_constraints' in jax2tf_kwargs_map[method_key]: + polymorphic_constraints = jax2tf_kwargs_map[method_key][ + 'polymorphic_constraints' + ] + if input_polymorphic_shape is None: + return model_inputs + else: + return jax_export.symbolic_args_specs( + model_inputs, + input_polymorphic_shape, + constraints=polymorphic_constraints, + ) + + symbolic_model_inputs_map = { + k: _symbolic_args_specs(model_inputs, k) + for k in input_polymorphic_shape_map.keys() + } + + def _lowering_platforms( + jax2tf_kwargs: Any, + ) -> Optional[Sequence[str]]: + if jax2tf_kwargs and 'native_serialization_platforms' in jax2tf_kwargs: + return tuple(jax2tf_kwargs['native_serialization_platforms']) + else: + return None + + lowering_platforms_map = { + k: _lowering_platforms(v) for k, v in jax2tf_kwargs_map.items() + } + + for method_key, apply_fn in apply_fn_map.items(): + if not hasattr(apply_fn, 'trace'): + apply_fn = jax.jit(apply_fn) + if method_key not in input_polymorphic_shape_map: + raise ValueError( + f'Method key {method_key} not found in input_polymorphic_shape_map.' + ) + if method_key not in lowering_platforms_map: + raise ValueError( + f'Method key {method_key} not found in lowering_platforms_map.' + ) + jax_exported = jax_export.export( + apply_fn, platforms=lowering_platforms_map[method_key] + )(model_params, symbolic_model_inputs_map[method_key]) + jax_exported_map[method_key] = jax_exported + return jax_exported_map diff --git a/export/orbax/export/jax_module_test.py b/export/orbax/export/jax_module_test.py index 3dacf5bc6..dfa20bfbf 100644 --- a/export/orbax/export/jax_module_test.py +++ b/export/orbax/export/jax_module_test.py @@ -15,6 +15,7 @@ """Tests for jax_module.""" import collections +import os from absl.testing import parameterized import chex @@ -22,9 +23,11 @@ import jax.numpy as jnp import numpy as np from orbax import export as obx_export +from orbax.export import utils as orbax_export_utils import tensorflow as tf DEFAULT_METHOD_KEY = obx_export.JaxModule.DEFAULT_METHOD_KEY +JaxModule = obx_export.JaxModule def _register_custom_dict_to_jax(dict_cls): @@ -419,6 +422,66 @@ def test_variable_update_error(self): with self.assertRaisesRegex(ValueError, 'Incompatible type conversion'): jax_module.update_variables({'w': np.zeros((4, 8), dtype=np.int32)}) + def test_save_load_as_jax_exported_map(self): + + def linear(params, x): + return params['w'] @ x + params['b'] + + key_w, key_b, key_x = jax.random.split(jax.random.PRNGKey(1234), 3) + model_params = { + 'w': jax.random.normal(key_w, shape=(8, 8)), + 'b': jax.random.normal(key_b, shape=(8, 1)), + } + model_inputs = jax.random.normal(key_x, shape=(8, 1)) + lowering_platforms = ['cpu', 'tpu'] + + j_module = JaxModule( + model_params, + linear, + jax2tf_kwargs={'native_serialization_platforms': lowering_platforms}, + ) + root_dir = self.create_tempdir().full_path + saved_dir = os.path.join(root_dir, 'jax_exported_map') + jax_exported_map = j_module.to_jax_exported_map(model_inputs, saved_dir) + restored_jax_exported_map = orbax_export_utils.load_jax_exported_map( + saved_dir + ) + self.assertEqual( + set(restored_jax_exported_map.keys()), + set(jax_exported_map.keys()), + f'{restored_jax_exported_map.keys()} vs {jax_exported_map.keys()}', + ) + self.assertEqual( + set(restored_jax_exported_map.keys()), + set(j_module.apply_fn_map.keys()), + f'{restored_jax_exported_map.keys()} vs {j_module.apply_fn_map.keys()}', + ) + chex.assert_trees_all_close( + restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].call( + model_params, model_inputs + ), + linear(model_params, model_inputs), + ) + chex.assert_equal( + set(restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].platforms), + set(lowering_platforms), + ) + args_kwargs = ((model_params, model_inputs), {}) + in_tree = jax.tree.structure(args_kwargs) + in_avals = tuple(jax.tree.leaves(args_kwargs)) + chex.assert_equal( + in_tree, + restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_tree, + ) + chex.assert_trees_all_equal_shapes( + in_avals, + restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_avals, + ) + chex.assert_trees_all_equal_dtypes( + in_avals, + restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_avals, + ) + if __name__ == '__main__': tf.test.main() diff --git a/export/orbax/export/typing.py b/export/orbax/export/typing.py index ab098f9be..42ee68781 100644 --- a/export/orbax/export/typing.py +++ b/export/orbax/export/typing.py @@ -16,7 +16,6 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar, Union import jaxtyping -from orbax.export import utils as orbax_export_utils import tensorflow as tf @@ -24,9 +23,7 @@ Nested = Union[T, tuple[Any, ...], Sequence[Any], Mapping[str, Any]] WarmupExample = Union[list[Mapping[str, Any]], Mapping[str, Any]] NestedTfTrackable = Nested[tf.saved_model.experimental.TrackableResource] -NestedTfTensorSpec = Nested[ - Union[tf.TensorSpec, orbax_export_utils.TensorSpecWithDefault] -] + PyTree = jaxtyping.PyTree diff --git a/export/orbax/export/utils.py b/export/orbax/export/utils.py index 6410ef83b..db42162d3 100644 --- a/export/orbax/export/utils.py +++ b/export/orbax/export/utils.py @@ -18,8 +18,11 @@ import dataclasses import functools import inspect -from typing import Any, Callable, Optional +import os +from typing import Any, Callable, Optional, Union +from absl import logging import jax +from jax import export as jax_export import jaxtyping import tensorflow as tf @@ -27,6 +30,8 @@ PyTree = jaxtyping.PyTree SignatureDef = Any +_FILE_TYPE = 'jax_exported' + @dataclasses.dataclass class TensorSpecWithDefault: @@ -62,6 +67,11 @@ def __post_init__(self): ) +NestedTfTensorSpec = jaxtyping.PyTree[ + Union[tf.TensorSpec, TensorSpecWithDefault] +] + + def remove_signature_defaults(input_signature: PyTree) -> PyTree: """Removes TensorSpecWithDefault from an input_signature.""" @@ -329,3 +339,64 @@ def from_saved_model( def signatures(self): """Returns a mapping for signature names to python callables.""" return self._signatures + + +def save_jax_exported_to_disk( + exp: jax_export.Exported, + bin_file_path: str, + vjp_order: int = 0, +) -> None: + if tf.io.gfile.exists(bin_file_path): + raise ValueError(f'File {bin_file_path} already exists.') + with tf.io.gfile.GFile(bin_file_path, 'wb') as f: + f.write(exp.serialize(vjp_order=vjp_order)) + + +def load_jax_exported_from_disk(bin_file_path: str) -> jax_export.Exported: + if not tf.io.gfile.exists(bin_file_path): + raise ValueError(f'File {bin_file_path} does not exist.') + with tf.io.gfile.GFile(bin_file_path, 'rb') as f: + exp = jax_export.deserialize(bytearray(f.read())) + return exp + + +def save_jax_exported_map( + dir_path: str, + jax_exported_map: Mapping[str, jax_export.Exported], +): + """Saves the orbax.export JaxExported Map to disk.""" + if tf.io.gfile.exists(dir_path): + raise ValueError(f'Directory {dir_path} already exists.') + + tf.io.gfile.makedirs(dir_path) + for method_key, jax_exported in jax_exported_map.items(): + file_path = os.path.join(dir_path, f'{method_key}.{_FILE_TYPE}') + save_jax_exported_to_disk(jax_exported, os.path.join(dir_path, file_path)) + logging.info('Saved JaxExported Map to %s successfully.', dir_path) + + +def load_jax_exported_map(dir_path: str) -> Mapping[str, jax_export.Exported]: + """Loads the orbax.export ApplyFn JaxExported Map from disk. + + Args: + dir_path: The directory path to load the ApplyFn Map. + + Returns: + A map of method_key to JaxExported object. + """ + jax_exported_map = {} + + if not tf.io.gfile.exists(dir_path): + raise ValueError(f'Directory {dir_path} does not exist.') + + for method_key in tf.io.gfile.listdir(dir_path): + if not method_key.endswith(f'.{_FILE_TYPE}'): + continue + jax_exported = load_jax_exported_from_disk( + os.path.join(dir_path, method_key) + ) + jax_exported_map[method_key[: -len(f'.{_FILE_TYPE}')]] = jax_exported + if not jax_exported_map: + raise ValueError(f'No .{_FILE_TYPE} files found in {dir_path}.') + logging.info('Loaded ApplyFn JaxExported Map from %s successfully.', dir_path) + return jax_exported_map diff --git a/export/pyproject.toml b/export/pyproject.toml index db67046b2..27e8713ee 100644 --- a/export/pyproject.toml +++ b/export/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ 'absl-py', 'etils', 'orbax-checkpoint', - 'jax >= 0.4.26', + 'jax >= 0.4.30', 'jaxlib', 'numpy', 'dataclasses-json',