Skip to content

Commit

Permalink
support for different mlflow and s3 hosts
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdrydew committed May 1, 2022
1 parent 8f0eef0 commit baab9d8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def main():
parser.add_argument('--aws_access_key_id', type=str, required=False)
parser.add_argument('--aws_secret_access_key', type=str, required=False)
parser.add_argument('--mlflow_url', type=str, required=False)
parser.add_argument('--mlflow_s3_port', type=int, default=9000)
parser.add_argument('--mlflow_tracking_port', type=int, default=5000)
parser.add_argument('--s3_url', type=str, required=False)
parser.add_argument('--mlflow_model_name', type=str, required=False)

args, _ = parser.parse_known_args()
Expand All @@ -39,7 +38,8 @@ def main():
assert args.aws_access_key_id, 'You need to specify aws_access_key_id to upload model to MLFlow'
assert args.aws_secret_access_key, 'You need to specify aws_secret_access_key to upload model to MLFlow'
assert args.mlflow_url, 'You need to specify mlflow_url to upload model to MLFlow'
assert args.mlflow_model_name, 'You need to specify mlflow_url to upload model to MLFlow'
assert args.s3_url, 'You need to specify s3_url to upload model to MLFlow'
assert args.mlflow_model_name, 'You need to specify mlflow_model_name to upload model to MLFlow'

if args.output_path is None:
if args.export_format == 'pbtxt':
Expand Down Expand Up @@ -92,8 +92,8 @@ def to_save(x):

os.environ['AWS_ACCESS_KEY_ID'] = args.aws_access_key_id
os.environ['AWS_SECRET_ACCESS_KEY'] = args.aws_secret_access_key
os.environ['MLFLOW_S3_ENDPOINT_URL'] = f'{args.mlflow_url}:{args.mlflow_s3_port}'
mlflow.set_tracking_uri(f'{args.mlflow_url}:{args.mlflow_tracking_port}')
os.environ['MLFLOW_S3_ENDPOINT_URL'] = args.s3_url
mlflow.set_tracking_uri(args.mlflow_url)
mlflow.set_experiment('model_export')

mlflow.onnx.log_model(onnx_model, artifact_path='model_onnx', registered_model_name=args.mlflow_model_name)
Expand Down

0 comments on commit baab9d8

Please sign in to comment.