Skip to content

Commit

Permalink
Merge pull request mlcommons#710 from mnaumovfb/master
Browse files Browse the repository at this point in the history
DLRM: Fixing --mlperf-bin-loader and adding ONNX implementation for fixed batch size
  • Loading branch information
christ1ne authored Sep 1, 2020
2 parents 0b2f0c2 + cf0d0b3 commit 2f42f34
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 3 deletions.
5 changes: 4 additions & 1 deletion recommendation/dlrm/pytorch/python/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,10 @@ def load_query_samples(self, sample_list):
s = self.random_offsets[l]
e = self.random_offsets[l+1]

ls = [self.test_data[i] for i in range(s, e)]
if self.use_mlperf_bin_loader and self.samples_to_aggregate > 1:
ls = [self.test_data[l]]
else:
ls = [self.test_data[i] for i in range(s, e)]
if self.use_mlperf_bin_loader:
# NOTE: in binary dataset the values are transformed
ls_t = list(zip(*ls))
Expand Down
55 changes: 55 additions & 0 deletions recommendation/dlrm/pytorch/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,24 @@
"model": "dlrm",
"max-batchsize": 2048,
},
"dlrm-kaggle-onnxruntime": {
"dataset": "kaggle",
"inputs": "continuous and categorical features",
"outputs": "probability",
"backend": "onnxruntime",
"model": "dlrm",
"max-batchsize": 128,
},
"dlrm-terabyte-onnxruntime": {
"dataset": "terabyte",
"inputs": "continuous and categorical features",
"outputs": "probability",
"backend": "onnxruntime",
"model": "dlrm",
"max-batchsize": 2048,
},


}

SCENARIO_MAP = {
Expand Down Expand Up @@ -196,6 +214,43 @@ def get_backend(backend, dataset, max_ind_range, data_sub_sample_rate, use_gpu):
else:
raise ValueError("only kaggle|terabyte dataset options are supported")

elif backend == "onnxruntime":
from backend_onnxruntime import BackendOnnxruntime

# NOTE: pass model parameters here, the following options are available
if dataset == "kaggle":
# 1. Criteo Kaggle Display Advertisement Challenge Dataset (see ./bench/dlrm_s_criteo_kaggle.sh)
backend = BackendOnnxruntime(
m_spa=16,
ln_emb=np.array([1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572]),
ln_bot=np.array([13,512,256,64,16]),
ln_top=np.array([367,512,256,1]),
use_gpu=use_gpu
)
elif dataset == "terabyte":
if max_ind_range == 10000000:
# 2. Criteo Terabyte (see ./bench/dlrm_s_criteo_terabyte.sh [--sub-sample=0.875] --max-in-range=10000000)
backend = BackendOnnxruntime(
m_spa=64,
ln_emb=np.array([9980333,36084,17217,7378,20134,3,7112,1442,61, 9758201,1333352,313829,10,2208,11156,122,4,970,14, 9994222, 7267859, 9946608,415421,12420,101, 36]),
ln_bot=np.array([13,512,256,64]),
ln_top=np.array([415,512,512,256,1]),
use_gpu=use_gpu
)
elif max_ind_range == 40000000:
# 3. Criteo Terabyte MLPerf training (see ./bench/run_and_time.sh --max-in-range=40000000)
backend = BackendOnnxruntime(
m_spa=128,
ln_emb=np.array([39884406,39043,17289,7420,20263,3,7120,1543,63,38532951,2953546,403346,10,2208,11938,155,4,976,14,39979771,25641295,39664984,585935,12972,108,36]),
ln_bot=np.array([13,512,256,128]),
ln_top=np.array([479,1024,1024,512,256,1]),
use_gpu=use_gpu
)
else:
raise ValueError("only --max-in-range 10M or 40M is supported")
else:
raise ValueError("only kaggle|terabyte dataset options are supported")

else:
raise ValueError("unknown backend: " + backend)
return backend
Expand Down
14 changes: 12 additions & 2 deletions recommendation/dlrm/pytorch/run_common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ device="cpu"

for i in $* ; do
case $i in
pytorch) backend=$i; shift;;
pytorch|onnxruntime) backend=$i; shift;;
dlrm) model=$i; shift;;
kaggle|terabyte) dataset=$i; shift;;
cpu|gpu) device=$i; shift;;
Expand All @@ -46,7 +46,8 @@ else
extra_args="--use-gpu"
fi
name="$model-$dataset-$backend"

# debuging
# echo $name

#
# pytorch
Expand All @@ -59,6 +60,15 @@ if [ $name == "dlrm-terabyte-pytorch" ] ; then
model_path="$MODEL_DIR/dlrm_terabyte.pytorch"
profile=dlrm-terabyte-pytorch
fi
if [ $name == "dlrm-kaggle-onnxruntime" ] ; then
model_path="$MODEL_DIR/dlrm_kaggle.onnxruntime"
profile=dlrm-kaggle-onnxruntime
fi
if [ $name == "dlrm-terabyte-onnxruntime" ] ; then
model_path="$MODEL_DIR/dlrm_terabyte.onnxruntime"
profile=dlrm-terabyte-onnxruntime
fi

# debuging
# echo $model_path
# echo $profile
Expand Down
90 changes: 90 additions & 0 deletions v0.5/recommendation/python/backend_onnxruntime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
onnxruntime backend (https://github.com/microsoft/onnxruntime)
"""

# pylint: disable=unused-argument,missing-docstring,useless-super-delegation

import onnxruntime as rt
import numpy as np
import backend
import torch


class BackendOnnxruntime(backend.Backend):
def __init__(self, m_spa, ln_emb, ln_bot, ln_top, use_gpu=False, mini_batch_size=1):
super(BackendOnnxruntime, self).__init__()

def version(self):
return rt.__version__

def name(self):
"""Name of the runtime."""
return "onnxruntime"

def load(self, model_path, inputs=None, outputs=None):
"""Load model and find input/outputs from the model file."""
opt = rt.SessionOptions()
# enable level 3 optimizations
# FIXME: enable below once onnxruntime 0.5 is released
# opt.set_graph_optimization_level(3)
# print("onnx load", model_path, inputs, outputs)
self.sess = rt.InferenceSession(model_path, opt)
# get input and output names
if True: #not inputs:
self.inputs = [meta.name for meta in self.sess.get_inputs()]
else:
self.inputs = inputs
if True: #not outputs:
self.outputs = [meta.name for meta in self.sess.get_outputs()]
else:
self.outputs = outputs
return self

def predict(self, batch_dense_X, batch_lS_o, batch_lS_i):
"""Run the prediction."""
# print("onnx predict")
# print(self.inputs)
# print(self.outputs)

'''
incoming_bs = batch_dense_X.shape[0]
model_saved_bs = 2048
if (incoming_bs != model_saved_bs):
print("WARNING: mismatch beween incoming " + str(incoming_bs) + " and model saved " + str(model_saved_bs) + " mini-batch size")
fake_output = torch.zeros(size=(incoming_bs,1), dtype=torch.float32)
return fake_output
'''

dict_inputs = {}

# Dmitriy's approach to build dictionaries
ind = 0
for i in self.inputs:

if "input.1" == i:
dict_inputs[i] = batch_dense_X.numpy().astype(np.float32)

elif "lS_o" == i:
dict_inputs[i] = batch_lS_o.numpy().astype(np.int64)

else:
dict_inputs[i] = batch_lS_i[ind].numpy().astype(np.int64)
ind = ind + 1
'''
# Maxim's approach to build dictionaries
dict_inputs[self.inputs[0]] = batch_dense_X.numpy().astype(np.float32)
dict_inputs[self.inputs[1]] = batch_lS_o.numpy().astype(np.int64)
if False: #torch.is_tensor(batch_lS_i): # approach 1: tensor
dict_inputs[self.inputs[2]] = batch_lS_i.numpy().astype(np.int64)
else: # approach 2: list
for j in range(26): # 26 sparse features
dict_inputs[self.inputs[j+2]] = batch_lS_i[j].numpy().astype(np.int64)
'''
# predict and return output
# print(dict_inputs)
output = self.sess.run(output_names=self.outputs, input_feed=dict_inputs)
output = torch.tensor(output, requires_grad=False).view(-1, 1)
# print("output", output)
# print("output.shape", output.shape)

return output

0 comments on commit 2f42f34

Please sign in to comment.