@@ -21,6 +21,14 @@ def parse_input_args():
2121 help = 'Perform fp16 quantizaton in addition to int8' ,
2222 )
2323
24+ parser .add_argument (
25+ "--fp32" ,
26+ action = "store_true" ,
27+ required = False ,
28+ default = False ,
29+ help = 'Perform no quantization' ,
30+ )
31+
2432 parser .add_argument (
2533 "--image_dir" ,
2634 required = False ,
@@ -56,10 +64,10 @@ def __init__(self,
5664 '''
5765 :param image_folder: image dataset folder
5866 :param width: image width
59- :param height: image height
67+ :param height: image height
6068 :param start_index: start index of images
6169 :param end_index: end index of images
62- :param stride: image size of each data get
70+ :param stride: image size of each data get
6371 :param batch_size: batch size of inference
6472 :param model_path: model name and path
6573 :param input_name: model input name
@@ -153,12 +161,14 @@ def preprocess_imagenet(self, images_folder, height, width, start_index=0, size_
153161 parameter images_folder: path to folder storing images
154162 parameter height: image height in pixels
155163 parameter width: image width in pixels
156- parameter start_index: image index to start with
164+ parameter start_index: image index to start with
157165 parameter size_limit: number of images to load. Default is 0 which means all images are picked.
158166 return: list of matrices characterizing multiple images
159167 '''
160168 def preprocess_images (input , channels = 3 , height = 224 , width = 224 ):
161169 image = input .resize ((width , height ), Image .Resampling .LANCZOS )
170+ if image .mode in ["CMYK" , "RGBA" ]:
171+ image = image .convert ("RGB" )
162172 input_data = np .asarray (image ).astype (np .float32 )
163173 if len (input_data .shape ) != 2 :
164174 input_data = input_data .transpose ([2 , 0 , 1 ])
@@ -249,7 +259,7 @@ def __init__(self,
249259 providers = ["MIGraphXExecutionProvider" ]):
250260 '''
251261 :param model_path: ONNX model to validate
252- :param synset_id: ILSVRC2012 synset id
262+ :param synset_id: ILSVRC2012 synset id
253263 :param data_reader: user implemented object to read in and preprocess calibration dataset
254264 based on CalibrationDataReader Interface
255265 :param providers: ORT execution provider type
@@ -281,9 +291,8 @@ def predict(self):
281291 self .prediction_result_list = inference_outputs_list
282292
283293 def top_k_accuracy (self , truth , prediction , k = 1 ):
284- '''From https://github.com/chainer/chainer/issues/606
294+ '''From https://github.com/chainer/chainer/issues/606
285295 '''
286-
287296 y = np .argsort (prediction )[:, - k :]
288297 return np .any (y .T == truth .argmax (axis = 1 ), axis = 0 ).mean ()
289298
@@ -293,7 +302,7 @@ def evaluate(self, prediction_results):
293302 y_prediction = np .empty ((total_val_images , 1000 ), dtype = np .float32 )
294303 i = 0
295304 for res in prediction_results :
296- y_prediction [i :i + batch_size , :] = res [0 ]
305+ y_prediction [i :i + res [ 0 ]. shape [ 0 ] , :] = res [0 ]
297306 i = i + batch_size
298307 print ("top 1: " , self .top_k_accuracy (self .synset_id , y_prediction , k = 1 ))
299308 print ("top 5: " , self .top_k_accuracy (self .synset_id , y_prediction , k = 5 ))
@@ -344,8 +353,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
344353 2. Download ILSVRC2012 validation dataset and development kit from http://www.image-net.org/challenges/LSVRC/2012/downloads.
345354 3. Extract validation dataset JPEG files to 'ILSVRC2012/val'.
346355 4. Extract development kit to 'ILSVRC2012/devkit'. Two files in the development kit are used, 'ILSVRC2012_validation_ground_truth.txt' and 'meta.mat'.
356+ These are also available to download at https://github.com/miraclewkf/MobileNetV2-PyTorch/tree/master/ImageNet/ILSVRC2012_devkit_t12/data
347357 5. Download 'synset_words.txt' from https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt into 'ILSVRC2012/'.
348-
358+
349359 Please download Resnet50 model from ONNX model zoo https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz
350360 Untar the model into the workspace
351361 '''
@@ -356,15 +366,18 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
356366 ilsvrc2012_dataset_path = flags .image_dir
357367 augmented_model_path = "./augmented_model.onnx"
358368 batch_size = flags .batch
359- calibration_dataset_size = flags .cal_size # Size of dataset for calibration
369+ calibration_dataset_size = 0 if flags .fp32 else flags .cal_size # Size of dataset for calibration
370+
371+ calibration_table_generation_enable = False
372+ if not flags .fp32 :
373+ # INT8 calibration setting
374+ calibration_table_generation_enable = True # Enable/Disable INT8 calibration
360375
361- # INT8 calibration setting
362- calibration_table_generation_enable = True # Enable/Disable INT8 calibration
376+ # MIGraphX EP INT8 settings
377+ os .environ ["ORT_MIGRAPHX_INT8_ENABLE" ] = "1" # Enable INT8 precision
378+ os .environ ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME" ] = "calibration.flatbuffers" # Calibration table name
379+ os .environ ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE" ] = "0" # Calibration table name
363380
364- # MIGraphX EP INT8 settings
365- os .environ ["ORT_MIGRAPHX_INT8_ENABLE" ] = "1" # Enable INT8 precision
366- os .environ ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME" ] = "calibration.flatbuffers" # Calibration table name
367- os .environ ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE" ] = "0" # Calibration table name
368381 execution_provider = ["MIGraphXExecutionProvider" ]
369382
370383 # Convert static batch to dynamic batch
@@ -378,7 +391,7 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
378391 if calibration_table_generation_enable :
379392 print ("Generating Calibration Table" )
380393 calibrator = create_calibrator (new_model_path , [], augmented_model_path = augmented_model_path )
381- calibrator .set_execution_providers (["ROCMExecutionProvider" ])
394+ calibrator .set_execution_providers (["ROCMExecutionProvider" ])
382395 data_reader = ImageNetDataReader (ilsvrc2012_dataset_path ,
383396 start_index = 0 ,
384397 end_index = calibration_dataset_size ,
@@ -391,7 +404,7 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
391404
392405 serial_cal_tensors = {}
393406 for keys , values in cal_tensors .data .items ():
394- serial_cal_tensors [keys ] = values .range_value
407+ serial_cal_tensors [keys ] = [ float ( x [ 0 ]) for x in values .range_value ]
395408
396409 print ("Writing calibration table" )
397410 write_calibration_table (serial_cal_tensors )
0 commit comments