diff --git a/docker/Dockerfile b/docker/Dockerfile index d432003..b7e0e90 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,4 +14,8 @@ RUN pip3 install jupyterlab==1.2.6 \ seaborn==0.10.0 \ scipy==1.4.1 \ Pillow==7.0.0 \ - PyYAML==5.3 + PyYAML==5.3 \ + tf2onnx==1.9.2 \ + mlflow==1.24.0 \ + onnxruntime==1.10.0 \ + boto3==1.21.21 diff --git a/dump_graph_model_v4.py b/dump_graph_model_v4.py deleted file mode 100644 index ed8eb50..0000000 --- a/dump_graph_model_v4.py +++ /dev/null @@ -1,71 +0,0 @@ -import argparse -from pathlib import Path - -import tensorflow as tf - -from cuda_gpu_config import setup_gpu -from model_export import dump_graph -from models.model_v4 import preprocess_features, Model_v4 -from models.utils import load_weights -from run_model_v4 import load_config - - -def main(): - parser = argparse.ArgumentParser(fromfile_prefix_chars='@') - parser.add_argument('--checkpoint_name', type=str, required=True) - parser.add_argument('--output_path', type=str, default='model_export/model_v4/graph.pbtxt') - parser.add_argument('--latent_dim', type=int, default=32, required=False) - parser.add_argument('--dont_hack_upsampling_op', default=True, action='store_true') - parser.add_argument('--test_input', type=float, nargs=4, default=None) - parser.add_argument('--constant_seed', type=float, default=None) - parser.add_argument('--gpu_num', type=str, default=None) - - args, _ = parser.parse_known_args() - - setup_gpu(args.gpu_num) - - print("") - print("----" * 10) - print("Arguments:") - for k, v in vars(args).items(): - print(f" {k} : {v}") - print("----" * 10) - print("") - - model_path = Path('saved_models') / args.checkpoint_name - - full_model = Model_v4(load_config(model_path / 'config.yaml')) - load_weights(full_model, model_path) - model = full_model.generator - - if args.constant_seed is None: - - def preprocess(x): - size = tf.shape(x)[0] - latent_input = tf.random.normal(shape=(size, args.latent_dim), dtype='float32') - return tf.concat([preprocess_features(x), latent_input], axis=-1) - - else: - - def preprocess(x): - size = tf.shape(x)[0] - latent_input = tf.ones(shape=(size, args.latent_dim), dtype='float32') * args.constant_seed - return tf.concat([preprocess_features(x), latent_input], axis=-1) - - def postprocess(x): - x = 10**x - 1 - return tf.where(x < 1.0, 0.0, x) - - dump_graph.model_to_graph( - model, - preprocess, - postprocess, - input_signature=[tf.TensorSpec(shape=[None, 4], dtype=tf.float32)], - output_file=args.output_path, - test_input=args.test_input, - hack_upsampling=not args.dont_hack_upsampling_op, - ) - - -if __name__ == '__main__': - main() diff --git a/export_model.py b/export_model.py new file mode 100644 index 0000000..553fd2d --- /dev/null +++ b/export_model.py @@ -0,0 +1,135 @@ +import os +import argparse +import tensorflow as tf +import tf2onnx + +from pathlib import Path + +from model_export import dump_graph +from models.model_v4 import Model_v4, preprocess_features +from models.utils import load_weights +from run_model_v4 import load_config + + +def main(): + parser = argparse.ArgumentParser(fromfile_prefix_chars='@') + parser.add_argument('--checkpoint_name', type=str, required=True) + parser.add_argument('--output_path', type=str, default=None) + parser.add_argument('--dont_hack_upsampling_op', default=True, action='store_true') + parser.add_argument('--test_input', type=float, nargs=4, default=None) + + parser.add_argument('--latent_space', choices=['normal', 'uniform', 'constant', 'none'], default='normal') + parser.add_argument('--latent_dim', type=int, default=32, required=False) + parser.add_argument('--constant_latent', type=float, default=None) + + parser.add_argument('--export_format', choices=['pbtxt', 'onnx'], default='pbtxt') + + parser.add_argument('--upload_to_mlflow', action='store_true') + parser.add_argument('--aws_access_key_id', type=str, required=False) + parser.add_argument('--aws_secret_access_key', type=str, required=False) + parser.add_argument('--mlflow_url', type=str, required=False) + parser.add_argument('--s3_url', type=str, required=False) + parser.add_argument('--mlflow_model_name', type=str, required=False) + + args, _ = parser.parse_known_args() + + if args.upload_to_mlflow: + assert args.export_format == 'onnx', 'Only onnx export format is supported when uploading to MLFlow' + assert args.aws_access_key_id, 'You need to specify aws_access_key_id to upload model to MLFlow' + assert args.aws_secret_access_key, 'You need to specify aws_secret_access_key to upload model to MLFlow' + assert args.mlflow_url, 'You need to specify mlflow_url to upload model to MLFlow' + assert args.s3_url, 'You need to specify s3_url to upload model to MLFlow' + assert args.mlflow_model_name, 'You need to specify mlflow_model_name to upload model to MLFlow' + + if args.output_path is None: + if args.export_format == 'pbtxt': + args.output_path = Path('model_export/model_v4') + else: + args.output_path = Path('model_export/onnx') + + args.output_path.mkdir(parents=True, exist_ok=True) + + print("") + print("----" * 10) + print("Arguments:") + for k, v in vars(args).items(): + print(f" {k} : {v}") + print("----" * 10) + print("") + + model_path = Path('saved_models') / args.checkpoint_name + + full_model = Model_v4(load_config(model_path / 'config.yaml')) + load_weights(full_model, model_path) + model = full_model.generator + + input_signature, preprocess = construct_preprocess(args) + + def postprocess(x): + x = 10**x - 1 + return tf.where(x < 1.0, 0.0, x) + + @tf.function(input_signature=input_signature) + def to_save(x): + return postprocess(model(preprocess(x))) + + if args.export_format == 'pbtxt': + dump_graph.model_to_graph( + to_save, + output_file=Path(args.output_path) / "graph.pbtxt", + test_input=args.test_input, + hack_upsampling=not args.dont_hack_upsampling_op, + ) + else: + onnx_model, _ = tf2onnx.convert.from_function( + to_save, + input_signature=input_signature, + output_path=Path(args.output_path) / f'{args.checkpoint_name}.onnx', + ) + + if args.upload_to_mlflow: + import mlflow + + os.environ['AWS_ACCESS_KEY_ID'] = args.aws_access_key_id + os.environ['AWS_SECRET_ACCESS_KEY'] = args.aws_secret_access_key + os.environ['MLFLOW_S3_ENDPOINT_URL'] = args.s3_url + mlflow.set_tracking_uri(args.mlflow_url) + mlflow.set_experiment('model_export') + + mlflow.log_artifact(str(model_path / 'config.yaml'), artifact_path='model_onnx') + mlflow.onnx.log_model(onnx_model, artifact_path='model_onnx', registered_model_name=args.mlflow_model_name) + + +def construct_preprocess(args): + latent_input_gen = None + predefined_batch_size = None if args.export_format == 'pbtxt' else 1 + + if args.latent_space == 'normal': + + def latent_input_gen(batch_size): + return tf.random.normal(shape=(batch_size, args.latent_dim), dtype='float32') + + elif args.latent_space == 'uniform': + + def latent_input_gen(batch_size): + return tf.random.uniform(shape=(batch_size, args.latent_dim), dtype='float32') + + if latent_input_gen is None: + input_signature = [tf.TensorSpec(shape=[predefined_batch_size, 36], dtype=tf.float32)] + + def preprocess(x): + return tf.concat([preprocess_features(x[..., :4]), x[..., 4:]], axis=-1) + + else: + input_signature = [tf.TensorSpec(shape=[predefined_batch_size, 4], dtype=tf.float32)] + + def preprocess(x): + size = tf.shape(x)[0] + latent_input = latent_input_gen(size) + return tf.concat([preprocess_features(x), latent_input], axis=-1) + + return input_signature, preprocess + + +if __name__ == '__main__': + main() diff --git a/model_export/dump_graph.py b/model_export/dump_graph.py index 02dd65b..97cb255 100644 --- a/model_export/dump_graph.py +++ b/model_export/dump_graph.py @@ -7,15 +7,11 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.tools import optimize_for_inference_lib - -from . import tf2xla_pb2 +from model_export import tf2xla_pb2 def model_to_graph( model, - preprocess, - postprocess, - input_signature, output_file, test_input=None, hack_upsampling=False, @@ -24,11 +20,7 @@ def model_to_graph( ): tf.keras.backend.set_learning_phase(0) - @tf.function(input_signature=input_signature) - def to_save(x): - return postprocess(model(preprocess(x))) - - constant_graph = convert_to_constants.convert_variables_to_constants_v2(to_save.get_concrete_function()) + constant_graph = convert_to_constants.convert_variables_to_constants_v2(model.get_concrete_function()) if hack_upsampling: print("Warning: hacking upsampling operations") @@ -68,7 +60,7 @@ def to_save(x): f.write(str(config)) if test_input is not None: - print(to_save(tf.convert_to_tensor([test_input]))) + print(model(tf.convert_to_tensor([test_input]))) for batch_size in batch_sizes[::-1]: timings = [] @@ -76,7 +68,7 @@ def to_save(x): for i in range(iterations): batched_input = tf.random.normal(shape=(batch_size, len(test_input)), dtype='float32') t0 = perf_counter() - to_save(batched_input).numpy() + model(batched_input).numpy() t1 = perf_counter() timings.append((t1 - t0) * 1000.0 / batch_size)