From baab9d893ab73d8d5e8a550acdfc81ca5e0070d5 Mon Sep 17 00:00:00 2001 From: alexdrydew Date: Sun, 1 May 2022 15:31:41 +0300 Subject: [PATCH] support for different mlflow and s3 hosts --- export_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/export_model.py b/export_model.py index 395a958..e3e5baa 100644 --- a/export_model.py +++ b/export_model.py @@ -28,8 +28,7 @@ def main(): parser.add_argument('--aws_access_key_id', type=str, required=False) parser.add_argument('--aws_secret_access_key', type=str, required=False) parser.add_argument('--mlflow_url', type=str, required=False) - parser.add_argument('--mlflow_s3_port', type=int, default=9000) - parser.add_argument('--mlflow_tracking_port', type=int, default=5000) + parser.add_argument('--s3_url', type=str, required=False) parser.add_argument('--mlflow_model_name', type=str, required=False) args, _ = parser.parse_known_args() @@ -39,7 +38,8 @@ def main(): assert args.aws_access_key_id, 'You need to specify aws_access_key_id to upload model to MLFlow' assert args.aws_secret_access_key, 'You need to specify aws_secret_access_key to upload model to MLFlow' assert args.mlflow_url, 'You need to specify mlflow_url to upload model to MLFlow' - assert args.mlflow_model_name, 'You need to specify mlflow_url to upload model to MLFlow' + assert args.s3_url, 'You need to specify s3_url to upload model to MLFlow' + assert args.mlflow_model_name, 'You need to specify mlflow_model_name to upload model to MLFlow' if args.output_path is None: if args.export_format == 'pbtxt': @@ -92,8 +92,8 @@ def to_save(x): os.environ['AWS_ACCESS_KEY_ID'] = args.aws_access_key_id os.environ['AWS_SECRET_ACCESS_KEY'] = args.aws_secret_access_key - os.environ['MLFLOW_S3_ENDPOINT_URL'] = f'{args.mlflow_url}:{args.mlflow_s3_port}' - mlflow.set_tracking_uri(f'{args.mlflow_url}:{args.mlflow_tracking_port}') + os.environ['MLFLOW_S3_ENDPOINT_URL'] = args.s3_url + mlflow.set_tracking_uri(args.mlflow_url) mlflow.set_experiment('model_export') mlflow.onnx.log_model(onnx_model, artifact_path='model_onnx', registered_model_name=args.mlflow_model_name)