Skip to content

Commit

Permalink
Fix RNNT and BERT script
Browse files Browse the repository at this point in the history
  • Loading branch information
nvzhihanj committed Aug 4, 2020
1 parent bf3c0c5 commit bb04fc9
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 10 deletions.
24 changes: 19 additions & 5 deletions v0.7/language/bert/accuracy-squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@

RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])

dtype_map = {
"int8": np.int8,
"int16": np.int16,
"int32": np.int32,
"int64": np.int64,
"float16": np.float16,
"float32": np.float32,
"float64": np.float64
}

def get_final_text(pred_text, orig_text, do_lower_case):
"""Project the tokenized prediction back to the original text."""

Expand Down Expand Up @@ -302,18 +312,18 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")

def load_loadgen_log(log_path, eval_features, output_transposed=False):
def load_loadgen_log(log_path, eval_features, dtype=np.float32, output_transposed=False):
with open(log_path) as f:
predictions = json.load(f)

results = []
for prediction in predictions:
qsl_idx = prediction["qsl_idx"]
if output_transposed:
logits = np.frombuffer(bytes.fromhex(prediction["data"]), np.float32).reshape(2, -1)
logits = np.frombuffer(bytes.fromhex(prediction["data"]), dtype).reshape(2, -1)
logits = np.transpose(logits)
else:
logits = np.frombuffer(bytes.fromhex(prediction["data"]), np.float32).reshape(-1, 2)
logits = np.frombuffer(bytes.fromhex(prediction["data"]), dtype).reshape(-1, 2)
# Pad logits to max_seq_length
seq_length = logits.shape[0]
start_logits = np.ones(max_seq_length) * -10000.0
Expand All @@ -336,8 +346,11 @@ def main():
parser.add_argument("--out_file", default="build/result/predictions.json", help="Path to output predictions file")
parser.add_argument("--features_cache_file", default="eval_features.pickle", help="Path to features' cache file")
parser.add_argument("--output_transposed", action="store_true", help="Transpose the output")
parser.add_argument("--output_dtype", default="float16", choices=dtype_map.keys(), help="Output data type")
args = parser.parse_args()

output_dtype = dtype_map[args.output_dtype]

print("Reading examples...")
eval_examples = read_squad_examples(input_file=args.val_data,
is_training=False, version_2_with_negative=False)
Expand Down Expand Up @@ -374,13 +387,14 @@ def append_feature(feature):
pickle.dump(eval_features, cache_file)

print("Loading LoadGen logs...")
results = load_loadgen_log(args.log_file, eval_features, args.output_transposed)
results = load_loadgen_log(args.log_file, eval_features, output_dtype, args.output_transposed)

print("Post-processing predictions...")
write_predictions(eval_examples, eval_features, results, 20, 30, True, args.out_file)

print("Evaluating predictions...")
cmd = "python3 build/data/evaluate-v1.1.py build/data/dev-v1.1.json build/result/predictions.json"
cmd = "python3 {:}/evaluate.py {:} {:}".format(os.path.dirname(__file__),
args.val_data, args.out_file)
subprocess.check_call(cmd, shell=True)

if __name__ == "__main__":
Expand Down
94 changes: 94 additions & 0 deletions v0.7/language/bert/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
""" Official evaluation script for v1.1 of the SQuAD dataset. """
from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys


def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1


def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)


def evaluate(dataset, predictions):
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article['paragraphs']:
for qa in paragraph['qas']:
total += 1
if qa['id'] not in predictions:
message = 'Unanswered question ' + qa['id'] + \
' will receive score 0.'
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x['text'], qa['answers']))
prediction = predictions[qa['id']]
exact_match += metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(
f1_score, prediction, ground_truths)

exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total

return {'exact_match': exact_match, 'f1': f1}


if __name__ == '__main__':
expected_version = '1.1'
parser = argparse.ArgumentParser(
description='Evaluation for SQuAD ' + expected_version)
parser.add_argument('dataset_file', help='Dataset file')
parser.add_argument('prediction_file', help='Prediction File')
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_json = json.load(dataset_file)
if (dataset_json['version'] != expected_version):
print('Evaluation expects v-' + expected_version +
', but got dataset with v-' + dataset_json['version'],
file=sys.stderr)
dataset = dataset_json['data']
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset, predictions)))
16 changes: 11 additions & 5 deletions v0.7/speech_recognition/rnnt/accuracy_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import sys
import os

from QSL import AudioQSL
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "pytorch"))

sys.path.insert(0, os.path.join(os.getcwd(), "pytorch"))
from QSL import AudioQSL
from helpers import process_evaluation_epoch, __gather_predictions
from parts.manifest import Manifest

Expand All @@ -31,13 +31,19 @@ def main():
hypotheses = []
references = []
for result in results:
hypotheses.append(array.array('q', bytes.fromhex(result["data"])).tolist())
hypotheses.append(array.array('b', bytes.fromhex(result["data"])).tolist())
references.append(manifest[result["qsl_idx"]]["transcript"])
hypotheses = __gather_predictions([hypotheses], labels=labels)

# Convert ASCII output into string
for idx in range(len(hypotheses)):
hypotheses[idx] = ''.join([chr(c) for c in hypotheses[idx]])

references = __gather_predictions([references], labels=labels)

d = dict(predictions=hypotheses,
transcripts=references)
print("Word Error Rate:", process_evaluation_epoch(d))
wer = process_evaluation_epoch(d)
print("Word Error Rate: {:}%, accuracy={:}%".format(wer * 100, (1 - wer) * 100))

if __name__ == '__main__':
main()

0 comments on commit bb04fc9

Please sign in to comment.