Skip to content

Commit

Permalink
Internal.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648500630
  • Loading branch information
maxwillzq authored and Orbax Authors committed Jul 3, 2024
1 parent 3767297 commit 959f22a
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 54 deletions.
2 changes: 1 addition & 1 deletion export/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from orbax.export.export_manager_base import ExportManagerBase
from orbax.export.jax_module import JaxModule
from orbax.export.serving_config import ServingConfig
from orbax.export.typing import TensorSpecWithDefault
# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
from orbax.export.utils import remove_signature_defaults
from orbax.export.utils import TensorSpecWithDefault


# A new PyPI release will be pushed everytime `__version__` is increased.
Expand Down
2 changes: 1 addition & 1 deletion export/orbax/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from orbax.export.export_manager_base import ExportManagerBase
from orbax.export.jax_module import JaxModule
from orbax.export.serving_config import ServingConfig
from orbax.export.typing import TensorSpecWithDefault
# TODO(dinghua): remove them after we change all references to
# utils.remove_signature_defaults.
from orbax.export.utils import remove_signature_defaults
from orbax.export.utils import TensorSpecWithDefault


# A new PyPI release will be pushed everytime `__version__` is increased.
Expand Down
127 changes: 115 additions & 12 deletions export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,29 @@

import dataclasses
import os
from typing import Any, Callable, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union

from absl import logging
import jax
from jax.experimental import jax2tf
import jaxtyping
import orbax.checkpoint as ocp
from orbax.export import dtensor_utils
from orbax.export import typing as orbax_export_typing
import tensorflow as tf
from tensorflow.experimental import dtensor

PyTree = jaxtyping.PyTree
ApplyFn = Callable[[PyTree, PyTree], PyTree]
# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
if jax.__version_info__ <= (0, 4, 29):
from jax.experimental import export as jax_export
else:
from jax import export as jax_export
# pylint: enable=g-bad-import-order
# pylint: enable=g-import-not-at-top


PyTree = orbax_export_typing.PyTree
ApplyFn = orbax_export_typing.ApplyFn


def get_obx_export_tf_preprocess_only() -> bool:
Expand Down Expand Up @@ -73,6 +83,9 @@ class _NonTrackableMetadata:
var_treedef: Any
var_trainable: Mapping[str, bool]
var_pspecs: Optional[Mapping[str, PyTree]]
model_params: PyTree
jax2tf_kwargs_map: Mapping[str, Any]
input_polymorphic_shape: Mapping[str, Any]


class JaxModule(tf.Module):
Expand Down Expand Up @@ -201,17 +214,43 @@ def __init__(
jit_compile,
)

self._jax2tf_kwargs_map = jax2tf_kwargs
self._jax_methods = _make_closures(params, apply_fn_map)

# Keep the following Metadata for variable update.
self._nontrackable_metadata = _NonTrackableMetadata(
apply_fn_map=apply_fn_map,
var_treedef=var_treedef,
var_trainable=trainable,
var_pspecs=pspecs,
model_params=params,
jax2tf_kwargs_map=jax2tf_kwargs,
input_polymorphic_shape=input_polymorphic_shape,
)

@property
def apply_fn_map(self) -> Mapping[str, ApplyFn]:
"""Returns the apply_fn_map."""
return self._nontrackable_metadata.apply_fn_map

@property
def model_params(self) -> PyTree:
"""Returns the model parameters."""
return self._nontrackable_metadata.model_params

@property
def jax2tf_kwargs_map(self) -> Mapping[str, Any]:
"""Returns the jax2tf_kwargs_map."""
return self._nontrackable_metadata.jax2tf_kwargs_map

@property
def input_polymorphic_shape(self) -> Mapping[str, PyTree]:
"""Returns the polymorphic shapes."""
return self._nontrackable_metadata.input_polymorphic_shape

def to_jax_exported_map(
self, model_inputs: PyTree
) -> Mapping[str, jax_export.Exported]:
"""Converts the orbax.export JaxModule to jax_export.Exported."""
return _jax_module_to_jax_exported_map(self, model_inputs)

def update_variables(self, params: PyTree):
"""Updates the variables associated with self.
Expand All @@ -221,6 +260,10 @@ def update_variables(self, params: PyTree):
shape and dtype of each parameter must be the same as the original
parameter.
"""
# Update jax model_params
self._nontrackable_metadata.model_params = params

# Update TF model_params
_, treedef = jax.tree_util.tree_flatten(params)
if treedef != self._nontrackable_metadata.var_treedef:
raise ValueError(
Expand All @@ -238,10 +281,6 @@ def update_variables(self, params: PyTree):
lambda v, new_v: v.assign(new_v), self._get_variable_tree(), new_vars
)

self._jax_methods = _make_closures(
params, self._nontrackable_metadata.apply_fn_map
)

def _get_variable_tree(self) -> PyTree:
"""Returns the PyTree of the tf.Variables associated with self."""
return jax.tree_util.tree_unflatten(
Expand Down Expand Up @@ -298,7 +337,9 @@ def methods(self) -> Mapping[str, Callable[..., Any]]:
@property
def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
"""Named methods in JAX context for validation."""
return self._jax_methods
params = self._nontrackable_metadata.model_params
apply_fn_map = self._nontrackable_metadata.apply_fn_map
return _make_closures(params, apply_fn_map)


def _get_param_names(params: PyTree) -> PyTree:
Expand Down Expand Up @@ -396,3 +437,65 @@ def _to_tf_variable(x, name, trainable, pspec):
return jax.tree_util.tree_map(
_to_tf_variable, params, names, trainable, pspecs
)


def _jax_module_to_jax_exported_map(
j_module: JaxModule,
model_inputs: PyTree,
) -> Mapping[str, jax_export.Exported]:
"""Convert the orbax.export JaxModule to jax_export.Exported.
Args:
j_module: The orbax.export JaxModule.
model_inputs: The model inputs.
Returns:
A mapping from method key to jax_export.Exported.
"""
apply_fn_map = j_module.apply_fn_map
model_params = j_module.model_params
input_polymorphic_shape = j_module.input_polymorphic_shape
jax2tf_kwargs_map = j_module.jax2tf_kwargs_map

jax_exported_map = {}

def _symbolic_args_specs(model_inputs, p_str):
print(p_str)
if p_str is None:
return model_inputs
else:
return jax_export.symbolic_args_specs(model_inputs, p_str)

model_inputs_map = {
k: _symbolic_args_specs(model_inputs, v)
for k, v in input_polymorphic_shape.items()
}

def _lowering_platforms(
jax2tf_kwargs: Any,
) -> Optional[Sequence[str]]:
if jax2tf_kwargs and 'native_serialization_platforms' in jax2tf_kwargs:
return tuple(jax2tf_kwargs['native_serialization_platforms'])
else:
return None

lowering_platforms_map = {
k: _lowering_platforms(v) for k, v in jax2tf_kwargs_map.items()
}

for method_key, apply_fn in apply_fn_map.items():
if not hasattr(apply_fn, 'trace'):
apply_fn = jax.jit(apply_fn)
if method_key not in model_inputs_map:
raise ValueError(
f'Method key {method_key} not found in model_inputs_map.'
)
if method_key not in lowering_platforms_map:
raise ValueError(
f'Method key {method_key} not found in lowering_platforms_map.'
)
jax_exported = jax_export.export(
apply_fn, platforms=lowering_platforms_map[method_key]
)(model_params, model_inputs_map[method_key])
jax_exported_map[method_key] = jax_exported
return jax_exported_map
56 changes: 56 additions & 0 deletions export/orbax/export/jax_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import collections

from absl.testing import parameterized
import chex
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
import numpy as np
from orbax.export import utils as orbax_export_utils
from orbax.export.jax_module import JaxModule
import tensorflow as tf

Expand Down Expand Up @@ -361,6 +363,60 @@ def test_variable_update_error(self):
with self.assertRaisesRegex(ValueError, 'Incompatible type conversion'):
jax_module.update_variables({'w': np.zeros((4, 8), dtype=np.int32)})

def test_save_load_as_jax_exported_map(self):

def linear(params, x):
return params['w'] @ x + params['b']

key_w, key_b, key_x = jax.random.split(jax.random.PRNGKey(1234), 3)
model_params = {
'w': jax.random.normal(key_w, shape=(8, 8)),
'b': jax.random.normal(key_b, shape=(8, 1)),
}
model_inputs = jax.random.normal(key_x, shape=(8, 1))
lowering_platforms = ['cpu', 'tpu']

j_module = JaxModule(
model_params,
linear,
jax2tf_kwargs={'native_serialization_platforms': lowering_platforms},
)
saved_dir = self.create_tempdir().full_path
jax_exported_map = j_module.to_jax_exported_map(model_inputs)
orbax_export_utils.save_jax_exported_map(saved_dir, jax_exported_map)
restored_jax_exported_map = orbax_export_utils.load_jax_exported_map(
saved_dir
)
self.assertEqual(
set(restored_jax_exported_map.keys()),
set(j_module.apply_fn_map.keys()),
)
chex.assert_trees_all_close(
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].call(
model_params, model_inputs
),
linear(model_params, model_inputs),
)
chex.assert_equal(
set(restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].platforms),
set(lowering_platforms),
)
args_kwargs = ((model_params, model_inputs), {})
in_tree = jax.tree.structure(args_kwargs)
in_avals = tuple(jax.tree.leaves(args_kwargs))
chex.assert_equal(
in_tree,
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_tree,
)
chex.assert_trees_all_equal_shapes(
in_avals,
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_avals,
)
chex.assert_trees_all_equal_dtypes(
in_avals,
restored_jax_exported_map[JaxModule.DEFAULT_METHOD_KEY].in_avals,
)


if __name__ == '__main__':
tf.test.main()
50 changes: 45 additions & 5 deletions export/orbax/export/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,56 @@
# limitations under the License.

"""Common typing for export."""
# tpying.py should not depend on any other files in orbax/export.

from typing import Any, Mapping, Sequence, TypeVar, Union
from orbax.export import utils as orbax_export_utils
import dataclasses
from typing import Any, Callable, Mapping, Sequence, TypeVar, Union
import jaxtyping
import tensorflow as tf


@dataclasses.dataclass
class TensorSpecWithDefault:
"""Extends tf.TensorSpec to hold a default value.
Constraints due to Python function calling conventions:
- For a python function parameter, all corresponding tensor values in the
signature must have a TensorSpecWithDefault or none of them should.
- Parameters with default values should be ordered before non-default ones.
"""

tensor_spec: tf.TensorSpec
default_val: Any
is_primary: bool = False

def __post_init__(self):
if self.default_val is None:
raise ValueError('Use TensorSpec if no defaults are needed.')

# Has to be a Tensor to be available for TF1 style signatures.
if not isinstance(self.default_val, tf.Tensor):
self.default_val = tf.convert_to_tensor(
self.default_val, dtype=self.tensor_spec.dtype
)

if not tf.TensorSpec.from_tensor(
self.default_val,
name=self.tensor_spec.name,
).is_subtype_of(self.tensor_spec):
raise ValueError(
f'TensorSpec {self.tensor_spec} is not compatible with'
f' the default value {self.default_val}'
)


T = TypeVar('T')
Nested = Union[T, tuple[Any, ...], Sequence[Any], Mapping[str, Any]]
WarmupExample = Union[list[Mapping[str, Any]], Mapping[str, Any]]
NestedTfTrackable = Nested[tf.saved_model.experimental.TrackableResource]
NestedTfTensorSpec = Nested[
Union[tf.TensorSpec, orbax_export_utils.TensorSpecWithDefault]
]
NestedTfTensorSpec = Nested[Union[tf.TensorSpec, TensorSpecWithDefault]]

PyTree = jaxtyping.PyTree

# ApplyFn take two arguments, the first one is the model_params, the second one
# is the model_inputs.
ApplyFn = Callable[[PyTree, PyTree], PyTree]
Loading

0 comments on commit 959f22a

Please sign in to comment.