diff --git a/docker/Dockerfile b/docker/Dockerfile index 597a5b2..b7e0e90 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -15,4 +15,7 @@ RUN pip3 install jupyterlab==1.2.6 \ scipy==1.4.1 \ Pillow==7.0.0 \ PyYAML==5.3 \ - tf2onnx==1.9.2 + tf2onnx==1.9.2 \ + mlflow==1.24.0 \ + onnxruntime==1.10.0 \ + boto3==1.21.21 diff --git a/export_model.py b/export_model.py index c5c99e7..395a958 100644 --- a/export_model.py +++ b/export_model.py @@ -1,3 +1,4 @@ +import os import argparse import tensorflow as tf import tf2onnx @@ -23,8 +24,23 @@ def main(): parser.add_argument('--export_format', choices=['pbtxt', 'onnx'], default='pbtxt') + parser.add_argument('--upload_to_mlflow', action='store_true') + parser.add_argument('--aws_access_key_id', type=str, required=False) + parser.add_argument('--aws_secret_access_key', type=str, required=False) + parser.add_argument('--mlflow_url', type=str, required=False) + parser.add_argument('--mlflow_s3_port', type=int, default=9000) + parser.add_argument('--mlflow_tracking_port', type=int, default=5000) + parser.add_argument('--mlflow_model_name', type=str, required=False) + args, _ = parser.parse_known_args() + if args.upload_to_mlflow: + assert args.export_format == 'onnx', 'Only onnx export format is supported when uploading to MLFlow' + assert args.aws_access_key_id, 'You need to specify aws_access_key_id to upload model to MLFlow' + assert args.aws_secret_access_key, 'You need to specify aws_secret_access_key to upload model to MLFlow' + assert args.mlflow_url, 'You need to specify mlflow_url to upload model to MLFlow' + assert args.mlflow_model_name, 'You need to specify mlflow_url to upload model to MLFlow' + if args.output_path is None: if args.export_format == 'pbtxt': args.output_path = Path('model_export/model_v4') @@ -65,12 +81,23 @@ def to_save(x): hack_upsampling=not args.dont_hack_upsampling_op, ) else: - tf2onnx.convert.from_function( + onnx_model, _ = tf2onnx.convert.from_function( to_save, input_signature=input_signature, output_path=Path(args.output_path) / f'{args.checkpoint_name}.onnx', ) + if args.upload_to_mlflow: + import mlflow + + os.environ['AWS_ACCESS_KEY_ID'] = args.aws_access_key_id + os.environ['AWS_SECRET_ACCESS_KEY'] = args.aws_secret_access_key + os.environ['MLFLOW_S3_ENDPOINT_URL'] = f'{args.mlflow_url}:{args.mlflow_s3_port}' + mlflow.set_tracking_uri(f'{args.mlflow_url}:{args.mlflow_tracking_port}') + mlflow.set_experiment('model_export') + + mlflow.onnx.log_model(onnx_model, artifact_path='model_onnx', registered_model_name=args.mlflow_model_name) + def construct_preprocess(args): latent_input_gen = None