Skip to content

Commit

Permalink
Switch tensorflow_gnn/experimental/sampler/eval_dag.py to tf_internal.py
Browse files Browse the repository at this point in the history
for dealing with internals of different Keras versions.

PiperOrigin-RevId: 604655724
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Feb 6, 2024
1 parent ffa453f commit e1d9210
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
19 changes: 3 additions & 16 deletions tensorflow_gnn/experimental/sampler/eval_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,10 @@

from tensorflow_gnn.experimental.sampler import interfaces
from tensorflow_gnn.experimental.sampler import proto as pb
from tensorflow_gnn.graph import tf_internal

try:
input_layer = tf._keras_internal.engine.input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error # pylint:disable=protected-access
except AttributeError:
try:
from tf_keras.src.engine import input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error
except ImportError:
import keras # pylint:disable=g-import-not-at-top # pytype: disable=import-error
if not keras.__version__.startswith('2.'):
raise ImportError(
'tensorflow_gnn requires tf_keras to be installed or keras version <'
f' 3. Instead got keras version {keras.__version__}.'
) from None # This is Keras version mismatch, not just lacking tf_keras.
if hasattr(keras, 'src'):
from keras.src.engine import input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error
else:
from keras.engine import input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error
# More portable version of input_layer = tf._keras_internal.engine.input_layer
input_layer = tf_internal.keras_input_layer_module


@dataclasses.dataclass
Expand Down
13 changes: 11 additions & 2 deletions tensorflow_gnn/graph/tf_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
if tf.__version__.startswith("2.12."):
# tf.keras is keras 2.12, which does not yet have the `src` subdirectory.
from keras import backend as keras_backend
from keras.engine import input_layer
from keras.engine import keras_tensor as kt
from keras.layers import core as core_layers
# In 2.12, these symbols are not exposed yet under tf.keras.__internal__.
Expand All @@ -76,30 +77,34 @@
# tf.keras is keras.
# For TF 2.14, there also exists a tf_keras package, but TF does not use it.
from keras.src import backend as keras_backend
from keras.src.engine import input_layer
from keras.src.engine import keras_tensor as kt
from keras.src.layers import core as core_layers

elif tf.__version__.startswith("2.15."):
KerasTensor = tf.keras.__internal__.KerasTensor
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
# OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15
# BUT THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES
# OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15 but
# THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES (b/324019542)
# so it is essential that we pick the right one by replicating the logic from
# https://github.com/tensorflow/tensorflow/blob/r2.15/tensorflow/python/util/lazy_loader.py#L96
if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"):
from tf_keras.src import backend as keras_backend
from tf_keras.src.layers import core as core_layers
from tf_keras.src.engine import input_layer
from tf_keras.src.engine import keras_tensor as kt
else:
from keras.src import backend as keras_backend
from keras.src.layers import core as core_layers
from keras.src.engine import input_layer
from keras.src.engine import keras_tensor as kt

elif hasattr(tf, "_keras_internal"): # Special case: internal.
KerasTensor = tf.keras.__internal__.KerasTensor
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
kt = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access
core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access
input_layer = tf._keras_internal.engine.input_layer # pylint:disable=protected-access
keras_backend = tf._keras_internal.backend # pylint: disable=protected-access

else: # TF2.16 and onwards.
Expand All @@ -110,6 +115,7 @@
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
from tf_keras.src import backend as keras_backend
from tf_keras.src.layers import core as core_layers
from tf_keras.src.engine import input_layer
from tf_keras.src.engine import keras_tensor as kt

# pytype: enable=import-error
Expand All @@ -119,6 +125,9 @@
delegate_method = core_layers._delegate_method # pylint: disable=protected-access
unique_keras_object_name = keras_backend.unique_object_name

# tensorflow_gnn/experimental/sampler/eval_dag.py uses the module object itself.
keras_input_layer_module = input_layer

# Delete imports, in their order above.
del composite_tensor
del type_spec
Expand Down

0 comments on commit e1d9210

Please sign in to comment.