Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use highway network and employ decaying learning rate #58

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,17 @@ optional arguments:
--num_epochs NUM_EPOCHS
Number of training epochs (default: 100)
--evaluate_every EVALUATE_EVERY
Evaluate model on dev set after this many steps
(default: 100)
Evaluate model on dev set after this many steps (default: 100)
--checkpoint_every CHECKPOINT_EVERY
Save model after this many steps (default: 100)
--learning_rate LEARNING_RATE
The start learning rate (default: 0.001)
--decay_step DECAY_STEP
Decay step for rmsprop (default: 500)
--decay_rate DECAY_RATE
Decay rate for rmsprop (default: 0.98)
--use_highway USE_HIGHWAY
Use the highway network (default: True)
--allow_soft_placement ALLOW_SOFT_PLACEMENT
Allow device soft device placement
--noallow_soft_placement
Expand Down
8 changes: 4 additions & 4 deletions data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def batch_iter(data, batch_size, num_epochs, shuffle=True):
"""
data = np.array(data)
data_size = len(data)
num_batches_per_epoch = int((len(data)-1)/batch_size) + 1
for epoch in range(num_epochs):
num_batches_per_epoch = int((data_size-1)/batch_size) + 1
for epoch in xrange(num_epochs):
# Shuffle the data at each epoch
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(num_batches_per_epoch):
for batch_num in xrange(num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_index:end_index]
yield shuffled_data[start_index:end_index], epoch
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
# Collect the predictions here
all_predictions = []

for x_test_batch in batches:
for x_test_batch, _ in batches:
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])

Expand Down
36 changes: 30 additions & 6 deletions text_cnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
import tensorflow as tf
import numpy as np

# highway layer that borrowed from https://github.com/carpedm20/lstm-char-cnn-tensorflow
def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu):
"""Highway Network (cf. http://arxiv.org/abs/1505.00387).
t = sigmoid(Wy + b)
z = t * g(Wy + b) + (1 - t) * y
where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.
"""
output = input_
for idx in xrange(layer_size):
output = f(tf.nn.rnn_cell._linear(output, size, 0, scope='output_lin_%d' % idx))
transform_gate = tf.sigmoid(
tf.nn.rnn_cell._linear(input_, size, 0, scope='transform_lin_%d' % idx) + bias)
carry_gate = 1. - transform_gate
output = transform_gate * output + carry_gate * input_
return output


class TextCNN(object):
Expand All @@ -8,8 +23,8 @@ class TextCNN(object):
Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer.
"""
def __init__(
self, sequence_length, num_classes, vocab_size,
embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):
self, sequence_length, num_classes, vocab_size,
embedding_size, filter_sizes, num_filters, use_highway=True, l2_reg_lambda=0.0):

# Placeholders for input, output and dropout
self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
Expand Down Expand Up @@ -57,9 +72,18 @@ def __init__(
self.h_pool = tf.concat(3, pooled_outputs)
self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])

# Add dropout
with tf.name_scope("dropout"):
self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)
if use_highway:
# Add highway
with tf.name_scope("highway"):
self.h_highway = highway(self.h_pool_flat, self.h_pool_flat.get_shape()[1], 1, 0)

# Add dropout
with tf.name_scope("dropout"):
self.h_drop = tf.nn.dropout(self.h_highway, self.dropout_keep_prob)
else:
# Add dropout
with tf.name_scope("dropout"):
self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)

# Final (unnormalized) scores and predictions
with tf.name_scope("output"):
Expand Down
32 changes: 23 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (default: 100)")
tf.flags.DEFINE_float("learning_rate", 0.001, "The start learning rate (default: 0.001)")
tf.flags.DEFINE_float("decay_step", 500, "Decay step for rmsprop (default: 500)")
tf.flags.DEFINE_float("decay_rate", 0.98, "Decay rate for rmsprop (default: 0.98)")
tf.flags.DEFINE_boolean("use_highway", True, "Use the highway network (default: True)")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
Expand Down Expand Up @@ -84,11 +88,15 @@
embedding_size=FLAGS.embedding_dim,
filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
num_filters=FLAGS.num_filters,
use_highway=FLAGS.use_highway,
l2_reg_lambda=FLAGS.l2_reg_lambda)

# Define Training procedure
global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.AdamOptimizer(1e-3)
learning_rate = tf.train.exponential_decay(
FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate)
optimizer = tf.train.AdamOptimizer(learning_rate)

grads_and_vars = optimizer.compute_gradients(cnn.loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

Expand Down Expand Up @@ -126,15 +134,18 @@
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver(tf.global_variables())
saver = tf.train.Saver()

# Write vocabulary
vocab_processor.save(os.path.join(out_dir, "vocab"))

# Initialize all variables
sess.run(tf.global_variables_initializer())
try:
sess.run(tf.global_variables_initializer())
except:
sess.run(tf.initialize_all_variables())

def train_step(x_batch, y_batch):
def train_step(x_batch, y_batch, epoch):
"""
A single training step
"""
Expand All @@ -143,11 +154,11 @@ def train_step(x_batch, y_batch):
cnn.input_y: y_batch,
cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
}
_, step, summaries, loss, accuracy = sess.run(
[train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],
_, step, lr, summaries, loss, accuracy = sess.run(
[train_op, global_step, learning_rate, train_summary_op, cnn.loss, cnn.accuracy],
feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
print("{}: epoch {}/{}, step {}, lr {:.6f} , loss {:g}, acc {:g}".format(time_str, epoch, FLAGS.num_epochs, step, lr, loss, accuracy))
train_summary_writer.add_summary(summaries, step)

def dev_step(x_batch, y_batch, writer=None):
Expand All @@ -171,9 +182,9 @@ def dev_step(x_batch, y_batch, writer=None):
batches = data_helpers.batch_iter(
list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
# Training loop. For each batch...
for batch in batches:
for batch, epoch in batches:
x_batch, y_batch = zip(*batch)
train_step(x_batch, y_batch)
train_step(x_batch, y_batch, epoch)
current_step = tf.train.global_step(sess, global_step)
if current_step % FLAGS.evaluate_every == 0:
print("\nEvaluation:")
Expand All @@ -182,3 +193,6 @@ def dev_step(x_batch, y_batch, writer=None):
if current_step % FLAGS.checkpoint_every == 0:
path = saver.save(sess, checkpoint_prefix, global_step=current_step)
print("Saved model checkpoint to {}\n".format(path))
current_step = tf.train.global_step(sess, global_step)
path = saver.save(sess, checkpoint_prefix, global_step=current_step)
print("Saved model checkpoint to {}\n".format(path))