Skip to content

Commit 681e817

Browse files
committed
update code & data for ijcai2020
1 parent 1386239 commit 681e817

10 files changed

+710
-197
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,58 @@
11
# Data
2-
dataset_name: 'quac'
3-
trainset: '../data/quac/train.json'
4-
devset: '../data/quac/dev.json'
5-
testset: null
6-
embed_file: '../data/coqa/glove.840B.300d.txt'
7-
saved_vocab_file: '../data/quac/word_model_min_5'
2+
dataset_name: 'doqa'
3+
trainset: '../data/doqa/train.json'
4+
devset: '../data/doqa/dev.json'
5+
testset: '../data/doqa/test.json'
6+
embed_file: '/home/cheny39/glove-vectors/glove.840B.300d.txt'
7+
saved_vocab_file: '../data/doqa/word_model_min_5'
88
pretrained: null
99

1010

1111
# Output
12-
out_dir: '../out/quac/graphflow_static_graph'
12+
out_dir: '../out/doqa/graphflow_dynamic_graph'
1313

1414

1515
# Preprocessing
1616
min_freq: 5
1717
top_vocab: 200000
18-
n_history: 2
18+
n_history: 2 # 2!
1919
no_pre_question: False
2020
no_pre_answer: False
21-
max_turn_num: 20
21+
max_turn_num: 8
2222

2323

2424

2525
# Model
2626
embed_type: 'glove'
2727
vocab_embed_size: 300
28-
fix_vocab_embed: True
28+
fix_vocab_embed: True # True!
2929
f_qem: True # Context exact match feature
3030
f_pos: True # Context POS feature
3131
f_ner: True # Context NER feature
3232
f_tf: False # Context TF feature
3333
ctx_exact_match_embed_dim: 3
3434
ctx_pos_embed_dim: 12
3535
ctx_ner_embed_dim: 8
36-
answer_marker_embed_dim: 10
36+
answer_marker_embed_dim: 10 # 10!
3737
use_ques_marker: True
38-
ques_marker_embed_dim: 3
39-
ques_turn_marker_embed_dim: 5
38+
ques_marker_embed_dim: 3 # 3!
39+
ques_turn_marker_embed_dim: 5 # 5!
4040

41-
hidden_size: 300
42-
word_dropout: 0.3
43-
bert_dropout: 0.4
44-
rnn_dropout: 0.3
41+
hidden_size: 128 # 128!
42+
word_dropout: 0.4 # 0.4!
43+
bert_dropout: 0.2 # 0.2!
44+
rnn_dropout: 0.4 # 0.4!
4545
rnn_input_dropout: null
4646

4747
# Graph neural networks
4848
use_gnn: True
4949
bignn: False
50-
static_graph: True
50+
static_graph: False
5151
temporal_gnn: True
52-
ctx_graph_hops: 3
53-
ctx_graph_topk: 10
54-
graph_learner_num_pers: 1
52+
ctx_graph_hops: 5 # 5!
53+
ctx_graph_topk: 10 # 10!
54+
graph_learner_num_pers: 1 # 1
55+
stacked_layer: False # False
5556

5657

5758
# Spatial kernels
@@ -63,36 +64,35 @@ position_emb_size: 50
6364

6465

6566
# Bert configure
66-
use_bert: True
67-
finetune_bert: False
67+
use_bert: True # True
68+
finetune_bert: False # False
6869
use_bert_weight: True
6970
use_bert_gamma: False
7071
bert_model: 'bert-large-uncased'
7172
bert_dim: 1024
7273
bert_max_seq_len: 500
73-
bert_doc_stride: 250
74+
bert_doc_stride: 250 #
7475
bert_layer_indexes:
7576
- 0
7677
- 24
7778

7879

7980
# Optimizer
8081
optimizer: 'adamax'
81-
learning_rate: 0.001
82-
grad_clipping: 10
82+
learning_rate: 0.0005 # 0.0005!
83+
grad_clipping: 5 # 5!
8384

8485

8586
# Training & testing
8687
random_seed: 1234
8788
shuffle: True # Whether to shuffle the examples during training
88-
batch_size: 1 # No. of dialogs per batch
89+
batch_size: 1 # No. of dialogs per batch, 1!
8990
grad_accumulated_steps: 1
90-
test_batch_size: 1
9191
max_epochs: 30
9292
patience: 10
9393
verbose: 1000 # Print every X batches
94-
unk_answer_threshold: 0.3
95-
max_answer_len: 35 # Set max answer length for decoding
94+
unk_answer_threshold: 0.2 # 0.2!
95+
max_answer_len: 30 # Set max answer length for decoding # 30! 35!
9696
predict_train: True # Whether to predict on training set
9797
out_predictions: True # Whether to output predictions
9898
predict_raw_text: True # Whether to use raw text and offsets for prediction
@@ -103,4 +103,4 @@ out_pred_in_folder: True # Turn it off for Codalab
103103

104104
# Device
105105
no_cuda: False
106-
cuda_id: 0
106+
cuda_id: -1

src/config/graphflow_static_graph_coqa.yml

-103
This file was deleted.

src/core/layers/graphs.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ def forward(self, context, ctx_mask):
7474
# attention = torch.mean(torch.matmul(context_fc, context_fc.transpose(-1, -2)), dim=2)
7575

7676

77-
# # 3) Best attention mechanism
78-
# context_fc = context.unsqueeze(2) * torch.relu(self.weight_tensor).unsqueeze(0).unsqueeze(0).unsqueeze(-2)
79-
# attention = torch.mean(torch.matmul(context_fc, context.unsqueeze(2).transpose(-1, -2)), dim=2)
77+
# 3) Best attention mechanism
78+
context_fc = context.unsqueeze(2) * torch.relu(self.weight_tensor).unsqueeze(0).unsqueeze(0).unsqueeze(-2)
79+
attention = torch.mean(torch.matmul(context_fc, context.unsqueeze(2).transpose(-1, -2)), dim=2)
8080

8181

82-
# 4)weighted cosine
83-
context_fc = context.unsqueeze(2) * self.weight_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-2)
84-
context_norm = F.normalize(context_fc, p=2, dim=-1)
85-
attention = torch.matmul(context_norm, context_norm.transpose(-1, -2)).mean(2)
86-
markoff_value = 0
82+
# # 4)weighted cosine
83+
# context_fc = context.unsqueeze(2) * self.weight_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-2)
84+
# context_norm = F.normalize(context_fc, p=2, dim=-1)
85+
# attention = torch.matmul(context_norm, context_norm.transpose(-1, -2)).mean(2)
86+
# markoff_value = 0
8787

8888

8989
if ctx_mask is not None:
@@ -205,12 +205,8 @@ def forward(self, node_state, weighted_adjacency_matrix):
205205
return node_state
206206

207207
def bignn_update(self, node_state, weighted_adjacency_matrix):
208-
# weighted_adjacency_matrix_in = torch.softmax(weighted_adjacency_matrix, dim=-1)
209-
# weighted_adjacency_matrix_out = torch.softmax(weighted_adjacency_matrix.transpose(-1, -2), dim=-1)
210-
211-
weighted_adjacency_matrix_in = weighted_adjacency_matrix / torch.clamp(torch.sum(weighted_adjacency_matrix, dim=-1, keepdim=True), min=VERY_SMALL_NUMBER)
212-
weighted_adjacency_matrix_out = weighted_adjacency_matrix.transpose(-1, -2) / torch.clamp(torch.sum(weighted_adjacency_matrix.transpose(-1, -2), dim=-1, keepdim=True), min=VERY_SMALL_NUMBER)
213-
208+
weighted_adjacency_matrix_in = torch.softmax(weighted_adjacency_matrix, dim=-1)
209+
weighted_adjacency_matrix_out = torch.softmax(weighted_adjacency_matrix.transpose(-1, -2), dim=-1)
214210

215211
for _ in range(self.graph_hops):
216212
agg_state_in = self.aggregate_avgpool(node_state, weighted_adjacency_matrix_in)
@@ -220,8 +216,7 @@ def bignn_update(self, node_state, weighted_adjacency_matrix):
220216
return node_state
221217

222218
def gnn_update(self, node_state, weighted_adjacency_matrix):
223-
# weighted_adjacency_matrix = torch.softmax(weighted_adjacency_matrix, dim=-1)
224-
weighted_adjacency_matrix = weighted_adjacency_matrix / torch.clamp(torch.sum(weighted_adjacency_matrix, dim=-1, keepdim=True), min=VERY_SMALL_NUMBER)
219+
weighted_adjacency_matrix = torch.softmax(weighted_adjacency_matrix, dim=-1)
225220

226221

227222
for _ in range(self.graph_hops):

src/core/model.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212

1313
from .utils.coqa import compute_eval_metric
14-
from .utils.quac import eval_fn
14+
from .utils.quac import eval_fn as quac_eval_fn
15+
from .utils.doqa import eval_fn as doqa_eval_fn
1516
from .utils import constants as Constants
1617
from .word_model import WordModel
1718
from .models.graphflow import GraphFlow
@@ -331,7 +332,7 @@ class QuACModel(Model):
331332

332333
def __init__(self, config, train_set=None):
333334
super(QuACModel, self).__init__(config, train_set)
334-
335+
self.eval_fn = quac_eval_fn if config['dataset_name'] == 'quac' else doqa_eval_fn
335336

336337
def predict(self, ex, step, update=True, out_predictions=False):
337338
# Train/Eval mode
@@ -342,13 +343,12 @@ def predict(self, ex, step, update=True, out_predictions=False):
342343
score_s, score_e, unk_probs, score_yesno, score_followup = res['start_logits'], res['end_logits'], res['unk_probs'], res['score_yesno'], res['score_followup']
343344

344345
output = {
345-
'metrics': {'f1': 0.0, 'heq': 0.0, 'dheq': 0.0},
346+
'metrics': None,
346347
'loss': 0.0,
347348
'total_qs': 0,
348349
'total_dials': 0
349350
}
350351

351-
352352
# Compute loss
353353
loss = self.compute_span_loss(score_s, score_e, ex['targets'], ex['span_mask'])
354354
loss = loss + self.compute_answer_type_loss(unk_probs, score_yesno, score_followup, ex['unk_answer_targets'], ex['yesno_targets'], ex['followup_targets'], res['turn_mask'])
@@ -375,7 +375,7 @@ def predict(self, ex, step, update=True, out_predictions=False):
375375

376376
if (not update) or self.config['predict_train']:
377377
predictions, spans, yesnos, followups = self.extract_predictions(ex, score_s, score_e, unk_probs, score_yesno, score_followup, self.config['unk_answer_threshold'], res['turn_mask'])
378-
output['metrics'], total_qs, total_dials = eval_fn(ex['answers'], predictions, ex['raw_evidence_text'])
378+
output['metrics'], total_qs, total_dials = self.eval_fn(ex['answers'], predictions, ex['raw_evidence_text'])
379379
output['total_qs'] = total_qs
380380
output['total_dials'] = total_dials
381381

@@ -434,12 +434,20 @@ def extract_predictions(self, ex, score_s, score_e, unk_probs, score_yesno, scor
434434
yesno = Constants.QuAC_YESNO_OTHER
435435

436436
followup_type = np.argmax(_followup[j]).item()
437-
if followup_type == Constants.QuAC_FOLLOWUP_YES_LABEL:
438-
followup = Constants.QuAC_FOLLOWUP_YES
439-
elif followup_type == Constants.QuAC_FOLLOWUP_NO_LABEL:
440-
followup = Constants.QuAC_FOLLOWUP_NO
437+
438+
if self.config['dataset_name'] == 'quac':
439+
if followup_type == Constants.QuAC_FOLLOWUP_YES_LABEL:
440+
followup = Constants.QuAC_FOLLOWUP_YES
441+
elif followup_type == Constants.QuAC_FOLLOWUP_NO_LABEL:
442+
followup = Constants.QuAC_FOLLOWUP_NO
443+
else:
444+
followup = Constants.QuAC_FOLLOWUP_OTHER
445+
441446
else:
442-
followup = Constants.QuAC_FOLLOWUP_OTHER
447+
if followup_type == Constants.DoQA_FOLLOWUP_YES_LABEL:
448+
followup = Constants.DoQA_FOLLOWUP_YES
449+
else:
450+
followup = Constants.DoQA_FOLLOWUP_NO
443451

444452
para_pred.append(pred)
445453
para_span.append(span)

0 commit comments

Comments
 (0)