-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
58 lines (47 loc) · 2.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from argparse import ArgumentParser
from data import get_dataset
from model import get_model
from trainer import build_trainer
def main(args):
# Load model
model = get_model(args)
# Load dataset
data = get_dataset(model)
# Get Trainer
trainer = build_trainer(args)
if "train" in args.mode:
trainer.fit(model, data)
elif "test" in args.mode:
trainer.test(model, data, verbose=True)
else:
print(f"Unrecognized mode: {args.mode}")
if __name__ == "__main__":
parser = ArgumentParser()
# Model
parser.add_argument("--model_name", type=str, default="retrosum")
parser.add_argument("--from_checkpoint", type=str, default="")
parser.add_argument("--max_input_length", type=int, default=512)
parser.add_argument("--max_output_length", type=int, default=512)
parser.add_argument("--chunk_size", type=int, default=64)
parser.add_argument("--n_neighbors", type=int, default=2)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--no_repeat_ngram_size", type=int, default=5)
parser.add_argument("--retrieval", action="store_true")
# Data
parser.add_argument("--data_name", type=str, default="arxiv")
parser.add_argument("--train_path", type=str, default=None)
parser.add_argument("--val_path", type=str, default=None)
parser.add_argument("--test_path", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--test_batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=8)
# Trainer
parser.add_argument("--mode", type=str, default="test")
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--accumulate_grad_batches", type=int, default=8)
parser.add_argument("--val_check_interval", type=float, default=0.5)
parser.add_argument("--monitor", type=str, default="val_loss")
parser.add_argument("--results_filename", type=str, default="results.json")
parser.add_argument("--fast_dev_run", action="store_true")
args = parser.parse_args()
main(args)