From bb04fc96265d3067e9931854a45b1668e616832f Mon Sep 17 00:00:00 2001 From: Zhihan Date: Mon, 3 Aug 2020 22:45:26 -0700 Subject: [PATCH] Fix RNNT and BERT script --- v0.7/language/bert/accuracy-squad.py | 24 ++++- v0.7/language/bert/evaluate.py | 94 +++++++++++++++++++ v0.7/speech_recognition/rnnt/accuracy_eval.py | 16 +++- 3 files changed, 124 insertions(+), 10 deletions(-) create mode 100644 v0.7/language/bert/evaluate.py diff --git a/v0.7/language/bert/accuracy-squad.py b/v0.7/language/bert/accuracy-squad.py index f1365f4e6..2723a2692 100644 --- a/v0.7/language/bert/accuracy-squad.py +++ b/v0.7/language/bert/accuracy-squad.py @@ -45,6 +45,16 @@ RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) +dtype_map = { + "int8": np.int8, + "int16": np.int16, + "int32": np.int32, + "int64": np.int64, + "float16": np.float16, + "float32": np.float32, + "float64": np.float64 +} + def get_final_text(pred_text, orig_text, do_lower_case): """Project the tokenized prediction back to the original text.""" @@ -302,7 +312,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, with open(output_prediction_file, "w") as writer: writer.write(json.dumps(all_predictions, indent=4) + "\n") -def load_loadgen_log(log_path, eval_features, output_transposed=False): +def load_loadgen_log(log_path, eval_features, dtype=np.float32, output_transposed=False): with open(log_path) as f: predictions = json.load(f) @@ -310,10 +320,10 @@ def load_loadgen_log(log_path, eval_features, output_transposed=False): for prediction in predictions: qsl_idx = prediction["qsl_idx"] if output_transposed: - logits = np.frombuffer(bytes.fromhex(prediction["data"]), np.float32).reshape(2, -1) + logits = np.frombuffer(bytes.fromhex(prediction["data"]), dtype).reshape(2, -1) logits = np.transpose(logits) else: - logits = np.frombuffer(bytes.fromhex(prediction["data"]), np.float32).reshape(-1, 2) + logits = np.frombuffer(bytes.fromhex(prediction["data"]), dtype).reshape(-1, 2) # Pad logits to max_seq_length seq_length = logits.shape[0] start_logits = np.ones(max_seq_length) * -10000.0 @@ -336,8 +346,11 @@ def main(): parser.add_argument("--out_file", default="build/result/predictions.json", help="Path to output predictions file") parser.add_argument("--features_cache_file", default="eval_features.pickle", help="Path to features' cache file") parser.add_argument("--output_transposed", action="store_true", help="Transpose the output") + parser.add_argument("--output_dtype", default="float16", choices=dtype_map.keys(), help="Output data type") args = parser.parse_args() + output_dtype = dtype_map[args.output_dtype] + print("Reading examples...") eval_examples = read_squad_examples(input_file=args.val_data, is_training=False, version_2_with_negative=False) @@ -374,13 +387,14 @@ def append_feature(feature): pickle.dump(eval_features, cache_file) print("Loading LoadGen logs...") - results = load_loadgen_log(args.log_file, eval_features, args.output_transposed) + results = load_loadgen_log(args.log_file, eval_features, output_dtype, args.output_transposed) print("Post-processing predictions...") write_predictions(eval_examples, eval_features, results, 20, 30, True, args.out_file) print("Evaluating predictions...") - cmd = "python3 build/data/evaluate-v1.1.py build/data/dev-v1.1.json build/result/predictions.json" + cmd = "python3 {:}/evaluate.py {:} {:}".format(os.path.dirname(__file__), + args.val_data, args.out_file) subprocess.check_call(cmd, shell=True) if __name__ == "__main__": diff --git a/v0.7/language/bert/evaluate.py b/v0.7/language/bert/evaluate.py new file mode 100644 index 000000000..0137fbca0 --- /dev/null +++ b/v0.7/language/bert/evaluate.py @@ -0,0 +1,94 @@ +""" Official evaluation script for v1.1 of the SQuAD dataset. """ +from __future__ import print_function +from collections import Counter +import string +import re +import argparse +import json +import sys + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + total += 1 + if qa['id'] not in predictions: + message = 'Unanswered question ' + qa['id'] + \ + ' will receive score 0.' + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x['text'], qa['answers'])) + prediction = predictions[qa['id']] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths( + f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {'exact_match': exact_match, 'f1': f1} + + +if __name__ == '__main__': + expected_version = '1.1' + parser = argparse.ArgumentParser( + description='Evaluation for SQuAD ' + expected_version) + parser.add_argument('dataset_file', help='Dataset file') + parser.add_argument('prediction_file', help='Prediction File') + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if (dataset_json['version'] != expected_version): + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset = dataset_json['data'] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions))) diff --git a/v0.7/speech_recognition/rnnt/accuracy_eval.py b/v0.7/speech_recognition/rnnt/accuracy_eval.py index a1a12a7ad..6ac13e98c 100644 --- a/v0.7/speech_recognition/rnnt/accuracy_eval.py +++ b/v0.7/speech_recognition/rnnt/accuracy_eval.py @@ -6,9 +6,9 @@ import sys import os -from QSL import AudioQSL +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "pytorch")) -sys.path.insert(0, os.path.join(os.getcwd(), "pytorch")) +from QSL import AudioQSL from helpers import process_evaluation_epoch, __gather_predictions from parts.manifest import Manifest @@ -31,13 +31,19 @@ def main(): hypotheses = [] references = [] for result in results: - hypotheses.append(array.array('q', bytes.fromhex(result["data"])).tolist()) + hypotheses.append(array.array('b', bytes.fromhex(result["data"])).tolist()) references.append(manifest[result["qsl_idx"]]["transcript"]) - hypotheses = __gather_predictions([hypotheses], labels=labels) + + # Convert ASCII output into string + for idx in range(len(hypotheses)): + hypotheses[idx] = ''.join([chr(c) for c in hypotheses[idx]]) + references = __gather_predictions([references], labels=labels) + d = dict(predictions=hypotheses, transcripts=references) - print("Word Error Rate:", process_evaluation_epoch(d)) + wer = process_evaluation_epoch(d) + print("Word Error Rate: {:}%, accuracy={:}%".format(wer * 100, (1 - wer) * 100)) if __name__ == '__main__': main()