forked from ArminKmz/im2latex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparams.py
93 lines (80 loc) · 2.31 KB
/
params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
config = {
# ======================
# pathes
# ======================
'formulas_train_path' : 'train.txt',
'formulas_validation_path' : 'validate.txt',
'formulas_test_path' : 'test.txt',
'images_train_path' : 'train_imgs',
'images_validation_path' : 'val_imgs',
'images_test_path' : 'test_imgs',
'checkpoints_dir' : 'checkpoints/',
'log_dir' : 'logs/log.txt',
# ======================
# general parameters
# ======================
'batch_size' : 32,
'epochs' : 40,
'print_every_batch' : 100,
'learning_rate' : 1e-3,
'learning_rate_decay' : .5,
'learning_rate_decay_step' : 3,
'learning_rate_min': 1e-6,
'teacher_forcing_ratio' : 1,
'teacher_forcing_ratio_decay' : 0.95,
'teacher_forcing_ratio_decay_step' : 2,
'teacher_forcing_ratio_min': 0.7,
'clip' : 5,
'unk_token_threshold' : 10,
'generation_method' : 'greedy', # or 'beam-search'
'device' : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
# ======================
# CNN parameters
# ======================
'cnn_params' : {
'conv1_c' : 64,
'conv1_k' : (3, 3),
'conv1_s' : (1, 1),
'conv1_p' : (1, 1),
'pool1_k' : (2, 2),
'pool1_s' : (2, 2),
'pool1_p' : (0, 0),
'conv2_c' : 128,
'conv2_k' : (3, 3),
'conv2_s' : (1, 1),
'conv2_p' : (1, 1),
'pool2_k' : (2, 2),
'pool2_s' : (2, 2),
'pool2_p' : (0, 0),
'conv3_c' : 256,
'conv3_k' : (3, 3),
'conv3_s' : (1, 1),
'conv3_p' : (1, 1),
'conv4_c' : 256,
'conv4_k' : (3, 3),
'conv4_s' : (1, 1),
'conv4_p' : (1, 1),
'pool3_k' : (2, 1),
'pool3_s' : (2, 1),
'pool3_p' : (0, 0),
'conv5_c' : 512,
'conv5_k' : (3, 3),
'conv5_s' : (1, 1),
'conv5_p' : (1, 1),
'pool4_k' : (1, 2),
'pool4_s' : (1, 2),
'pool4_p' : (0, 0),
'conv6_c' : 512,
'conv6_k' : (3, 3),
'conv6_s' : (1, 1),
'conv6_p' : (1, 1),
},
# ======================
# seq2seq parameters
# ======================
'embedding_size' : 80,
'decoder_hidden_size' : 512,
'encoder_hidden_size' : 256,
'bidirectional' : True,
}