From 6c732649a7ffda442d9c3beaf12427f07e2f0375 Mon Sep 17 00:00:00 2001 From: Joe Date: Wed, 13 Jul 2022 00:45:45 +0900 Subject: [PATCH] Generate name for each member of list arg --- pytorch_pfn_extras/onnx/export_testcase.py | 25 ++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export_testcase.py b/pytorch_pfn_extras/onnx/export_testcase.py index 924f1e8c7..8849ccb40 100644 --- a/pytorch_pfn_extras/onnx/export_testcase.py +++ b/pytorch_pfn_extras/onnx/export_testcase.py @@ -279,10 +279,23 @@ def export_testcase( os.makedirs(out_dir, exist_ok=True) if isinstance(args, torch.Tensor): args = args, - input_names = kwargs.pop( - 'input_names', - ['input_{}'.format(i) for i in range(len(args))]) - assert len(input_names) == len(args) + + # We unroll list args and generate names for each tensor. + gen_input_names = [] + unrolled_args = [] + + def append_input_name(prefix: str, arg: Any) -> None: + if isinstance(arg, list): + for i, a in enumerate(arg): + append_input_name(prefix + f"_{i}", a) + else: + gen_input_names.append(prefix) + unrolled_args.append(arg) + for i, arg in enumerate(args): + append_input_name(f"input_{i}", arg) + + input_names = kwargs.pop('input_names', gen_input_names) + assert len(input_names) == len(unrolled_args) assert not isinstance(args, torch.Tensor) onnx_graph, outs = _export( @@ -302,7 +315,7 @@ def export_testcase( if used_input.name not in initializer_names: used_input_index_list.append(input_names.index(used_input.name)) input_names = [input_names[i] for i in used_input_index_list] - args = [args[i] for i in used_input_index_list] + unrolled_args = [unrolled_args[i] for i in used_input_index_list] output_path = os.path.join(out_dir, 'model.onnx') is_on_memory = True @@ -341,7 +354,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non os.makedirs(data_set_path, exist_ok=True) for pb_name in glob.glob(os.path.join(data_set_path, "*.pb")): os.remove(pb_name) - for i, (arg, name) in enumerate(zip(args, input_names)): + for i, (arg, name) in enumerate(zip(unrolled_args, input_names)): f = os.path.join(data_set_path, 'input_{}.pb'.format(i)) write_to_pb(f, arg, name)