diff --git a/v0.7/language/bert/accuracy-squad.py b/v0.7/language/bert/accuracy-squad.py index 2723a2692..113e1c8d8 100644 --- a/v0.7/language/bert/accuracy-squad.py +++ b/v0.7/language/bert/accuracy-squad.py @@ -346,7 +346,7 @@ def main(): parser.add_argument("--out_file", default="build/result/predictions.json", help="Path to output predictions file") parser.add_argument("--features_cache_file", default="eval_features.pickle", help="Path to features' cache file") parser.add_argument("--output_transposed", action="store_true", help="Transpose the output") - parser.add_argument("--output_dtype", default="float16", choices=dtype_map.keys(), help="Output data type") + parser.add_argument("--output_dtype", default="float32", choices=dtype_map.keys(), help="Output data type") args = parser.parse_args() output_dtype = dtype_map[args.output_dtype] @@ -393,7 +393,7 @@ def append_feature(feature): write_predictions(eval_examples, eval_features, results, 20, 30, True, args.out_file) print("Evaluating predictions...") - cmd = "python3 {:}/evaluate.py {:} {:}".format(os.path.dirname(__file__), + cmd = "python3 {:}/evaluate-v1.1.py {:} {:}".format(os.path.dirname(__file__), args.val_data, args.out_file) subprocess.check_call(cmd, shell=True) diff --git a/v0.7/language/bert/evaluate.py b/v0.7/language/bert/evaluate-v1.1.py similarity index 84% rename from v0.7/language/bert/evaluate.py rename to v0.7/language/bert/evaluate-v1.1.py index 0137fbca0..c582e6877 100644 --- a/v0.7/language/bert/evaluate.py +++ b/v0.7/language/bert/evaluate-v1.1.py @@ -1,3 +1,17 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Source: https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py + """ Official evaluation script for v1.1 of the SQuAD dataset. """ from __future__ import print_function from collections import Counter diff --git a/v0.7/speech_recognition/rnnt/accuracy_eval.py b/v0.7/speech_recognition/rnnt/accuracy_eval.py index 6ac13e98c..efb6a7927 100644 --- a/v0.7/speech_recognition/rnnt/accuracy_eval.py +++ b/v0.7/speech_recognition/rnnt/accuracy_eval.py @@ -12,12 +12,19 @@ from helpers import process_evaluation_epoch, __gather_predictions from parts.manifest import Manifest +dtype_map = { + "int8": 'b', + "int16": 'h', + "int32": 'l', + "int64": 'q', +} def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--log_dir", required=True) parser.add_argument("--dataset_dir", required=True) parser.add_argument("--manifest", required=True) + parser.add_argument("--output_dtype", default="int64", choices=dtype_map.keys(), help="Output data type") args = parser.parse_args() return args @@ -31,7 +38,7 @@ def main(): hypotheses = [] references = [] for result in results: - hypotheses.append(array.array('b', bytes.fromhex(result["data"])).tolist()) + hypotheses.append(array.array(dtype_map[args.output_dtype], bytes.fromhex(result["data"])).tolist()) references.append(manifest[result["qsl_idx"]]["transcript"]) # Convert ASCII output into string