diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 9f47c6680..de293a74d 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -11,11 +11,12 @@ def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5): """Display the subgraph computing a given value or node upto a certain depth.""" slice = [] + def visit(node: ir.Node, depth): if node in slice: return slice.append(node) - if (depth < depth_limit): + if depth < depth_limit: if backward: for inp in node.inputs: if inp is not None and inp.producer() is not None: @@ -24,12 +25,23 @@ def visit(node: ir.Node, depth): for out in node.outputs: for consumer, _ in out.uses(): visit(consumer, depth + 1) + if isinstance(x, ir.Node): visit(x, 0) elif isinstance(x, ir.Value) and x.producer() is not None: visit(x.producer(), 0) - for node in reversed(slice): - node.display() + if slice: + graph = slice[0].graph + if graph: + # Display nodes in same order as in graph: + # Currently doesn't handle (control-flow) subgraphs + for node in graph: + if node in slice: + node.display() + else: + for node in reversed(slice): + node.display() + def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: node = value.producer()