Skip to content

Commit 5537421

Browse files
committed
Initialize the project.
1 parent effb5f0 commit 5537421

13 files changed

+1055
-1
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,4 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
.idea

README.md

+28-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,28 @@
1-
# AtTGen
1+
# AtTGen
2+
3+
## Usage
4+
5+
```
6+
usage: python3 main.py [-h] [--name NAME] [--do_train] [--do_eval]
7+
[--data_dir DATA_DIR] [--seed SEED] [--gpu_ids GPU_IDS]
8+
[--batch_size BATCH_SIZE] [--lr LR] [--epoch EPOCH]
9+
[--emb_dim EMB_DIM] [--encode_dim ENCODE_DIM]
10+
11+
configuration
12+
13+
optional arguments:
14+
-h, --help show this help message and exit
15+
--name NAME Experiment name, for logging and saving models
16+
--do_train Whether to run traininog.
17+
--do_eval Whether to run eval on the test set.
18+
--data_dir DATA_DIR The input data dir.
19+
--seed SEED The random seed for initialization
20+
--gpu_ids GPU_IDS The GPU ids
21+
--batch_size BATCH_SIZE
22+
Total batch size for training.
23+
--lr LR The initial learning rate for Adam.
24+
--epoch EPOCH Total number of training epochs to perform.
25+
--emb_dim EMB_DIM The dimension of the embedding
26+
--encode_dim ENCODE_DIM
27+
The dimension of the encoding
28+
```

args.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import argparse
2+
3+
import torch
4+
5+
from utils import get_device
6+
7+
8+
def get_args():
9+
parser = argparse.ArgumentParser(description='configuration')
10+
11+
parser.add_argument("--name",
12+
default="1",
13+
type=str,
14+
help="Experiment name, for logging and saving models")
15+
parser.add_argument("--do_train",
16+
action='store_true',
17+
default=True,
18+
help="Whether to run traininog.")
19+
parser.add_argument("--do_eval",
20+
action='store_true',
21+
default=True,
22+
help="Whether to run eval on the test set.")
23+
parser.add_argument("--data_dir",
24+
default="./data/CNShipNet",
25+
type=str,
26+
help="The input data dir.")
27+
parser.add_argument("--word_vocab",
28+
default="word_vocab.json",
29+
type=str,
30+
help="The vocabulary file.")
31+
parser.add_argument("--ontology_vocab",
32+
default="attribute_vocab.json",
33+
type=str,
34+
help="The ontology class file.")
35+
parser.add_argument('--seed',
36+
type=int,
37+
default=42,
38+
help="The random seed for initialization")
39+
parser.add_argument('--gpu_ids',
40+
type=str,
41+
default='0',
42+
help="The GPU ids")
43+
44+
# Hyperparameters
45+
# Batch size
46+
parser.add_argument("--batch_size",
47+
default=512,
48+
type=int,
49+
help="Total batch size for training.")
50+
# Learning rate
51+
parser.add_argument("--lr",
52+
default=2e-4,
53+
type=float,
54+
help="The initial learning rate for Adam.")
55+
# Epochs
56+
parser.add_argument("--epoch",
57+
default=40,
58+
type=int,
59+
help="Total number of training epochs to perform.")
60+
# emb_dim
61+
parser.add_argument("--emb_dim",
62+
default=200,
63+
type=int,
64+
help="The dimension of the embedding")
65+
# encode_dim
66+
parser.add_argument("--encode_dim",
67+
default=200,
68+
type=int,
69+
help="The dimension of the encoding")
70+
71+
args = parser.parse_args()
72+
if args.gpu_ids == "":
73+
n_gpu = 0
74+
device = torch.device('cpu')
75+
else:
76+
gpu_ids = [int(device_id) for device_id in args.gpu_ids.split()]
77+
args.gpu_ids = gpu_ids
78+
device, n_gpu = get_device(gpu_ids[0])
79+
if n_gpu > 1:
80+
n_gpu = len(gpu_ids)
81+
args.device = device
82+
args.n_gpu = n_gpu
83+
84+
return args

data/.gitkeep

Whitespace-only changes.

dataloader.py

+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import json
2+
import os
3+
from typing import Dict, List, Tuple
4+
5+
import numpy as np
6+
import torch
7+
from torch.utils.data import Dataset
8+
from tqdm import tqdm
9+
10+
from utils import find_entity_id_from_tokens, seq_padding, sort_all
11+
12+
13+
class TreeDataset(Dataset):
14+
def __init__(self, data_dir: str = './data/CNShipNet', data_type: str = "train",
15+
word_vocab: str = 'word_vocab.json', ontology_vocab: str = 'attribute_vocab.json',
16+
order: List[str] = ("subject", "object", "predicate")):
17+
print('Loading {} data...'.format(data_type))
18+
self.data_dir = data_dir
19+
self.order = order
20+
self.word_vocab = json.load(open(os.path.join(data_dir, word_vocab)))
21+
self.ontology_vocab = json.load(open(os.path.join(data_dir, ontology_vocab)))
22+
vocab_size = len(self.word_vocab)
23+
rel_token = {k: (v + vocab_size) for k, v in self.ontology_vocab.items()}
24+
self.word_vocab.update(rel_token)
25+
self.word_vocab['[pre]'] = len(self.word_vocab)
26+
self.ontology_class = list(rel_token.keys())
27+
28+
self.tokenizer = lambda text: text.split(" ")
29+
30+
self.text = []
31+
self.text_length = []
32+
self.spo_list = []
33+
self.token_ids = []
34+
self.S1 = []
35+
self.S2 = []
36+
self.S_K1_in = []
37+
self.O_K1_in = []
38+
self.S_K2_in = []
39+
self.O_K2_in = []
40+
self.O1 = []
41+
self.O2 = []
42+
self.P1 = []
43+
self.P2 = []
44+
self.P_K1_in = []
45+
self.P_K2_in = []
46+
file = open(os.path.join(self.data_dir, "new_{}_data.json".format(data_type))).read().strip().split('\n')
47+
for line in tqdm(file):
48+
instance = json.loads(line)
49+
if data_type == 'train':
50+
expanded_instances = self.spo_to_seq(instance["text"], instance["spo_list"], self.tokenizer,
51+
self.ontology_class)
52+
instances = expanded_instances
53+
else:
54+
token = self.tokenizer(instance["text"]) + ['[pre]'] + self.ontology_class
55+
instance['text'] = ' '.join(token)
56+
instances = [instance]
57+
for instance in instances:
58+
text = instance["text"]
59+
spo_list = instance["spo_list"]
60+
text_id = []
61+
for c in (self.tokenizer(text)):
62+
text_id.append(self.word_vocab.get(c, self.word_vocab["<oov>"]))
63+
self.text_length.append(len(text_id))
64+
assert len(text_id) > 0
65+
self.token_ids.append(text_id)
66+
67+
s_k1 = instance.get("s_k1", 0)
68+
s_k2 = instance.get("s_k2", 0)
69+
o_k1 = instance.get("o_k1", 0)
70+
o_k2 = instance.get("o_k2", 0)
71+
p_k1 = instance.get("p_k1", 0)
72+
p_k2 = instance.get("p_k2", 0)
73+
74+
s1_gt = instance.get("s1_gt", [])
75+
s2_gt = instance.get("s2_gt", [])
76+
o1_gt = instance.get("o1_gt", [])
77+
o2_gt = instance.get("o2_gt", [])
78+
p1_gt = instance.get("p1_gt", [])
79+
p2_gt = instance.get("p2_gt", [])
80+
81+
self.text.append(self.tokenizer(text))
82+
self.spo_list.append(spo_list)
83+
84+
self.S1.append(s1_gt)
85+
self.S2.append(s2_gt)
86+
self.O1.append(o1_gt)
87+
self.O2.append(o2_gt)
88+
self.P1.append(p1_gt)
89+
self.P2.append(p2_gt)
90+
self.S_K1_in.append([s_k1])
91+
self.S_K2_in.append([s_k2])
92+
self.O_K1_in.append([o_k1])
93+
self.O_K2_in.append([o_k2])
94+
self.P_K1_in.append([p_k1])
95+
self.P_K2_in.append([p_k2])
96+
97+
self.token_ids = np.array(seq_padding(self.token_ids))
98+
99+
# training
100+
self.S1 = np.array(seq_padding(self.S1))
101+
self.S2 = np.array(seq_padding(self.S2))
102+
self.O1 = np.array(seq_padding(self.O1))
103+
self.O2 = np.array(seq_padding(self.O2))
104+
self.P1 = np.array(seq_padding(self.P1))
105+
self.P2 = np.array(seq_padding(self.P2))
106+
107+
# self.K1_in, self.K2_in = np.array(self.K1_in), np.array(self.K2_in)
108+
# only two time step are used for training
109+
self.S_K1_in = np.array(self.S_K1_in)
110+
self.S_K2_in = np.array(self.S_K2_in)
111+
self.O_K1_in = np.array(self.O_K1_in)
112+
self.O_K2_in = np.array(self.O_K2_in)
113+
self.P_K1_in = np.array(self.P_K1_in)
114+
self.P_K2_in = np.array(self.P_K2_in)
115+
116+
def __getitem__(self, index):
117+
return (
118+
self.token_ids[index],
119+
self.S1[index],
120+
self.S2[index],
121+
self.O1[index],
122+
self.O2[index],
123+
self.P1[index],
124+
self.P2[index],
125+
self.S_K1_in[index],
126+
self.S_K2_in[index],
127+
self.O_K1_in[index],
128+
self.O_K2_in[index],
129+
self.P_K1_in[index],
130+
self.P_K2_in[index],
131+
self.text[index], # original text
132+
self.text_length[index], # token length
133+
self.spo_list[index], # spo list
134+
)
135+
136+
def __len__(self):
137+
return len(self.text)
138+
139+
def spo_to_seq(self, text, spo_list, tokenizer, ontology_class):
140+
tree = self.spo_to_tree(spo_list, self.order)
141+
tokens = tokenizer(text) + ['[pre]'] + ontology_class
142+
143+
def to_ent(outp):
144+
# side effect!
145+
ent1, ent2 = [[0] * len(tokens) for _ in range(2)]
146+
for name in outp:
147+
_id = find_entity_id_from_tokens(tokens, self.tokenizer(name))
148+
ent1[_id] = 1
149+
ent2[_id + len(self.tokenizer(name)) - 1] = 1
150+
return ent1, ent2
151+
152+
def to_in_key(inp, name):
153+
# side effect!
154+
if not inp:
155+
return 0, 0
156+
157+
k1 = find_entity_id_from_tokens(tokens, self.tokenizer(inp))
158+
k2 = k1 + len(self.tokenizer(inp)) - 1
159+
out = k1, k2
160+
return out
161+
162+
results = []
163+
for t in tree:
164+
t1_in, t2_in, t1_out, t2_out, t3_out = t
165+
for name, ori_out, ori_in in zip(
166+
self.order, (t1_out, t2_out, t3_out), (t1_in, t2_in, None)
167+
):
168+
new_out = to_ent(ori_out)
169+
if name == "predicate":
170+
p1, p2 = new_out
171+
p_k1, p_k2 = to_in_key(ori_in, name)
172+
elif name == "subject":
173+
s1, s2 = new_out
174+
s_k1, s_k2 = to_in_key(ori_in, name)
175+
elif name == "object":
176+
o1, o2 = new_out
177+
o_k1, o_k2 = to_in_key(ori_in, name)
178+
else:
179+
raise ValueError("should be in predicate, subject, object")
180+
181+
result = {
182+
"text": ' '.join(tokens),
183+
"spo_list": spo_list,
184+
"s_k1": s_k1,
185+
"s_k2": s_k2,
186+
"o_k1": o_k1,
187+
"o_k2": o_k2,
188+
"p_k1": p_k1,
189+
"p_k2": p_k2,
190+
"s1_gt": s1,
191+
"s2_gt": s2,
192+
"o1_gt": o1,
193+
"o2_gt": o2,
194+
"p1_gt": p1,
195+
"p2_gt": p2,
196+
}
197+
198+
results.append(result)
199+
return results
200+
201+
def spo_to_tree(self, spo_list: List[Dict[str, str]], order=("subject", "object", "predicate")):
202+
"""return the ground truth of the tree: rel, subj, obj, used for teacher forcing.
203+
r: given text, one of the relations
204+
s: given r_1, one of the subjects
205+
rel: multi-label classification of relation
206+
subj: multi-label classification of subject
207+
obj: multi-label classification of object
208+
Arguments:
209+
spo_list {List[Dict[str, str]]} -- [description]
210+
Returns:
211+
List[Tuple[str]] -- [(r, s, rel, subj, obj)]
212+
"""
213+
result = []
214+
t1_out = list(set(t[order[0]] for t in spo_list))
215+
for t1_in in t1_out:
216+
t2_out = list(set(t[order[1]] for t in spo_list if t[order[0]] == t1_in))
217+
for t2_in in t2_out:
218+
t3_out = list(
219+
set(
220+
t[order[2]]
221+
for t in spo_list
222+
if t[order[0]] == t1_in and t[order[1]] == t2_in
223+
)
224+
)
225+
result.append((t1_in, t2_in, t1_out, t2_out, t3_out))
226+
return result
227+
228+
229+
def collate(batch: List[Tuple]):
230+
batch_data = list(zip(*batch))
231+
token_len = batch_data[-2]
232+
batch_data, orig_idx = sort_all(batch_data, token_len)
233+
token_ids, s1, s2, o1, o2, p1, p2, s_k1_in, s_k2_in, o_k1_in, o_k2_in, p_k1_in, p_k2_in, text, token_len, spo_list = batch_data
234+
token_ids = torch.LongTensor(np.array(token_ids))
235+
s1 = torch.FloatTensor(np.array(s1))
236+
s2 = torch.FloatTensor(np.array(s2))
237+
o1 = torch.FloatTensor(np.array(o1))
238+
o2 = torch.FloatTensor(np.array(o2))
239+
p1 = torch.FloatTensor(np.array(p1))
240+
p2 = torch.FloatTensor(np.array(p2))
241+
s_k1_in = torch.LongTensor(np.array(s_k1_in))
242+
s_k2_in = torch.LongTensor(np.array(s_k2_in))
243+
o_k1_in = torch.LongTensor(np.array(o_k1_in))
244+
o_k2_in = torch.LongTensor(np.array(o_k2_in))
245+
p_k1_in = torch.LongTensor(np.array(p_k1_in))
246+
p_k2_in = torch.LongTensor(np.array(p_k2_in))
247+
token_len = torch.LongTensor(token_len)
248+
return {'token_ids': token_ids, 's1': s1, 's2': s2, 'o1': o1, 'o2': o2, 'p1': p1, 'p2': p2, 's_k1_in': s_k1_in,
249+
's_k2_in': s_k2_in, 'o_k1_in': o_k1_in, 'o_k2_in': o_k2_in, 'p_k1_in': p_k1_in, 'p_k2_in': p_k2_in,
250+
'text': text, 'token_len': token_len, 'spo_list': spo_list}
251+
252+
253+
if __name__ == '__main__':
254+
x = TreeDataset()
255+
exit()

0 commit comments

Comments
 (0)