diff --git a/modeling.py b/modeling.py index fed525971..49f99d4b7 100644 --- a/modeling.py +++ b/modeling.py @@ -17,7 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from packaging import version import collections import copy import json @@ -28,6 +28,9 @@ import tensorflow as tf +if version.parse(tf.__version__)>=version.parse("2"): + tf.compat.v1.disable_eager_execution() + class BertConfig(object): """Configuration for `BertModel`.""" @@ -90,9 +93,15 @@ def from_dict(cls, json_object): @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" - with tf.gfile.GFile(json_file, "r") as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) + if version.parse(tf.__version__)=version.parse("2"): + with tf.io.gfile.GFile(json_file, "r") as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + def to_dict(self): """Serializes this instance to a Python dictionary.""" @@ -167,9 +176,18 @@ def __init__(self, if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) - - with tf.variable_scope(scope, default_name="bert"): - with tf.variable_scope("embeddings"): + if version.parse(tf.__version__)