Skip to content

Commit

Permalink
Address review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
nvzhihanj committed Aug 4, 2020
1 parent bb04fc9 commit 08fba19
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
4 changes: 2 additions & 2 deletions v0.7/language/bert/accuracy-squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def main():
parser.add_argument("--out_file", default="build/result/predictions.json", help="Path to output predictions file")
parser.add_argument("--features_cache_file", default="eval_features.pickle", help="Path to features' cache file")
parser.add_argument("--output_transposed", action="store_true", help="Transpose the output")
parser.add_argument("--output_dtype", default="float16", choices=dtype_map.keys(), help="Output data type")
parser.add_argument("--output_dtype", default="float32", choices=dtype_map.keys(), help="Output data type")
args = parser.parse_args()

output_dtype = dtype_map[args.output_dtype]
Expand Down Expand Up @@ -393,7 +393,7 @@ def append_feature(feature):
write_predictions(eval_examples, eval_features, results, 20, 30, True, args.out_file)

print("Evaluating predictions...")
cmd = "python3 {:}/evaluate.py {:} {:}".format(os.path.dirname(__file__),
cmd = "python3 {:}/evaluate-v1.1.py {:} {:}".format(os.path.dirname(__file__),
args.val_data, args.out_file)
subprocess.check_call(cmd, shell=True)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Source: https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py

""" Official evaluation script for v1.1 of the SQuAD dataset. """
from __future__ import print_function
from collections import Counter
Expand Down
9 changes: 8 additions & 1 deletion v0.7/speech_recognition/rnnt/accuracy_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
from helpers import process_evaluation_epoch, __gather_predictions
from parts.manifest import Manifest

dtype_map = {
"int8": 'b',
"int16": 'h',
"int32": 'l',
"int64": 'q',
}

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--log_dir", required=True)
parser.add_argument("--dataset_dir", required=True)
parser.add_argument("--manifest", required=True)
parser.add_argument("--output_dtype", default="int64", choices=dtype_map.keys(), help="Output data type")
args = parser.parse_args()
return args

Expand All @@ -31,7 +38,7 @@ def main():
hypotheses = []
references = []
for result in results:
hypotheses.append(array.array('b', bytes.fromhex(result["data"])).tolist())
hypotheses.append(array.array(dtype_map[args.output_dtype], bytes.fromhex(result["data"])).tolist())
references.append(manifest[result["qsl_idx"]]["transcript"])

# Convert ASCII output into string
Expand Down

0 comments on commit 08fba19

Please sign in to comment.