Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650383101
  • Loading branch information
maxwillzq authored and Orbax Authors committed Jul 8, 2024
1 parent 94eff14 commit 0745731
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 85 deletions.
81 changes: 57 additions & 24 deletions export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from absl import logging
import jax
from jax.experimental import jax2tf
import jaxtyping
import orbax.checkpoint as ocp
from orbax.export import dtensor_utils
from orbax.export import typing as orbax_export_typing
import tensorflow as tf
from tensorflow.experimental import dtensor

PyTree = jaxtyping.PyTree
ApplyFn = Callable[[PyTree, PyTree], PyTree]

PyTree = orbax_export_typing.PyTree
ApplyFn = orbax_export_typing.ApplyFn


def get_obx_export_tf_preprocess_only() -> bool:
Expand Down Expand Up @@ -70,9 +71,13 @@ class _NonTrackableMetadata:
"""

apply_fn_map: Mapping[str, ApplyFn]
var_treedef: Any
tf_var_treedef: Any
var_trainable: Mapping[str, bool]
var_pspecs: Optional[Mapping[str, PyTree]]
model_params: PyTree
jax2tf_kwargs_map: Mapping[str, Any]
input_polymorphic_shape_map: Mapping[str, Any]
allow_multi_axis_sharding_conslidation: Optional[bool]


class JaxModule(tf.Module):
Expand All @@ -89,7 +94,7 @@ def __init__(
params: PyTree,
apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]],
trainable: Optional[Union[bool, PyTree]] = None,
input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree]] = None,
input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree], None] = None,
jax2tf_kwargs: Optional[Mapping[str, Any]] = None,
jit_compile: Union[bool, Mapping[str, bool]] = True,
pspecs: Optional[PyTree] = None,
Expand Down Expand Up @@ -144,7 +149,11 @@ def __init__(
# Check if `apply_fn`, `input_polymorphic_shape` and `jax2tf_kwargs` have
# the same structure.
apply_fn_map = apply_fn
if not isinstance(input_polymorphic_shape, Mapping) or not _same_keys(
if input_polymorphic_shape is None:
input_polymorphic_shape = jax.tree_util.tree_map(
lambda x: None, apply_fn_map
)
elif not isinstance(input_polymorphic_shape, Mapping) or not _same_keys(
input_polymorphic_shape, apply_fn_map
):
raise ValueError(
Expand Down Expand Up @@ -183,16 +192,16 @@ def __init__(
if get_obx_export_tf_preprocess_only():
# Skip the heavy jax_params_to_tf_variables() call in TF preprocess only
# mode.
var_treedef = None
self._var_leaves = None
tf_var_treedef = None
self._tf_var_leaves = None
self._methods = dict()
else:
tf_vars = _jax_params_to_tf_variables(
params, trainable, pspecs, allow_multi_axis_sharding_conslidation
)
# Do not attach `tf_vars` to `self` directly, otherwise its structure will
# be mutated by `tf.Module.__setattr__`.
self._var_leaves, var_treedef = jax.tree_util.tree_flatten(tf_vars)
self._tf_var_leaves, tf_var_treedef = jax.tree_util.tree_flatten(tf_vars)
self._methods = jax.tree_util.tree_map(
self._make_tf_closure,
apply_fn_map,
Expand All @@ -201,17 +210,39 @@ def __init__(
jit_compile,
)

self._jax2tf_kwargs_map = jax2tf_kwargs
self._jax_methods = _make_closures(params, apply_fn_map)

# Keep the following Metadata for variable update.
# # Preserve the original structure of this Metadata object to prevent
# unintended conversion by TF tf.Module (e.g., Dict to DictWrapper).
self._nontrackable_metadata = _NonTrackableMetadata(
apply_fn_map=apply_fn_map,
var_treedef=var_treedef,
tf_var_treedef=tf_var_treedef,
var_trainable=trainable,
var_pspecs=pspecs,
model_params=params,
jax2tf_kwargs_map=jax2tf_kwargs,
input_polymorphic_shape_map=input_polymorphic_shape,
allow_multi_axis_sharding_conslidation=allow_multi_axis_sharding_conslidation,
)

@property
def apply_fn_map(self) -> Mapping[str, ApplyFn]:
"""Returns the apply_fn_map."""
return self._nontrackable_metadata.apply_fn_map

@property
def model_params(self) -> PyTree:
"""Returns the model parameters."""
return self._nontrackable_metadata.model_params

@property
def jax2tf_kwargs_map(self) -> Mapping[str, Any]:
"""Returns the jax2tf_kwargs_map."""
return self._nontrackable_metadata.jax2tf_kwargs_map

@property
def input_polymorphic_shape_map(self) -> Mapping[str, PyTree]:
"""Returns the polymorphic shapes."""
return self._nontrackable_metadata.input_polymorphic_shape_map

def update_variables(self, params: PyTree):
"""Updates the variables associated with self.
Expand All @@ -221,31 +252,31 @@ def update_variables(self, params: PyTree):
shape and dtype of each parameter must be the same as the original
parameter.
"""
# Update jax model_params
object.__setattr__(self._nontrackable_metadata, 'model_params', params)

# Update TF model_params
_, treedef = jax.tree_util.tree_flatten(params)
if treedef != self._nontrackable_metadata.var_treedef:
if treedef != self._nontrackable_metadata.tf_var_treedef:
raise ValueError(
'The PyTree structure of the updated parameters must be the same as'
f' that of the original parameters. Got new treedef: {treedef},'
f' original treedef: {self._nontrackable_metadata.var_treedef}'
f' original treedef: {self._nontrackable_metadata.tf_var_treedef}'
)
new_vars = _jax_params_to_tf_variables(
params,
self._nontrackable_metadata.model_params,
self._nontrackable_metadata.var_trainable,
self._nontrackable_metadata.var_pspecs,
self._nontrackable_metadata.allow_multi_axis_sharding_conslidation,
)

jax.tree_util.tree_map(
lambda v, new_v: v.assign(new_v), self._get_variable_tree(), new_vars
)

self._jax_methods = _make_closures(
params, self._nontrackable_metadata.apply_fn_map
)

def _get_variable_tree(self) -> PyTree:
"""Returns the PyTree of the tf.Variables associated with self."""
return jax.tree_util.tree_unflatten(
self._nontrackable_metadata.var_treedef, self._var_leaves
self._nontrackable_metadata.tf_var_treedef, self._tf_var_leaves
)

def _make_tf_closure(
Expand Down Expand Up @@ -298,7 +329,9 @@ def methods(self) -> Mapping[str, Callable[..., Any]]:
@property
def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
"""Named methods in JAX context for validation."""
return self._jax_methods
params = self._nontrackable_metadata.model_params
apply_fn_map = self._nontrackable_metadata.apply_fn_map
return _make_closures(params, apply_fn_map)


def _get_param_names(params: PyTree) -> PyTree:
Expand Down
Loading

0 comments on commit 0745731

Please sign in to comment.