Skip to content

Commit

Permalink
mlflow support
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdrydew committed Mar 17, 2022
1 parent 3795d2c commit 8f0eef0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
5 changes: 4 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ RUN pip3 install jupyterlab==1.2.6 \
scipy==1.4.1 \
Pillow==7.0.0 \
PyYAML==5.3 \
tf2onnx==1.9.2
tf2onnx==1.9.2 \
mlflow==1.24.0 \
onnxruntime==1.10.0 \
boto3==1.21.21
29 changes: 28 additions & 1 deletion export_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import argparse
import tensorflow as tf
import tf2onnx
Expand All @@ -23,8 +24,23 @@ def main():

parser.add_argument('--export_format', choices=['pbtxt', 'onnx'], default='pbtxt')

parser.add_argument('--upload_to_mlflow', action='store_true')
parser.add_argument('--aws_access_key_id', type=str, required=False)
parser.add_argument('--aws_secret_access_key', type=str, required=False)
parser.add_argument('--mlflow_url', type=str, required=False)
parser.add_argument('--mlflow_s3_port', type=int, default=9000)
parser.add_argument('--mlflow_tracking_port', type=int, default=5000)
parser.add_argument('--mlflow_model_name', type=str, required=False)

args, _ = parser.parse_known_args()

if args.upload_to_mlflow:
assert args.export_format == 'onnx', 'Only onnx export format is supported when uploading to MLFlow'
assert args.aws_access_key_id, 'You need to specify aws_access_key_id to upload model to MLFlow'
assert args.aws_secret_access_key, 'You need to specify aws_secret_access_key to upload model to MLFlow'
assert args.mlflow_url, 'You need to specify mlflow_url to upload model to MLFlow'
assert args.mlflow_model_name, 'You need to specify mlflow_url to upload model to MLFlow'

if args.output_path is None:
if args.export_format == 'pbtxt':
args.output_path = Path('model_export/model_v4')
Expand Down Expand Up @@ -65,12 +81,23 @@ def to_save(x):
hack_upsampling=not args.dont_hack_upsampling_op,
)
else:
tf2onnx.convert.from_function(
onnx_model, _ = tf2onnx.convert.from_function(
to_save,
input_signature=input_signature,
output_path=Path(args.output_path) / f'{args.checkpoint_name}.onnx',
)

if args.upload_to_mlflow:
import mlflow

os.environ['AWS_ACCESS_KEY_ID'] = args.aws_access_key_id
os.environ['AWS_SECRET_ACCESS_KEY'] = args.aws_secret_access_key
os.environ['MLFLOW_S3_ENDPOINT_URL'] = f'{args.mlflow_url}:{args.mlflow_s3_port}'
mlflow.set_tracking_uri(f'{args.mlflow_url}:{args.mlflow_tracking_port}')
mlflow.set_experiment('model_export')

mlflow.onnx.log_model(onnx_model, artifact_path='model_onnx', registered_model_name=args.mlflow_model_name)


def construct_preprocess(args):
latent_input_gen = None
Expand Down

0 comments on commit 8f0eef0

Please sign in to comment.