diff --git a/export/orbax/export/jax_module.py b/export/orbax/export/jax_module.py index 5852e1e4..9fac6799 100644 --- a/export/orbax/export/jax_module.py +++ b/export/orbax/export/jax_module.py @@ -21,14 +21,15 @@ from absl import logging import jax from jax.experimental import jax2tf -import jaxtyping import orbax.checkpoint as ocp from orbax.export import dtensor_utils +from orbax.export import typing as orbax_export_typing import tensorflow as tf from tensorflow.experimental import dtensor -PyTree = jaxtyping.PyTree -ApplyFn = Callable[[PyTree, PyTree], PyTree] + +PyTree = orbax_export_typing.PyTree +ApplyFn = orbax_export_typing.ApplyFn def get_obx_export_tf_preprocess_only() -> bool: @@ -70,9 +71,13 @@ class _NonTrackableMetadata: """ apply_fn_map: Mapping[str, ApplyFn] - var_treedef: Any + tf_var_treedef: Any var_trainable: Mapping[str, bool] var_pspecs: Optional[Mapping[str, PyTree]] + model_params: PyTree + jax2tf_kwargs_map: Mapping[str, Any] + input_polymorphic_shape_map: Mapping[str, Any] + allow_multi_axis_sharding_conslidation: Optional[bool] class JaxModule(tf.Module): @@ -89,7 +94,7 @@ def __init__( params: PyTree, apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]], trainable: Optional[Union[bool, PyTree]] = None, - input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree]] = None, + input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree], None] = None, jax2tf_kwargs: Optional[Mapping[str, Any]] = None, jit_compile: Union[bool, Mapping[str, bool]] = True, pspecs: Optional[PyTree] = None, @@ -144,7 +149,11 @@ def __init__( # Check if `apply_fn`, `input_polymorphic_shape` and `jax2tf_kwargs` have # the same structure. apply_fn_map = apply_fn - if not isinstance(input_polymorphic_shape, Mapping) or not _same_keys( + if input_polymorphic_shape is None: + input_polymorphic_shape = jax.tree_util.tree_map( + lambda x: None, apply_fn_map + ) + elif not isinstance(input_polymorphic_shape, Mapping) or not _same_keys( input_polymorphic_shape, apply_fn_map ): raise ValueError( @@ -183,8 +192,8 @@ def __init__( if get_obx_export_tf_preprocess_only(): # Skip the heavy jax_params_to_tf_variables() call in TF preprocess only # mode. - var_treedef = None - self._var_leaves = None + tf_var_treedef = None + self._tf_var_leaves = None self._methods = dict() else: tf_vars = _jax_params_to_tf_variables( @@ -192,7 +201,7 @@ def __init__( ) # Do not attach `tf_vars` to `self` directly, otherwise its structure will # be mutated by `tf.Module.__setattr__`. - self._var_leaves, var_treedef = jax.tree_util.tree_flatten(tf_vars) + self._tf_var_leaves, tf_var_treedef = jax.tree_util.tree_flatten(tf_vars) self._methods = jax.tree_util.tree_map( self._make_tf_closure, apply_fn_map, @@ -201,17 +210,39 @@ def __init__( jit_compile, ) - self._jax2tf_kwargs_map = jax2tf_kwargs - self._jax_methods = _make_closures(params, apply_fn_map) - - # Keep the following Metadata for variable update. + # # Preserve the original structure of this Metadata object to prevent + # unintended conversion by TF tf.Module (e.g., Dict to DictWrapper). self._nontrackable_metadata = _NonTrackableMetadata( apply_fn_map=apply_fn_map, - var_treedef=var_treedef, + tf_var_treedef=tf_var_treedef, var_trainable=trainable, var_pspecs=pspecs, + model_params=params, + jax2tf_kwargs_map=jax2tf_kwargs, + input_polymorphic_shape_map=input_polymorphic_shape, + allow_multi_axis_sharding_conslidation=allow_multi_axis_sharding_conslidation, ) + @property + def apply_fn_map(self) -> Mapping[str, ApplyFn]: + """Returns the apply_fn_map.""" + return self._nontrackable_metadata.apply_fn_map + + @property + def model_params(self) -> PyTree: + """Returns the model parameters.""" + return self._nontrackable_metadata.model_params + + @property + def jax2tf_kwargs_map(self) -> Mapping[str, Any]: + """Returns the jax2tf_kwargs_map.""" + return self._nontrackable_metadata.jax2tf_kwargs_map + + @property + def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]: + """Returns the polymorphic shapes.""" + return self._nontrackable_metadata.input_polymorphic_shape_map + def update_variables(self, params: PyTree): """Updates the variables associated with self. @@ -221,31 +252,31 @@ def update_variables(self, params: PyTree): shape and dtype of each parameter must be the same as the original parameter. """ + # Update jax model_params + object.__setattr__(self._nontrackable_metadata, 'model_params', params) + + # Update TF model_params _, treedef = jax.tree_util.tree_flatten(params) - if treedef != self._nontrackable_metadata.var_treedef: + if treedef != self._nontrackable_metadata.tf_var_treedef: raise ValueError( 'The PyTree structure of the updated parameters must be the same as' f' that of the original parameters. Got new treedef: {treedef},' - f' original treedef: {self._nontrackable_metadata.var_treedef}' + f' original treedef: {self._nontrackable_metadata.tf_var_treedef}' ) new_vars = _jax_params_to_tf_variables( - params, + self._nontrackable_metadata.model_params, self._nontrackable_metadata.var_trainable, self._nontrackable_metadata.var_pspecs, + self._nontrackable_metadata.allow_multi_axis_sharding_conslidation, ) - jax.tree_util.tree_map( lambda v, new_v: v.assign(new_v), self._get_variable_tree(), new_vars ) - self._jax_methods = _make_closures( - params, self._nontrackable_metadata.apply_fn_map - ) - def _get_variable_tree(self) -> PyTree: """Returns the PyTree of the tf.Variables associated with self.""" return jax.tree_util.tree_unflatten( - self._nontrackable_metadata.var_treedef, self._var_leaves + self._nontrackable_metadata.tf_var_treedef, self._tf_var_leaves ) def _make_tf_closure( @@ -298,7 +329,9 @@ def methods(self) -> Mapping[str, Callable[..., Any]]: @property def jax_methods(self) -> Mapping[str, Callable[..., Any]]: """Named methods in JAX context for validation.""" - return self._jax_methods + params = self._nontrackable_metadata.model_params + apply_fn_map = self._nontrackable_metadata.apply_fn_map + return _make_closures(params, apply_fn_map) def _get_param_names(params: PyTree) -> PyTree: diff --git a/export/orbax/export/jax_module_test.py b/export/orbax/export/jax_module_test.py index 0db23712..3dacf5bc 100644 --- a/export/orbax/export/jax_module_test.py +++ b/export/orbax/export/jax_module_test.py @@ -17,14 +17,15 @@ import collections from absl.testing import parameterized +import chex import jax import jax.numpy as jnp -from jax.sharding import Mesh -from jax.sharding import PartitionSpec import numpy as np -from orbax.export.jax_module import JaxModule +from orbax import export as obx_export import tensorflow as tf +DEFAULT_METHOD_KEY = obx_export.JaxModule.DEFAULT_METHOD_KEY + def _register_custom_dict_to_jax(dict_cls): def _flatten_with_keys(xs): @@ -67,7 +68,8 @@ class MyDict(dict): empty_nodes=[dict(), tuple(), list(), MyDict(), YetAnotherDict()], ) variable_names_to_vals = { - v.name: v for v in JaxModule(params, lambda params, x: x).variables + v.name: v + for v in obx_export.JaxModule(params, lambda params, x: x).variables } self.assertEqual( { @@ -94,7 +96,8 @@ def test_variable_names_contains_tilde(self): } } variable_names_to_vals = { - v.name: v for v in JaxModule(params, lambda params, x: x).variables + v.name: v + for v in obx_export.JaxModule(params, lambda params, x: x).variables } self.assertEqual( { @@ -113,43 +116,50 @@ class MyDict(dict): a=jnp.array(1), b=[jnp.array([5, 6]), jnp.array([7, 8])], ) - variables = JaxModule(params, lambda params, x: x).variables + variables = obx_export.JaxModule(params, lambda params, x: x).variables names = {v.name for v in variables} self.assertLen(names, len(variables)) def test_trainable(self): params = {'x': jnp.array(1), 'y': jnp.array(2)} trainable = {'x': True, 'y': False} - jm = JaxModule(params, lambda params, x: x, trainable=trainable) + jm = obx_export.JaxModule(params, lambda params, x: x, trainable=trainable) self.assertLen(jm.trainable_variables, 1) self.assertEqual(jm.trainable_variables[0].name, 'x:0') self.assertEqual(jm.trainable_variables[0], jnp.array(1)) self.assertTrue(jm.with_gradient) - jm = JaxModule(params, lambda params, x: x) + jm = obx_export.JaxModule(params, lambda params, x: x) self.assertEmpty(jm.trainable_variables) self.assertFalse(jm.with_gradient) - jm = JaxModule(params, lambda params, x: x, trainable=True) + jm = obx_export.JaxModule(params, lambda params, x: x, trainable=True) self.assertLen(jm.trainable_variables, 2) self.assertTrue(jm.with_gradient) - jm = JaxModule(params, lambda params, x: x, trainable=False) + jm = obx_export.JaxModule(params, lambda params, x: x, trainable=False) self.assertEmpty(jm.trainable_variables) self.assertFalse(jm.with_gradient) def test_jax_array(self): - global_mesh = Mesh(np.array(jax.local_devices(backend='cpu')), 'x') - mesh_axes = PartitionSpec('x') + global_mesh = jax.sharding.Mesh( + np.array(jax.local_devices(backend='cpu')), 'x' + ) + mesh_axes = jax.sharding.PartitionSpec('x') global_input_shape = (jax.device_count('cpu'), 2) - global_input_data = np.arange( - np.prod(global_input_shape)).reshape(global_input_shape) + global_input_data = np.arange(np.prod(global_input_shape)).reshape( + global_input_shape + ) arr = jax.make_array_from_callback( - global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), - lambda idx: global_input_data[idx]) + global_input_shape, + jax.sharding.NamedSharding(global_mesh, mesh_axes), + lambda idx: global_input_data[idx], + ) self.assertIsInstance(arr, jax.Array) - variables = JaxModule({'arr': arr}, lambda params, x: x).variables + variables = obx_export.JaxModule( + {'arr': arr}, lambda params, x: x + ).variables self.assertLen(variables, 1) self.assertEqual(variables[0].name, 'arr:0') self.assertAllEqual(variables[0], global_input_data) @@ -167,9 +177,49 @@ def linear(params, x): } x = jax.random.normal(key_x, shape=(8, 1)) - jax_module = JaxModule(params, linear, jit_compile=jit_compile) - self.assertAllClose(jax_module.methods[JaxModule.DEFAULT_METHOD_KEY](x), - jax_module.jax_methods[JaxModule.DEFAULT_METHOD_KEY](x)) + jax_module = obx_export.JaxModule(params, linear, jit_compile=jit_compile) + self.assertAllClose( + jax_module.methods[DEFAULT_METHOD_KEY](x), + jax_module.jax_methods[DEFAULT_METHOD_KEY](x), + ) + + @parameterized.parameters(True, False) + def test_jax_module_property(self, jit_compile): + + def linear1(params, x): + return params['w'] @ x + params['b'] + + def linear2(params, x): + return params['w'] @ x + params['b'] * 0.1 + + key_w, key_b = jax.random.split(jax.random.PRNGKey(1234), 2) + params = { + 'w': jax.random.normal(key_w, shape=(8, 8)), + 'b': jax.random.normal(key_b, shape=(8, 1)), + } + + j_module = obx_export.JaxModule( + params, + {'linear1': linear1, 'linear2': linear2}, + jit_compile=jit_compile, + ) + self.assertEqual( + set(j_module.apply_fn_map.keys()), set(['linear1', 'linear2']) + ) + self.assertEqual( + set(j_module.jax2tf_kwargs_map.keys()), set(['linear1', 'linear2']) + ) + self.assertEqual( + set(j_module.input_polymorphic_shape_map.keys()), + set(['linear1', 'linear2']), + ) + chex.assert_trees_all_equal(j_module.model_params, params) + new_params = { + 'w': jax.random.normal(key_w, shape=(8, 8)), + 'b': jax.random.normal(key_b, shape=(8, 1)), + } + j_module.update_variables(new_params) + self.assertEqual(j_module.model_params, new_params) @parameterized.parameters(True, False) def test_polymorphic_shapes(self, jit_compile): @@ -183,23 +233,21 @@ def linear(params, batch): 'b': jax.random.normal(key_b, shape=(8, 1)), } - with self.assertRaisesRegex(ValueError, - 'Do not use `polymorphic_shapes`'): - JaxModule( - params, - linear, - jax2tf_kwargs={'polymorphic_shapes': [None, 'b, ...']}) + with self.assertRaisesRegex(ValueError, 'Do not use `polymorphic_shapes`'): + obx_export.JaxModule( + params, linear, jax2tf_kwargs={'polymorphic_shapes': [None, 'b, ...']} + ) - jax_module = JaxModule( + jax_module = obx_export.JaxModule( params, linear, jit_compile=jit_compile, - input_polymorphic_shape='b, ...') + input_polymorphic_shape='b, ...', + ) - @tf.function( - input_signature=[tf.TensorSpec([None, 8, 1], tf.float32)]) + @tf.function(input_signature=[tf.TensorSpec([None, 8, 1], tf.float32)]) def traced(x): - return jax_module.methods[JaxModule.DEFAULT_METHOD_KEY](x) + return jax_module.methods[DEFAULT_METHOD_KEY](x) key_x1, key_x2 = jax.random.split(key_x, 2) x1 = jax.random.normal(key_x1, shape=(8, 8, 1)) # batch size is 8 @@ -226,7 +274,7 @@ def linear(params, x): input_signature=[tf.TensorSpec([None, 1], tf.float32)], ) def traced(x): - return jax_module.methods[JaxModule.DEFAULT_METHOD_KEY](x) + return jax_module.methods[DEFAULT_METHOD_KEY](x) x = jax.random.normal(key_x, shape=(2, 1)) # batch size is 2 @@ -236,11 +284,13 @@ def traced(x): Exception, "Symbolic dimension comparison 'b' > '1' is inconclusive.", ): - jax_module = JaxModule(params, linear, input_polymorphic_shape='b, _') + jax_module = obx_export.JaxModule( + params, linear, input_polymorphic_shape='b, _' + ) _ = traced(x) # With user provided constraints, the trace compiling should succeed. - jax_module = JaxModule( + jax_module = obx_export.JaxModule( params, linear, input_polymorphic_shape='b, _', @@ -249,7 +299,7 @@ def traced(x): self.assertAllClose(traced(x), linear(params, x)) def test_multi_functions(self): - jax_module = JaxModule( + jax_module = obx_export.JaxModule( params={'delta': jnp.ones((), jnp.int32)}, apply_fn={ 'add': lambda params, x: x + params['delta'], @@ -258,54 +308,62 @@ def test_multi_functions(self): input_polymorphic_shape={ 'add': None, 'sub': 'b, ...', # Make `sub` batch polymorphic. - }) + }, + ) # `add` cannot accept polymorphic shapes. with self.assertRaisesRegex(ValueError, 'syntax error'): jax_module.methods['add'].get_concrete_function( - tf.TensorSpec([None], tf.int32)) + tf.TensorSpec([None], tf.int32) + ) # `add` can accept fixed shapes. jax_module.methods['add'].get_concrete_function( - tf.TensorSpec([1], tf.int32)) + tf.TensorSpec([1], tf.int32) + ) # `sub` can accept polymorphic shapes. jax_module.methods['sub'].get_concrete_function( - tf.TensorSpec([None], tf.int32)) + tf.TensorSpec([None], tf.int32) + ) def test_init_invalid_argument(self): - params = {'delta': jnp.ones((), jnp.int32)}, + params = ({'delta': jnp.ones((), jnp.int32)},) apply_fns = { 'add': lambda params, x: x + params['delta'], 'sub': lambda params, x: x - params['delta'], } with self.assertRaisesRegex(ValueError, '`input_polymorphic_shape` must'): - JaxModule(params, apply_fns) - - with self.assertRaisesRegex(ValueError, '`input_polymorphic_shape` must'): - JaxModule( - params, apply_fns, input_polymorphic_shape={ + obx_export.JaxModule( + params, + apply_fns, + input_polymorphic_shape={ 'add': None, - }) + }, + ) with self.assertRaisesRegex(ValueError, '`jax2tf_kwargs` must'): - JaxModule( + obx_export.JaxModule( params, apply_fns, input_polymorphic_shape=jax.tree_util.tree_map( - lambda x: None, apply_fns), - jax2tf_kwargs={'enable_xla': False}) + lambda x: None, apply_fns + ), + jax2tf_kwargs={'enable_xla': False}, + ) with self.assertRaisesRegex(ValueError, '`jit_compile` must'): - JaxModule( + obx_export.JaxModule( params, apply_fns, input_polymorphic_shape=jax.tree_util.tree_map( - lambda x: None, apply_fns), - jit_compile={'add': False}) + lambda x: None, apply_fns + ), + jit_compile={'add': False}, + ) with self.assertRaisesRegex(ValueError, 'contains trainable'): - JaxModule( + obx_export.JaxModule( params, lambda p, x: x, trainable=True, @@ -313,7 +371,7 @@ def test_init_invalid_argument(self): ) with self.assertRaisesRegex(ValueError, 'does not contain trainable'): - JaxModule( + obx_export.JaxModule( params, lambda p, x: x, trainable=False, @@ -330,7 +388,9 @@ def linear(params, batch): 'b': jax.random.normal(key_b, shape=(8, 1)), } - jax_module = JaxModule(params, linear, input_polymorphic_shape='b, ...') + jax_module = obx_export.JaxModule( + params, linear, input_polymorphic_shape='b, ...' + ) new_params = jax.tree_util.tree_map(lambda x: x + 1.0, params) jax_module.update_variables(new_params) @@ -338,15 +398,13 @@ def linear(params, batch): expected_res = linear(new_params, x) self.assertAllClose( - jax_module.jax_methods[JaxModule.DEFAULT_METHOD_KEY](x), expected_res - ) - self.assertAllClose( - jax_module.methods[JaxModule.DEFAULT_METHOD_KEY](x), expected_res + jax_module.jax_methods[DEFAULT_METHOD_KEY](x), expected_res ) + self.assertAllClose(jax_module.methods[DEFAULT_METHOD_KEY](x), expected_res) def test_variable_update_error(self): params = {'w': np.zeros((4, 8), dtype=np.float32)} - jax_module = JaxModule(params, lambda params, x: params['w'] @ x) + jax_module = obx_export.JaxModule(params, lambda params, x: params['w'] @ x) with self.assertRaisesRegex( ValueError, diff --git a/export/orbax/export/typing.py b/export/orbax/export/typing.py index 6c272093..ab098f9b 100644 --- a/export/orbax/export/typing.py +++ b/export/orbax/export/typing.py @@ -14,7 +14,8 @@ """Common typing for export.""" -from typing import Any, Mapping, Sequence, TypeVar, Union +from typing import Any, Callable, Mapping, Sequence, TypeVar, Union +import jaxtyping from orbax.export import utils as orbax_export_utils import tensorflow as tf @@ -26,3 +27,9 @@ NestedTfTensorSpec = Nested[ Union[tf.TensorSpec, orbax_export_utils.TensorSpecWithDefault] ] + +PyTree = jaxtyping.PyTree + +# ApplyFn take two arguments, the first one is the model_params, the second one +# is the model_inputs. +ApplyFn = Callable[[PyTree, PyTree], PyTree]