|
13 | 13 |
|
14 | 14 | _SINGLE_INDENT = " " |
15 | 15 |
|
| 16 | +_SMALL_TENSOR_SIZE = 4 |
| 17 | + |
16 | 18 | kwlist = { |
17 | 19 | "False", |
18 | 20 | "None", |
@@ -119,7 +121,7 @@ def renamer(name): |
119 | 121 |
|
120 | 122 | def _translate_type(onnx_type): |
121 | 123 | """Converts a onnx type into a type defined by *onnxscript*.""" |
122 | | - return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type) |
| 124 | + return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False) |
123 | 125 |
|
124 | 126 |
|
125 | 127 | def _translate_signature(inputs, outputs): |
@@ -350,25 +352,33 @@ def _translate_graph_body(self, graph, opsets, indent=0): |
350 | 352 | if hasattr(graph, "initializer"): |
351 | 353 | for init in graph.initializer: |
352 | 354 | if self.skip_initializers: |
353 | | - init_py_name = self._translate_onnx_var(init.name) |
354 | | - if init_py_name in self.skipped_initializers: |
355 | | - raise RuntimeError( |
356 | | - f"Initializer {init.name!r} is already present in skipped_initializers." |
357 | | - ) |
358 | | - self.skipped_initializers[init_py_name] = init |
359 | | - continue |
| 355 | + size = 1 |
| 356 | + for d in init.dims: |
| 357 | + size *= d |
| 358 | + if size > _SMALL_TENSOR_SIZE: |
| 359 | + init_py_name = self._translate_onnx_var(init.name) |
| 360 | + if init_py_name in self.skipped_initializers: |
| 361 | + raise RuntimeError( |
| 362 | + f"Initializer {init.name!r} is already present in skipped_initializers." |
| 363 | + ) |
| 364 | + self.skipped_initializers[init_py_name] = init |
| 365 | + continue |
360 | 366 | node = onnx.helper.make_node( # noqa: TID251 |
361 | 367 | "Constant", |
362 | 368 | [], |
363 | 369 | [self._translate_onnx_var(init.name)], # type: ignore[list-item] |
364 | 370 | value=init, |
365 | 371 | ) |
366 | | - code.append(self._translate_node(node, opsets, indent=indent)) |
| 372 | + pyinit = self._translate_node(node, opsets, indent=indent) |
| 373 | + if pyinit: |
| 374 | + code.append(pyinit) |
367 | 375 | if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0: |
368 | 376 | raise NotImplementedError("Unable to convert sparse_initilizer into python.") |
369 | 377 | for node in graph.node: |
370 | 378 | pynode = self._translate_node(node, opsets, indent=indent) |
371 | 379 | if pynode: |
| 380 | + if node.name: |
| 381 | + pynode += f" # {node.name}" |
372 | 382 | code.append(pynode) |
373 | 383 |
|
374 | 384 | final = "\n".join(code) |
@@ -418,7 +428,8 @@ def _translate_attributes(self, node): |
418 | 428 | def _translate_if(self, node, opsets, indent=0): |
419 | 429 | """Translates a node If into python.""" |
420 | 430 | sindent = _SINGLE_INDENT * indent |
421 | | - code = [f"{sindent}if {node.input[0]}:"] |
| 431 | + cond = self._translate_onnx_var_ref(node.input[0]) |
| 432 | + code = [f"{sindent}if {cond}:"] |
422 | 433 | if len(node.attribute) != 2: |
423 | 434 | raise RuntimeError( |
424 | 435 | f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}." |
@@ -502,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0): |
502 | 513 |
|
503 | 514 | rows.extend(self._emit_assign(formal_ins, actual_ins, indent)) |
504 | 515 |
|
| 516 | + if node.name: |
| 517 | + node_name = " # " + node.name |
| 518 | + else: |
| 519 | + node_name = "" |
505 | 520 | if use_iter_var and not use_loop_cond: |
506 | | - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") |
| 521 | + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") |
507 | 522 | # The following is a hacky way to suppress the generation of |
508 | 523 | # "cond_out = cond_in", which ONNX forces for a FOR loop. |
509 | 524 | # TODO: a cleaner solution for this. |
510 | 525 | self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in) |
511 | 526 | elif not use_iter_var and use_loop_cond: |
512 | | - rows.append(f"{sindent}while {py_cond}:") |
| 527 | + rows.append(f"{sindent}while {py_cond}:{node_name}") |
513 | 528 | elif use_iter_var and use_loop_cond: |
514 | 529 | # TODO: This needs fixing |
515 | | - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") |
| 530 | + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") |
516 | 531 | rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:") |
517 | 532 | rows.append(f"{sindent}{_SINGLE_INDENT * 2}break") |
518 | 533 | else: |
@@ -734,11 +749,13 @@ def _substitute_initializers( |
734 | 749 |
|
735 | 750 | def generate_rand(name: str, value: TensorProto) -> str: |
736 | 751 | shape = ",".join(str(d) for d in value.dims) |
737 | | - if value.data_type != TensorProto.FLOAT: |
738 | | - raise NotImplementedError( |
739 | | - f"Unable to generate random initializer for data type {value.data_type}." |
740 | | - ) |
741 | | - return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" |
| 752 | + if value.data_type == TensorProto.FLOAT: |
| 753 | + return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" |
| 754 | + if value.data_type == TensorProto.INT8: |
| 755 | + return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)" |
| 756 | + raise NotImplementedError( |
| 757 | + f"Unable to generate random initializer for data type {value.data_type}." |
| 758 | + ) |
742 | 759 |
|
743 | 760 | random_initializer_values = "\n".join( |
744 | 761 | generate_rand(key, value) for key, value in self.skipped_initializers.items() |
|
0 commit comments