Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added OpenVINO inference SUT #2

Open
wants to merge 9 commits into
base: dev-3dunet-reference
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion v0.7/medical_imaging/3d-unet/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ RUN cd /tmp \
&& rm -rf inference

# Install dependencies
RUN python3 -m pip install onnx onnxruntime numpy==1.18.0 Pillow==7.0.0
RUN python3 -m pip install wrapt --upgrade --ignore-installed
RUN python3 -m pip install onnx onnxruntime numpy==1.18.0 Pillow==7.0.0 tensorflow
RUN python3 -m pip install tensorflow-addons https://github.com/onnx/onnx-tensorflow/archive/master.zip

# Install nnUnet
COPY nnUnet /workspace/nnUnet
Expand Down
74 changes: 52 additions & 22 deletions v0.7/medical_imaging/3d-unet/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ POSTPROCESSED_DATA_DIR := $(BUILD_DIR)/postprocessed_data
MODEL_DIR := $(BUILD_DIR)/model
RESULT_DIR := $(BUILD_DIR)/result
MLPERF_CONF := $(BUILD_DIR)/mlperf.conf
PYTORCH_MODEL := $(RESULT_DIR)/fold_4.zip
ONNX_MODEL := $(MODEL_DIR)/192_224_192.onnx
PYTORCH_MODEL := $(RESULT_DIR)/fold_1.zip
ONNX_MODEL := $(MODEL_DIR)/224_224_160.onnx
ONNX_DYNAMIC_BS_MODEL := $(MODEL_DIR)/224_224_160_dynamic_bs.onnx
TF_MODEL := $(MODEL_DIR)/224_224_160.pb

# Env variables needed by nnUnet
export nnUNet_raw_data_base=$(RAW_DATA_DIR)
Expand All @@ -43,16 +45,15 @@ export RESULTS_FOLDER=$(RESULT_DIR)

HAS_GPU := $(shell command -v nvidia-smi 2> /dev/null)

ifndef $HAS_GPU
DOCKER_RUN_CMD := docker run
ifeq ($(HAS_GPU),)
DOCKER_RUN_CMD := docker run
else
# Handle different nvidia-docker version
ifneq ($(wildcard /usr/bin/nvidia-docker),)
DOCKER_RUN_CMD := nvidia-docker run
else
DOCKER_RUN_CMD := docker run --gpus=all
endif

# Handle different nvidia-docker version
ifneq ($(wildcard /usr/bin/nvidia-docker),)
DOCKER_RUN_CMD := nvidia-docker run
else
DOCKER_RUN_CMD := docker run --gpus=all
endif
endif

.PHONY: setup
Expand Down Expand Up @@ -86,24 +87,45 @@ download_model:
@echo "Download models..."
@$(MAKE) -f $(MAKEFILE_NAME) download_pytorch_model
@$(MAKE) -f $(MAKEFILE_NAME) download_onnx_model
@$(MAKE) -f $(MAKEFILE_NAME) download_tf_model

.PHONY: download_pytorch_model
download_pytorch_model:
# Will download model from Zenodo
# @if [ ! -e $(PYTORCH_MODEL)/model.pytorch ]; then \
# wget -O ; \
# fi
# For now, assume that fold_4.zip is in build/result
@echo "Downloading PyTorch model from Zenodo..."
@if [ ! -e $(PYTORCH_MODEL) ]; then \
echo "For now, please manually download PyTorch model to $(PYTORCH_MODEL)/"; \
wget -O $(PYTORCH_MODEL) https://zenodo.org/record/3904106/files/fold_1.zip?download=1 \
&& cd $(RESULT_DIR) && unzip -o fold_1.zip; \
fi
@cd $(RESULT_DIR) && unzip -o fold_4.zip

.PHONY: download_onnx_model
download_onnx_model:
# Will download model from Zenodo
@echo "Downloading ONNX model from Zenodo..."
@if [ ! -e $(ONNX_MODEL) ]; then \
echo "For now, please manually download ONNX model to $(ONNX_MODEL)"; \
wget -O $(ONNX_MODEL) https://zenodo.org/record/3904138/files/224_224_160.onnx?download=1; \
fi
@if [ ! -e $(ONNX_DYNAMIC_BS_MODEL) ]; then \
wget -O $(ONNX_DYNAMIC_BS_MODEL) https://zenodo.org/record/3904138/files/224_224_160_dyanmic_bs.onnx?download=1; \
fi

.PHONY: download_tf_model
download_tf_model:
@echo "Downloading TF model from Zenodo..."
@if [ ! -e $(TF_MODEL) ]; then \
wget -O $(TF_MODEL) https://zenodo.org/record/3904146/files/224_224_160.pb?download=1; \
fi

.PHONY: convert_onnx_model
convert_onnx_model: download_pytorch_model
@echo "Converting PyTorch model to ONNX model..."
@if [ ! -e $(ONNX_MODEL) ]; then \
python3 unet_pytorch_to_onnx.py; \
fi

.PHONY: convert_tf_model
convert_tf_model: convert_onnx_model
@echo "Converting ONNX model to TF model..."
@if [ ! -e $(TF_MODEL) ]; then \
python3 unet_onnx_to_tf.py; \
fi

.PHONY: build_docker
Expand Down Expand Up @@ -154,11 +176,19 @@ run_pytorch_accuracy: mkdir_postprocessed_data

.PHONY: run_onnxruntime_performance
run_onnxruntime_performance:
@python3 run.py --backend=onnxruntime
@python3 run.py --backend=onnxruntime --model=build/model/224_224_160.onnx

.PHONY: run_onnxruntime_accuracy
run_onnxruntime_accuracy: mkdir_postprocessed_data
@python3 run.py --backend=onnxruntime --accuracy
@python3 run.py --backend=onnxruntime --model=build/model/224_224_160.onnx --accuracy

.PHONY: run_tf_performance
run_tf_performance:
@python3 run.py --backend=tf --model=build/model/224_224_160.pb

.PHONY: run_tf_accuracy
run_tf_accuracy: mkdir_postprocessed_data
@python3 run.py --backend=tf --model=build/model/224_224_160.pb --accuracy

.PHONY: evaluate
evaluate:
Expand Down
24 changes: 5 additions & 19 deletions v0.7/medical_imaging/3d-unet/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# MLPerf Inference Benchmarks for Medical Image 3D Segmentation

This is the reference implementation for MLPerf Inference benchmarks for Medical Image 3D Segmentation.

The chosen model is 3D-Unet in [nnUnet](https://github.com/MIC-DKFZ/nnUNet) performing [BraTS 2019](https://www.med.upenn.edu/cbica/brats2019/data.html) brain tumor segmentation task.

## Prerequisites
Expand All @@ -13,18 +11,15 @@ The chosen model is 3D-Unet in [nnUnet](https://github.com/MIC-DKFZ/nnUNet) perf

| model | framework | accuracy | dataset | model link | model source | precision | notes |
| ----- | --------- | -------- | ------- | ---------- | ------------ | --------- | ----- |
| 3D-Unet | PyTorch | mean = 0.82400, whole tumor = 0.8922, tumor core = 0.8158, enhancing tumor = 0.7640 | last 20% of BraTS 2019 Training Dataset (67 samples) | [from zenodo](???) | Trained in PyTorch using codes from [nnUnet](https://github.com/MIC-DKFZ/nnUNet) on the first 80% of BraTS 2019 Training Dataset. | fp32 | |
| 3D-Unet | ONNX | mean = 0.82400, whole tumor = 0.8922, tumor core = 0.8158, enhancing tumor = 0.7640 | last 20% of BraTS 2019 Training Dataset (67 samples) | [from zenodo](???) | Converted from the PyTorch model using ??? script. | fp32 | |
| 3D-Unet | PyTorch | mean = 0.85203, whole tumor = 0.9147, tumor core = 0.8645, enhancing tumor = 0.7769 | The second 20% of BraTS 2019 Training Dataset (fold 1, 67 samples) | [from zenodo](https://zenodo.org/record/3904106) | Trained in PyTorch using codes from [nnUnet](https://github.com/MIC-DKFZ/nnUNet) on the first 20% and the last 60% (fold 1) of BraTS 2019 Training Dataset. | fp32 | |
| 3D-Unet | ONNX | mean = 0.85203, whole tumor = 0.9147, tumor core = 0.8645, enhancing tumor = 0.7769 | The second 20% of BraTS 2019 Training Dataset (fold 1, 67 samples) | [from zenodo](https://zenodo.org/record/3904138) | Converted from the PyTorch model using [script](unet_pytorch_to_onnx.py). | fp32 | |
| 3D-Unet | Tensorflow | mean = 0.85203, whole tumor = 0.9147, tumor core = 0.8645, enhancing tumor = 0.7769 | The second 20% of BraTS 2019 Training Dataset (fold 1, 67 samples) | [from zenodo](https://zenodo.org/record/3904146) | Converted from the PyTorch model using [script](unet_onnx_to_tf.py). | fp32 | |

## Disclaimer
This benchmark app is a reference implementation that is not meant to be the fastest implementation possible.

## TODO

[ ] Update the models (PyTorch and ONNX) to the final volume size (160, 224, 224).
[ ] Upload the models to Zenodo, and fill in Zenodo link, and modify Makefile so that it downloads models from Zenodo.
[ ] Update the accuracy metric.
[ ] Add PyTorch -> ONNX script.
[ ] Update the onnxruntime in the docker container to a version which supports 3D ConvTranspose op.

## Commands
Expand All @@ -34,21 +29,12 @@ Please download [BraTS 2019](https://www.med.upenn.edu/cbica/brats2019/data.html
Please run the following commands:

- `export DOWNLOAD_DATA_DIR=<path/to/MICCAI_BraTS_2019_Data_Training>`: point to location of downloaded BraTS 2019 Training dataset.
- **Temporary:** Download the (192, 224, 192) PyTorch model named `fold_4.zip` to `build/result/`.
- **Temporary:** Download the (192, 224, 192) ONNX model named `192_224_192.onnx` to `build/`.
- `make setup`: initialize submodule and download models.
- `make build_docker`: build docker image.
- `make launch_docker`: launch docker container with an interaction session.
- `make preprocess_data`: preprocess the BraTS 2019 dataset.
- `python3 run.py --backend=[pytorch|onnxruntime] --scenario=[Offline|SingleStream|MultiStream|Server] [--accuracy]`: run the harness inside the docker container. Performance or Accuracy results will be printed in console.
- `python3 run.py --backend=[tf|pytorch|onnxruntime] --scenario=[Offline|SingleStream|MultiStream|Server] [--accuracy] --model=[path/to/model_file(tf/onnx only)]`: run the harness inside the docker container. Performance or Accuracy results will be printed in console.

## Details

- SUT implementations are in [pytorch_SUT.py](pytorch_SUT.py) and [onnxruntime_SUT.py](onnxruntime_SUT.py). QSL implementation is in [brats_QSL.py](brats_QSL.py).
- The script [brats_eval.py](brats_eval.py) parses LoadGen accuracy log, post-processes it, and computes the accuracy.
- Preprocessing and evaluation (including post-processing) are not included in the timed path.
- The input to the SUT is a volume of size `[4, 192, 224, 192]`. The output from SUT is a volume of size `[4, 192, 224, 192]` with predicted label logits for each voxel.

## License

Apache License 2.0
- SUT implementations are in [pytorch_SUT.py](pytorch_SUT.py), [onnxruntime_
3 changes: 1 addition & 2 deletions v0.7/medical_imaging/3d-unet/brats_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def load_loadgen_log(log_file, result_dtype, dictionaries):

assert len(predictions) == len(dictionaries), "Number of predictions does not match number of samples in validation set!"

# TODO: need to change to [160, 224, 224]
padded_shape = [192, 224, 192]
padded_shape = [224, 224, 160]
results = [None for i in range(len(predictions))]
for prediction in predictions:
qsl_idx = prediction["qsl_idx"]
Expand Down
Binary file added v0.7/medical_imaging/3d-unet/fold1_validation.npy
Binary file not shown.
Binary file removed v0.7/medical_imaging/3d-unet/fold4_validation.npy
Binary file not shown.
76 changes: 76 additions & 0 deletions v0.7/medical_imaging/3d-unet/ov_SUT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# coding=utf-8
# Copyright (c) 2020 INTEL CORPORATION. All rights reserved.
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import array
import json
import os
import sys
sys.path.insert(0, os.getcwd())

import mlperf_loadgen as lg
import numpy as np

from brats_QSL import get_brats_QSL

from openvino.inference_engine import IECore
from scipy.special import softmax

class _3DUNET_OV_SUT():
def __init__(self, model_path, preprocessed_data_dir, performance_count):
print("Loading OV model...")

model_xml = model_path
model_bin = os.path.splitext(model_xml)[0] + '.bin'

ie = IECore()
net = ie.read_network(model=model_xml, weights=model_bin)

self.input_name = next(iter(net.inputs))
self.output_name = 'output'

self.exec_net = ie.load_network(network=net, device_name='CPU')

print("Constructing SUT...")
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries,
self.process_latencies)
self.qsl = get_brats_QSL(preprocessed_data_dir, performance_count)
print("Finished constructing SUT.")

def issue_queries(self, query_samples):
for i in range(len(query_samples)):
data = self.qsl.get_features(query_samples[i].index)

print("Processing sample id {:d} with shape = {:}".format(
query_samples[i].index, data.shape))

before_softmax = self.exec_net.infer(inputs={self.input_name: data[np.newaxis, ...]})[self.output_name]
after_softmax = softmax(before_softmax, axis=1).astype(np.float16)

response_array = array.array("B", after_softmax.tobytes())
bi = response_array.buffer_info()
response = lg.QuerySampleResponse(query_samples[i].id, bi[0],
bi[1])
lg.QuerySamplesComplete([response])

def flush_queries(self):
pass

def process_latencies(self, latencies_ns):
pass


def get_ov_sut(model_path, preprocessed_data_dir, performance_count):
return _3DUNET_OV_SUT(model_path, preprocessed_data_dir, performance_count)
8 changes: 4 additions & 4 deletions v0.7/medical_imaging/3d-unet/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_args():
parser.add_argument("--raw_data_dir", default="build/raw_data/nnUNet_raw_data/Task043_BraTS2019/imagesTr",
help="Path to the directory containing raw nii.gz files")
parser.add_argument("--preprocessed_data_dir", default="build/preprocessed_data", help="Path to the directory containing preprocessed data")
parser.add_argument("--validation_fold_file", default="fold4_validation.npy", help="Path to the npy file storing all the sample names for the validation fold")
parser.add_argument("--validation_fold_file", default="fold1_validation.npy", help="Path to the npy file storing all the sample names for the validation fold")
parser.add_argument("--num_threads_preprocessing", type=int, default=12, help="Number of threads to run the preprocessing with")
args = parser.parse_args()
return args
Expand Down Expand Up @@ -75,8 +75,8 @@ def main():

print("Preparing for preprocessing data...")

# Validation set is fold 4
fold = 4
# Validation set is fold 1
fold = 1
validation_fold_file = args.validation_fold_file

# Make sure the model exists
Expand All @@ -91,7 +91,7 @@ def main():
raw_data_dir = args.raw_data_dir
preprocessed_data_dir = args.preprocessed_data_dir

# Open npy containing validation images from specific fold (e.g. 4)
# Open npy containing validation images from specific fold (e.g. 1)
with open(validation_fold_file, "rb") as f:
validation_files = numpy.load(f)

Expand Down
2 changes: 1 addition & 1 deletion v0.7/medical_imaging/3d-unet/pytorch_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ def flush_queries(self):
def process_latencies(self, latencies_ns):
pass

def get_pytorch_sut(model_dir, preprocessed_data_dir, performance_count, folds=4, checkpoint_name="model_best"):
def get_pytorch_sut(model_dir, preprocessed_data_dir, performance_count, folds=1, checkpoint_name="model_best"):
return _3DUNET_PyTorch_SUT(model_dir, preprocessed_data_dir, performance_count, folds, checkpoint_name)
57 changes: 46 additions & 11 deletions v0.7/medical_imaging/3d-unet/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,71 @@
import mlperf_loadgen as lg
import subprocess


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=["pytorch","onnxruntime"], default="pytorch", help="Backend")
parser.add_argument("--scenario", choices=["SingleStream", "Offline", "Server", "MultiStream"], default="Offline", help="Scenario")
parser.add_argument("--accuracy", action="store_true", help="enable accuracy pass")
parser.add_argument("--mlperf_conf", default="build/mlperf.conf", help="mlperf rules config")
parser.add_argument("--user_conf", default="user.conf", help="mlperf rules config")
parser.add_argument("--model_dir", default="build/result/nnUNet/3d_fullres/Task043_BraTS2019/nnUNetTrainerV2__nnUNetPlansv2.mlperf.1",
parser.add_argument("--backend",
choices=["pytorch", "onnxruntime", "tf", "ov"],
default="pytorch",
help="Backend")
parser.add_argument(
"--scenario",
choices=["SingleStream", "Offline", "Server", "MultiStream"],
default="Offline",
help="Scenario")
parser.add_argument("--accuracy",
action="store_true",
help="enable accuracy pass")
parser.add_argument("--mlperf_conf",
default="build/mlperf.conf",
help="mlperf rules config")
parser.add_argument("--user_conf",
default="user.conf",
help="mlperf rules config")
parser.add_argument(
"--model_dir",
default=
"build/result/nnUNet/3d_fullres/Task043_BraTS2019/nnUNetTrainerV2__nnUNetPlansv2.mlperf.1",
help="Path to the directory containing plans.pkl")
parser.add_argument("--onnx_model", default="build/model/192_224_192.onnx", help="Path to the ONNX model")
parser.add_argument("--preprocessed_data_dir", default="build/preprocessed_data", help="path to preprocessed data")
parser.add_argument("--performance_count", type=int, default=16, help="performance count")
parser.add_argument("--model", help="Path to the ONNX or TF model")
parser.add_argument("--preprocessed_data_dir",
default="build/preprocessed_data",
help="path to preprocessed data")
parser.add_argument("--performance_count",
type=int,
default=16,
help="performance count")
args = parser.parse_args()
return args


scenario_map = {
"SingleStream": lg.TestScenario.SingleStream,
"Offline": lg.TestScenario.Offline,
"Server": lg.TestScenario.Server,
"MultiStream": lg.TestScenario.MultiStream
}


def main():
args = get_args()

if args.backend == "pytorch":
from pytorch_SUT import get_pytorch_sut
sut = get_pytorch_sut(args.model_dir, args.preprocessed_data_dir, args.performance_count)
sut = get_pytorch_sut(args.model_dir, args.preprocessed_data_dir,
args.performance_count)
elif args.backend == "onnxruntime":
from onnxruntime_SUT import get_onnxruntime_sut
sut = get_onnxruntime_sut(args.onnx_model, args.preprocessed_data_dir, args.performance_count)
sut = get_onnxruntime_sut(args.model, args.preprocessed_data_dir,
args.performance_count)
elif args.backend == "tf":
from tf_SUT import get_tf_sut
sut = get_tf_sut(args.model, args.preprocessed_data_dir,
args.performance_count)
elif args.backend == "ov":
from ov_SUT import get_ov_sut
sut = get_ov_sut(args.model, args.preprocessed_data_dir,
args.performance_count)
else:
raise ValueError("Unknown backend: {:}".format(args.backend))

Expand Down Expand Up @@ -91,5 +125,6 @@ def main():
print("Destroying QSL...")
lg.DestroyQSL(sut.qsl.qsl)


if __name__ == "__main__":
main()
Loading