diff --git a/utils/RunONNXModel.py b/utils/RunONNXModel.py index 139b9b82e3..07d44681ff 100755 --- a/utils/RunONNXModel.py +++ b/utils/RunONNXModel.py @@ -160,6 +160,10 @@ def check_non_negative(argname, value): help="Path to a folder containing reference inputs and outputs stored in protobuf." " If --verify=ref, inputs and outputs are reference data for verification", ) +data_group.add_argument( + "--inputs-from-arrays", + help="List of numpy arrays used as inputs for inference" +) data_group.add_argument( "--load-ref-from-numpy", metavar="PATH", @@ -730,6 +734,8 @@ def main(): inputs = read_input_from_refs(len(input_names), args.load_ref) elif args.load_ref_from_numpy: inputs = read_input_from_refs(len(input_names), args.load_ref_from_numpy) + elif args.inputs_from_arrays: + inputs = args.inputs_from_arrays else: inputs = generate_random_input(input_signature, input_shapes) @@ -861,7 +867,19 @@ def main(): "using atol={}, rtol={} ...".format(args.atol, args.rtol), ) verify_outs(outs[i], ref_outs[i]) - - + return outs + +# Python function inteface for RunONNXModel +# Arguments are passed as named parameters for the function +# Extra functionality is to directly pass a list of arrays as inference input +def onnxmlirrun(onnx_model=None, compiled_so=None, inputs=None): + if onnx_model : + args.model = onnx_model + if compiled_so : + args.load_so = compiled_so + if inputs : + args.inputs_from_arrays = inputs + return main() + if __name__ == "__main__": main()