-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest3.py
58 lines (47 loc) · 2.09 KB
/
test3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from argparse import ArgumentParser
from data import get_dataset
from model import get_model
from trainer import build_trainer
def main(args):
# Load model
model = get_model(args)
title = "Document-Level Abstractive Summarization"
with open("final.txt", "r") as f:
text = f.read()
predictions = self.generate(
batch["prompt_ids"],
batch["id"],
batch["text_tokenized"],
self.hparams.retrieval,
)
print(predictions)
if __name__ == "__main__":
parser = ArgumentParser()
# Model
parser.add_argument("--model_name", type=str, default="retrosum")
parser.add_argument("--from_checkpoint", type=str, default="")
parser.add_argument("--max_input_length", type=int, default=512)
parser.add_argument("--max_output_length", type=int, default=512)
parser.add_argument("--chunk_size", type=int, default=64)
parser.add_argument("--n_neighbors", type=int, default=2)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--no_repeat_ngram_size", type=int, default=5)
parser.add_argument("--retrieval", action="store_true")
# Data
parser.add_argument("--data_name", type=str, default="arxiv")
parser.add_argument("--train_path", type=str, default=None)
parser.add_argument("--val_path", type=str, default=None)
parser.add_argument("--test_path", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--test_batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=8)
# Trainer
parser.add_argument("--mode", type=str, default="test")
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--accumulate_grad_batches", type=int, default=8)
parser.add_argument("--val_check_interval", type=float, default=0.5)
parser.add_argument("--monitor", type=str, default="val_loss")
parser.add_argument("--results_filename", type=str, default="results.json")
parser.add_argument("--fast_dev_run", action="store_true")
args = parser.parse_args()
main(args)