|
| 1 | +import json |
| 2 | +import os |
| 3 | +from typing import Dict, List, Tuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from torch.utils.data import Dataset |
| 8 | +from tqdm import tqdm |
| 9 | + |
| 10 | +from utils import find_entity_id_from_tokens, seq_padding, sort_all |
| 11 | + |
| 12 | + |
| 13 | +class TreeDataset(Dataset): |
| 14 | + def __init__(self, data_dir: str = './data/CNShipNet', data_type: str = "train", |
| 15 | + word_vocab: str = 'word_vocab.json', ontology_vocab: str = 'attribute_vocab.json', |
| 16 | + order: List[str] = ("subject", "object", "predicate")): |
| 17 | + print('Loading {} data...'.format(data_type)) |
| 18 | + self.data_dir = data_dir |
| 19 | + self.order = order |
| 20 | + self.word_vocab = json.load(open(os.path.join(data_dir, word_vocab))) |
| 21 | + self.ontology_vocab = json.load(open(os.path.join(data_dir, ontology_vocab))) |
| 22 | + vocab_size = len(self.word_vocab) |
| 23 | + rel_token = {k: (v + vocab_size) for k, v in self.ontology_vocab.items()} |
| 24 | + self.word_vocab.update(rel_token) |
| 25 | + self.word_vocab['[pre]'] = len(self.word_vocab) |
| 26 | + self.ontology_class = list(rel_token.keys()) |
| 27 | + |
| 28 | + self.tokenizer = lambda text: text.split(" ") |
| 29 | + |
| 30 | + self.text = [] |
| 31 | + self.text_length = [] |
| 32 | + self.spo_list = [] |
| 33 | + self.token_ids = [] |
| 34 | + self.S1 = [] |
| 35 | + self.S2 = [] |
| 36 | + self.S_K1_in = [] |
| 37 | + self.O_K1_in = [] |
| 38 | + self.S_K2_in = [] |
| 39 | + self.O_K2_in = [] |
| 40 | + self.O1 = [] |
| 41 | + self.O2 = [] |
| 42 | + self.P1 = [] |
| 43 | + self.P2 = [] |
| 44 | + self.P_K1_in = [] |
| 45 | + self.P_K2_in = [] |
| 46 | + file = open(os.path.join(self.data_dir, "new_{}_data.json".format(data_type))).read().strip().split('\n') |
| 47 | + for line in tqdm(file): |
| 48 | + instance = json.loads(line) |
| 49 | + if data_type == 'train': |
| 50 | + expanded_instances = self.spo_to_seq(instance["text"], instance["spo_list"], self.tokenizer, |
| 51 | + self.ontology_class) |
| 52 | + instances = expanded_instances |
| 53 | + else: |
| 54 | + token = self.tokenizer(instance["text"]) + ['[pre]'] + self.ontology_class |
| 55 | + instance['text'] = ' '.join(token) |
| 56 | + instances = [instance] |
| 57 | + for instance in instances: |
| 58 | + text = instance["text"] |
| 59 | + spo_list = instance["spo_list"] |
| 60 | + text_id = [] |
| 61 | + for c in (self.tokenizer(text)): |
| 62 | + text_id.append(self.word_vocab.get(c, self.word_vocab["<oov>"])) |
| 63 | + self.text_length.append(len(text_id)) |
| 64 | + assert len(text_id) > 0 |
| 65 | + self.token_ids.append(text_id) |
| 66 | + |
| 67 | + s_k1 = instance.get("s_k1", 0) |
| 68 | + s_k2 = instance.get("s_k2", 0) |
| 69 | + o_k1 = instance.get("o_k1", 0) |
| 70 | + o_k2 = instance.get("o_k2", 0) |
| 71 | + p_k1 = instance.get("p_k1", 0) |
| 72 | + p_k2 = instance.get("p_k2", 0) |
| 73 | + |
| 74 | + s1_gt = instance.get("s1_gt", []) |
| 75 | + s2_gt = instance.get("s2_gt", []) |
| 76 | + o1_gt = instance.get("o1_gt", []) |
| 77 | + o2_gt = instance.get("o2_gt", []) |
| 78 | + p1_gt = instance.get("p1_gt", []) |
| 79 | + p2_gt = instance.get("p2_gt", []) |
| 80 | + |
| 81 | + self.text.append(self.tokenizer(text)) |
| 82 | + self.spo_list.append(spo_list) |
| 83 | + |
| 84 | + self.S1.append(s1_gt) |
| 85 | + self.S2.append(s2_gt) |
| 86 | + self.O1.append(o1_gt) |
| 87 | + self.O2.append(o2_gt) |
| 88 | + self.P1.append(p1_gt) |
| 89 | + self.P2.append(p2_gt) |
| 90 | + self.S_K1_in.append([s_k1]) |
| 91 | + self.S_K2_in.append([s_k2]) |
| 92 | + self.O_K1_in.append([o_k1]) |
| 93 | + self.O_K2_in.append([o_k2]) |
| 94 | + self.P_K1_in.append([p_k1]) |
| 95 | + self.P_K2_in.append([p_k2]) |
| 96 | + |
| 97 | + self.token_ids = np.array(seq_padding(self.token_ids)) |
| 98 | + |
| 99 | + # training |
| 100 | + self.S1 = np.array(seq_padding(self.S1)) |
| 101 | + self.S2 = np.array(seq_padding(self.S2)) |
| 102 | + self.O1 = np.array(seq_padding(self.O1)) |
| 103 | + self.O2 = np.array(seq_padding(self.O2)) |
| 104 | + self.P1 = np.array(seq_padding(self.P1)) |
| 105 | + self.P2 = np.array(seq_padding(self.P2)) |
| 106 | + |
| 107 | + # self.K1_in, self.K2_in = np.array(self.K1_in), np.array(self.K2_in) |
| 108 | + # only two time step are used for training |
| 109 | + self.S_K1_in = np.array(self.S_K1_in) |
| 110 | + self.S_K2_in = np.array(self.S_K2_in) |
| 111 | + self.O_K1_in = np.array(self.O_K1_in) |
| 112 | + self.O_K2_in = np.array(self.O_K2_in) |
| 113 | + self.P_K1_in = np.array(self.P_K1_in) |
| 114 | + self.P_K2_in = np.array(self.P_K2_in) |
| 115 | + |
| 116 | + def __getitem__(self, index): |
| 117 | + return ( |
| 118 | + self.token_ids[index], |
| 119 | + self.S1[index], |
| 120 | + self.S2[index], |
| 121 | + self.O1[index], |
| 122 | + self.O2[index], |
| 123 | + self.P1[index], |
| 124 | + self.P2[index], |
| 125 | + self.S_K1_in[index], |
| 126 | + self.S_K2_in[index], |
| 127 | + self.O_K1_in[index], |
| 128 | + self.O_K2_in[index], |
| 129 | + self.P_K1_in[index], |
| 130 | + self.P_K2_in[index], |
| 131 | + self.text[index], # original text |
| 132 | + self.text_length[index], # token length |
| 133 | + self.spo_list[index], # spo list |
| 134 | + ) |
| 135 | + |
| 136 | + def __len__(self): |
| 137 | + return len(self.text) |
| 138 | + |
| 139 | + def spo_to_seq(self, text, spo_list, tokenizer, ontology_class): |
| 140 | + tree = self.spo_to_tree(spo_list, self.order) |
| 141 | + tokens = tokenizer(text) + ['[pre]'] + ontology_class |
| 142 | + |
| 143 | + def to_ent(outp): |
| 144 | + # side effect! |
| 145 | + ent1, ent2 = [[0] * len(tokens) for _ in range(2)] |
| 146 | + for name in outp: |
| 147 | + _id = find_entity_id_from_tokens(tokens, self.tokenizer(name)) |
| 148 | + ent1[_id] = 1 |
| 149 | + ent2[_id + len(self.tokenizer(name)) - 1] = 1 |
| 150 | + return ent1, ent2 |
| 151 | + |
| 152 | + def to_in_key(inp, name): |
| 153 | + # side effect! |
| 154 | + if not inp: |
| 155 | + return 0, 0 |
| 156 | + |
| 157 | + k1 = find_entity_id_from_tokens(tokens, self.tokenizer(inp)) |
| 158 | + k2 = k1 + len(self.tokenizer(inp)) - 1 |
| 159 | + out = k1, k2 |
| 160 | + return out |
| 161 | + |
| 162 | + results = [] |
| 163 | + for t in tree: |
| 164 | + t1_in, t2_in, t1_out, t2_out, t3_out = t |
| 165 | + for name, ori_out, ori_in in zip( |
| 166 | + self.order, (t1_out, t2_out, t3_out), (t1_in, t2_in, None) |
| 167 | + ): |
| 168 | + new_out = to_ent(ori_out) |
| 169 | + if name == "predicate": |
| 170 | + p1, p2 = new_out |
| 171 | + p_k1, p_k2 = to_in_key(ori_in, name) |
| 172 | + elif name == "subject": |
| 173 | + s1, s2 = new_out |
| 174 | + s_k1, s_k2 = to_in_key(ori_in, name) |
| 175 | + elif name == "object": |
| 176 | + o1, o2 = new_out |
| 177 | + o_k1, o_k2 = to_in_key(ori_in, name) |
| 178 | + else: |
| 179 | + raise ValueError("should be in predicate, subject, object") |
| 180 | + |
| 181 | + result = { |
| 182 | + "text": ' '.join(tokens), |
| 183 | + "spo_list": spo_list, |
| 184 | + "s_k1": s_k1, |
| 185 | + "s_k2": s_k2, |
| 186 | + "o_k1": o_k1, |
| 187 | + "o_k2": o_k2, |
| 188 | + "p_k1": p_k1, |
| 189 | + "p_k2": p_k2, |
| 190 | + "s1_gt": s1, |
| 191 | + "s2_gt": s2, |
| 192 | + "o1_gt": o1, |
| 193 | + "o2_gt": o2, |
| 194 | + "p1_gt": p1, |
| 195 | + "p2_gt": p2, |
| 196 | + } |
| 197 | + |
| 198 | + results.append(result) |
| 199 | + return results |
| 200 | + |
| 201 | + def spo_to_tree(self, spo_list: List[Dict[str, str]], order=("subject", "object", "predicate")): |
| 202 | + """return the ground truth of the tree: rel, subj, obj, used for teacher forcing. |
| 203 | + r: given text, one of the relations |
| 204 | + s: given r_1, one of the subjects |
| 205 | + rel: multi-label classification of relation |
| 206 | + subj: multi-label classification of subject |
| 207 | + obj: multi-label classification of object |
| 208 | + Arguments: |
| 209 | + spo_list {List[Dict[str, str]]} -- [description] |
| 210 | + Returns: |
| 211 | + List[Tuple[str]] -- [(r, s, rel, subj, obj)] |
| 212 | + """ |
| 213 | + result = [] |
| 214 | + t1_out = list(set(t[order[0]] for t in spo_list)) |
| 215 | + for t1_in in t1_out: |
| 216 | + t2_out = list(set(t[order[1]] for t in spo_list if t[order[0]] == t1_in)) |
| 217 | + for t2_in in t2_out: |
| 218 | + t3_out = list( |
| 219 | + set( |
| 220 | + t[order[2]] |
| 221 | + for t in spo_list |
| 222 | + if t[order[0]] == t1_in and t[order[1]] == t2_in |
| 223 | + ) |
| 224 | + ) |
| 225 | + result.append((t1_in, t2_in, t1_out, t2_out, t3_out)) |
| 226 | + return result |
| 227 | + |
| 228 | + |
| 229 | +def collate(batch: List[Tuple]): |
| 230 | + batch_data = list(zip(*batch)) |
| 231 | + token_len = batch_data[-2] |
| 232 | + batch_data, orig_idx = sort_all(batch_data, token_len) |
| 233 | + token_ids, s1, s2, o1, o2, p1, p2, s_k1_in, s_k2_in, o_k1_in, o_k2_in, p_k1_in, p_k2_in, text, token_len, spo_list = batch_data |
| 234 | + token_ids = torch.LongTensor(np.array(token_ids)) |
| 235 | + s1 = torch.FloatTensor(np.array(s1)) |
| 236 | + s2 = torch.FloatTensor(np.array(s2)) |
| 237 | + o1 = torch.FloatTensor(np.array(o1)) |
| 238 | + o2 = torch.FloatTensor(np.array(o2)) |
| 239 | + p1 = torch.FloatTensor(np.array(p1)) |
| 240 | + p2 = torch.FloatTensor(np.array(p2)) |
| 241 | + s_k1_in = torch.LongTensor(np.array(s_k1_in)) |
| 242 | + s_k2_in = torch.LongTensor(np.array(s_k2_in)) |
| 243 | + o_k1_in = torch.LongTensor(np.array(o_k1_in)) |
| 244 | + o_k2_in = torch.LongTensor(np.array(o_k2_in)) |
| 245 | + p_k1_in = torch.LongTensor(np.array(p_k1_in)) |
| 246 | + p_k2_in = torch.LongTensor(np.array(p_k2_in)) |
| 247 | + token_len = torch.LongTensor(token_len) |
| 248 | + return {'token_ids': token_ids, 's1': s1, 's2': s2, 'o1': o1, 'o2': o2, 'p1': p1, 'p2': p2, 's_k1_in': s_k1_in, |
| 249 | + 's_k2_in': s_k2_in, 'o_k1_in': o_k1_in, 'o_k2_in': o_k2_in, 'p_k1_in': p_k1_in, 'p_k2_in': p_k2_in, |
| 250 | + 'text': text, 'token_len': token_len, 'spo_list': spo_list} |
| 251 | + |
| 252 | + |
| 253 | +if __name__ == '__main__': |
| 254 | + x = TreeDataset() |
| 255 | + exit() |
0 commit comments