From d81480b530ec7851246bd5555f572c17893e6cd2 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 14 Nov 2024 14:01:53 -0800 Subject: [PATCH] A couple of bug fixes (#1945) Fixes a couple of bugs that show up in GPT2 optimization. --- onnxscript/optimizer/_inliner.py | 4 ++-- onnxscript/rewriter/collapse_slices.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxscript/optimizer/_inliner.py b/onnxscript/optimizer/_inliner.py index c35926301..798bc302a 100644 --- a/onnxscript/optimizer/_inliner.py +++ b/onnxscript/optimizer/_inliner.py @@ -236,9 +236,9 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl # Identify call-stack for node, used to generate unique names. call_stack = self.node_context.get(node, []) - call_stack.append(call_site_id) + new_call_stack = [*call_stack, call_site_id] - cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, call_stack) + cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack) # iterate over the nodes in the function, creating a copy of each node # and replacing inputs with the corresponding values in the value map. diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/collapse_slices.py index 57d9baf28..2615432e7 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/collapse_slices.py @@ -28,6 +28,10 @@ def _check_if_redundant_slice( axes_const = axes.const_value steps_const = steps.const_value + if starts_const is None or ends_const is None or axes_const is None or steps_const is None: + logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.") + return False + # Check if the values are scalar if starts_const.numpy().size != 1: # type: ignore[union-attr] logger.info("The value 'start' is not a scalar.") @@ -42,9 +46,6 @@ def _check_if_redundant_slice( logger.info("The value 'step' is not a scalar.") return False - if starts_const is None or ends_const is None or axes_const is None or steps_const is None: - logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.") - return False if steps_const.numpy().item() != 1: logger.info("The value 'step' is not 1.") return False