-
Notifications
You must be signed in to change notification settings - Fork 13
/
experiment.py
46 lines (36 loc) · 1.39 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from hbconfig import Config
import tensorflow as tf
from data_loader import TextLoader
import dataset
from model import CharRNN
import hook
def experiment_fn(run_config, params):
char_rnn = CharRNN()
estimator = tf.estimator.Estimator(
model_fn=char_rnn.model_fn,
model_dir=Config.train.model_dir,
params=params,
config=run_config)
data_loader = TextLoader(Config.data.data_dir,
batch_size=params.batch_size,
seq_length=params.seq_length)
Config.data.vocab_size = data_loader.vocab_size
train_X, test_X, train_y, test_y = data_loader.make_train_and_test_set()
train_input_fn, train_input_hook = dataset.get_train_inputs(train_X, train_y)
test_input_fn, test_input_hook = dataset.get_test_inputs(test_X, test_y)
experiment = tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_steps=Config.train.train_steps,
#min_eval_frequency=Config.train.min_eval_frequency,
train_monitors=[
train_input_hook,
hook.print_variables(
variables=['training/output_0', 'prediction_0'],
vocab=data_loader.vocab,
every_n_iter=Config.train.check_hook_n_iter)],
eval_hooks=[test_input_hook],
#eval_steps=None
)
return experiment