-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathparams.py
56 lines (43 loc) · 1.26 KB
/
params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
def basic_params():
'''A set of basic hyperparameters'''
return tf.contrib.training.HParams(
dtype = tf.float32,
voca_size = 34004,
embedding_trainable = False,
hidden_size = 800,
encoder_layer = 1,
decoder_layer = 1,
answer_layer = 1,
dec_init_ans = True,
maxlen_q_train = 32,
maxlen_q_dev = 27,
maxlen_q_test = 27,
rnn_dropout = 0.4,
start_token = 1, # <GO> index
end_token = 2, # <EOS> index
# Keyword-net related parameters
use_keyword = 2,
# Attention related parameters
attn = 'normed_bahdanau',
# Output layer related parameters
if_wean = True,
# Training related parameters
batch_size = 64,
learning_rate = 0.001,
decay_step = None,
decay_rate = 0.5,
# Beam Search
beam_width = 10,
length_penalty_weight = 2.1
)
def h200_batch64():
params = basic_params()
params.hidden_size = 200
params.batch_size = 64
return params
def h512_batch128():
params = basic_params()
params.hidden_size = 512
params.batch_size = 128
return params