diff --git a/tensorflow_gnn/experimental/sampler/eval_dag.py b/tensorflow_gnn/experimental/sampler/eval_dag.py index 1e40dd3e..afa00f3e 100644 --- a/tensorflow_gnn/experimental/sampler/eval_dag.py +++ b/tensorflow_gnn/experimental/sampler/eval_dag.py @@ -27,23 +27,10 @@ from tensorflow_gnn.experimental.sampler import interfaces from tensorflow_gnn.experimental.sampler import proto as pb +from tensorflow_gnn.graph import tf_internal -try: - input_layer = tf._keras_internal.engine.input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error # pylint:disable=protected-access -except AttributeError: - try: - from tf_keras.src.engine import input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error - except ImportError: - import keras # pylint:disable=g-import-not-at-top # pytype: disable=import-error - if not keras.__version__.startswith('2.'): - raise ImportError( - 'tensorflow_gnn requires tf_keras to be installed or keras version <' - f' 3. Instead got keras version {keras.__version__}.' - ) from None # This is Keras version mismatch, not just lacking tf_keras. - if hasattr(keras, 'src'): - from keras.src.engine import input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error - else: - from keras.engine import input_layer # pylint:disable=g-import-not-at-top # pytype: disable=import-error +# More portable version of input_layer = tf._keras_internal.engine.input_layer +input_layer = tf_internal.keras_input_layer_module @dataclasses.dataclass diff --git a/tensorflow_gnn/graph/tf_internal.py b/tensorflow_gnn/graph/tf_internal.py index 213cca5b..640091cd 100644 --- a/tensorflow_gnn/graph/tf_internal.py +++ b/tensorflow_gnn/graph/tf_internal.py @@ -64,6 +64,7 @@ if tf.__version__.startswith("2.12."): # tf.keras is keras 2.12, which does not yet have the `src` subdirectory. from keras import backend as keras_backend + from keras.engine import input_layer from keras.engine import keras_tensor as kt from keras.layers import core as core_layers # In 2.12, these symbols are not exposed yet under tf.keras.__internal__. @@ -76,23 +77,26 @@ # tf.keras is keras. # For TF 2.14, there also exists a tf_keras package, but TF does not use it. from keras.src import backend as keras_backend + from keras.src.engine import input_layer from keras.src.engine import keras_tensor as kt from keras.src.layers import core as core_layers elif tf.__version__.startswith("2.15."): KerasTensor = tf.keras.__internal__.KerasTensor RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor - # OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15 - # BUT THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES + # OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15 but + # THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES (b/324019542) # so it is essential that we pick the right one by replicating the logic from # https://github.com/tensorflow/tensorflow/blob/r2.15/tensorflow/python/util/lazy_loader.py#L96 if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"): from tf_keras.src import backend as keras_backend from tf_keras.src.layers import core as core_layers + from tf_keras.src.engine import input_layer from tf_keras.src.engine import keras_tensor as kt else: from keras.src import backend as keras_backend from keras.src.layers import core as core_layers + from keras.src.engine import input_layer from keras.src.engine import keras_tensor as kt elif hasattr(tf, "_keras_internal"): # Special case: internal. @@ -100,6 +104,7 @@ RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor kt = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access + input_layer = tf._keras_internal.engine.input_layer # pylint:disable=protected-access keras_backend = tf._keras_internal.backend # pylint: disable=protected-access else: # TF2.16 and onwards. @@ -110,6 +115,7 @@ RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor from tf_keras.src import backend as keras_backend from tf_keras.src.layers import core as core_layers + from tf_keras.src.engine import input_layer from tf_keras.src.engine import keras_tensor as kt # pytype: enable=import-error @@ -119,6 +125,9 @@ delegate_method = core_layers._delegate_method # pylint: disable=protected-access unique_keras_object_name = keras_backend.unique_object_name +# tensorflow_gnn/experimental/sampler/eval_dag.py uses the module object itself. +keras_input_layer_module = input_layer + # Delete imports, in their order above. del composite_tensor del type_spec