Skip to content
This repository was archived by the owner on Sep 25, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 152 additions & 38 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from packaging import version
import collections
import copy
import json
Expand All @@ -28,6 +28,9 @@
import tensorflow as tf


if version.parse(tf.__version__)>=version.parse("2"):
tf.compat.v1.disable_eager_execution()

class BertConfig(object):
"""Configuration for `BertModel`."""

Expand Down Expand Up @@ -90,9 +93,15 @@ def from_dict(cls, json_object):
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with tf.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
if version.parse(tf.__version__)<version.parse("2"):
with tf.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
elif version.parse(tf.__version__)>=version.parse("2"):
with tf.io.gfile.GFile(json_file, "r") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))


def to_dict(self):
"""Serializes this instance to a Python dictionary."""
Expand Down Expand Up @@ -167,9 +176,18 @@ def __init__(self,

if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope("embeddings"):
if version.parse(tf.__version__)<version.parse("2"):
scope_version=tf.variable_scope(scope, default_name="bert")
embedding_scope=tf.variable_scope("embeddings")
encoder_scope=tf.variable_scope("encoder")
pooler_Scope=tf.variable_scope("encoder")
else:
scope_version=tf.compat.v1.variable_scope(scope, default_name="bert")
embedding_scope=tf.compat.v1.variable_scope("embeddings")
encoder_scope=tf.compat.v1.variable_scope("encoder")
pooler_scope=tf.compat.v1.variable_scope("encoder")
with scope_version:
with embedding_scope:
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
input_ids=input_ids,
Expand All @@ -193,7 +211,7 @@ def __init__(self,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)

with tf.variable_scope("encoder"):
with encoder_scope:
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
# mask of shape [batch_size, seq_length, seq_length] which is used
# for the attention scores.
Expand Down Expand Up @@ -221,11 +239,21 @@ def __init__(self,
# [batch_size, hidden_size]. This is necessary for segment-level
# (or segment-pair-level) classification tasks where we need a fixed
# dimensional representation of the segment.
with tf.variable_scope("pooler"):

with pooler_scope:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token. We assume that this has been pre-trained
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
self.pooled_output = tf.layers.dense(
if version.parse(tf.__version__)<version.parse("2"):

self.pooled_output = tf.layers.dense(
first_token_tensor,
config.hidden_size,
activation=tf.tanh,
kernel_initializer=create_initializer(config.initializer_range))
else:

self.pooled_output = tf.compat.v1.layers.dense(
first_token_tensor,
config.hidden_size,
activation=tf.tanh,
Expand Down Expand Up @@ -356,13 +384,17 @@ def dropout(input_tensor, dropout_prob):
return input_tensor

output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
print(output)
return output


def layer_norm(input_tensor, name=None):
"""Run layer normalization on the last dimension of the tensor."""
return tf.contrib.layers.layer_norm(
if version.parse(tf.__version__)<version.parse("2"):
return tf.contrib.layers.layer_norm(
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
if version.parse(tf.__version__)>=version.parse("2"):
return tf.keras.layers.LayerNormalization(name=name,axis=-1,epsilon=1e-15,dtype=tf.float32)(input_tensor)


def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
Expand All @@ -374,7 +406,10 @@ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):

def create_initializer(initializer_range=0.02):
"""Creates a `truncated_normal_initializer` with the given range."""
return tf.truncated_normal_initializer(stddev=initializer_range)
if version.parse(tf.__version__)<version.parse("2"):
return tf.truncated_normal_initializer(stddev=initializer_range)
else:
return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range,dtype=tf.dtypes.float32)


def embedding_lookup(input_ids,
Expand Down Expand Up @@ -405,11 +440,17 @@ def embedding_lookup(input_ids,
# reshape to [batch_size, seq_length, 1].
if input_ids.shape.ndims == 2:
input_ids = tf.expand_dims(input_ids, axis=[-1])

embedding_table = tf.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
initializer=create_initializer(initializer_range))
if version.parse(tf.__version__)<version.parse("2"):
embedding_table = tf.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
initializer=create_initializer(initializer_range))

else:
embedding_table = tf.compat.v1.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
initializer=create_initializer(initializer_range))

flat_input_ids = tf.reshape(input_ids, [-1])
if use_one_hot_embeddings:
Expand Down Expand Up @@ -473,7 +514,13 @@ def embedding_postprocessor(input_tensor,
if token_type_ids is None:
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
token_type_table = tf.get_variable(
if version.parse(tf.__version__)<version.parse("2"):
token_type_table = tf.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
initializer=create_initializer(initializer_range))
else:
token_type_table = tf.compat.v1.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
initializer=create_initializer(initializer_range))
Expand All @@ -487,9 +534,18 @@ def embedding_postprocessor(input_tensor,
output += token_type_embeddings

if use_position_embeddings:
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
if version.parse(tf.__version__)<version.parse("2"):
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
else:
assert_op = tf.compat.v1.assert_less_equal(seq_length, max_position_embeddings)
with tf.control_dependencies([assert_op]):
full_position_embeddings = tf.get_variable(
if version.parse(tf.__version__)<version.parse("2"):
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
else:
full_position_embeddings = tf.compat.v1.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
Expand Down Expand Up @@ -663,29 +719,55 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
to_tensor_2d = reshape_to_matrix(to_tensor)

# `query_layer` = [B*F, N*H]
query_layer = tf.layers.dense(

if version.parse(tf.__version__)<version.parse("2"):
query_layer = tf.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
kernel_initializer=create_initializer(initializer_range))

# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(

# `key_layer` = [B*T, N*H]
key_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
kernel_initializer=create_initializer(initializer_range))

# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
# `value_layer` = [B*T, N*H]
value_layer = tf.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name="value",
kernel_initializer=create_initializer(initializer_range))
else:
query_layer = tf.compat.v1.layers.dense(
from_tensor_2d,
num_attention_heads * size_per_head,
activation=query_act,
name="query",
kernel_initializer=create_initializer(initializer_range))


# `key_layer` = [B*T, N*H]
key_layer = tf.compat.v1.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=key_act,
name="key",
kernel_initializer=create_initializer(initializer_range))

# `value_layer` = [B*T, N*H]
value_layer = tf.compat.v1.layers.dense(
to_tensor_2d,
num_attention_heads * size_per_head,
activation=value_act,
name="value",
kernel_initializer=create_initializer(initializer_range))
# `query_layer` = [B, N, F, H]
query_layer = transpose_for_scores(query_layer, batch_size,
num_attention_heads, from_seq_length,
Expand Down Expand Up @@ -824,12 +906,24 @@ def transformer_model(input_tensor,

all_layer_outputs = []
for layer_idx in range(num_hidden_layers):
with tf.variable_scope("layer_%d" % layer_idx):
if version.parse(tf.__version__)<version.parse("2"):
layer_scope=tf.variable_scope("layer_%d" % layer_idx)
attention_scope=tf.variable_scope("attention")
self_scope=tf.variable_scope("self")
output_scope=tf.variable_scope("output")
intermediate_scope=tf.variable_scope("intermediate")
else:
layer_scope=tf.compat.v1.variable_scope("layer_%d" % layer_idx)
attention_scope=tf.compat.v1.variable_scope("attention")
self_scope=tf.compat.v1.variable_scope("self")
output_scope=tf.compat.v1.variable_scope("output")
intermediate_scope=tf.compat.v1.variable_scope("intermediate")
with layer_scope:
layer_input = prev_output

with tf.variable_scope("attention"):
with attention_scope:
attention_heads = []
with tf.variable_scope("self"):
with self_scope:
attention_head = attention_layer(
from_tensor=layer_input,
to_tensor=layer_input,
Expand All @@ -854,25 +948,45 @@ def transformer_model(input_tensor,

# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
with tf.variable_scope("output"):
attention_output = tf.layers.dense(
with output_scope:
if version.parse(tf.__version__)<version.parse("2"):

attention_output = tf.layers.dense(
attention_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
attention_output = dropout(attention_output, hidden_dropout_prob)
attention_output = layer_norm(attention_output + layer_input)

attention_output = dropout(attention_output, hidden_dropout_prob)
attention_output = layer_norm(attention_output + layer_input)
else:
attention_output = tf.compat.v1.layers.dense(
attention_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
attention_output = dropout(attention_output, hidden_dropout_prob)
attention_output = layer_norm(attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
with tf.variable_scope("intermediate"):
intermediate_output = tf.layers.dense(
with intermediate_scope:
if version.parse(tf.__version__)<version.parse("2"):
intermediate_output = tf.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))
else:
intermediate_output = tf.compat.v1.layers.dense(
attention_output,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))

# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
with output_scope:
if version.parse(tf.__version__)<version.parse("2"):
layer_output = tf.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
else:
layer_output = tf.compat.v1.layers.dense(
intermediate_output,
hidden_size,
kernel_initializer=create_initializer(initializer_range))
Expand Down
8 changes: 6 additions & 2 deletions modeling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from packaging import version
import collections
import json
import random
Expand Down Expand Up @@ -135,8 +135,12 @@ def test_config_to_json_string(self):
def run_tester(self, tester):
with self.test_session() as sess:
ops = tester.create_model()
init_op = tf.group(tf.global_variables_initializer(),
if version.parse(tf.__version__)<version.parse("2"):
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
else:
init_op = tf.group(tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer())
sess.run(init_op)
output_result = sess.run(ops)
tester.check_output(output_result)
Expand Down
2 changes: 1 addition & 1 deletion optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
return train_op


class AdamWeightDecayOptimizer(tf.train.Optimizer):
class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer):
"""A basic Adam optimizer that includes "correct" L2 weight decay."""

def __init__(self,
Expand Down
Loading