diff --git a/exir/graph_module.py b/exir/graph_module.py index e26d22d8145..2adf62ab0b8 100644 --- a/exir/graph_module.py +++ b/exir/graph_module.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +11,7 @@ from typing import Callable, Dict, List, Tuple, Union import torch +from torch._ops import HigherOrderOperator LeafValue = Union[ @@ -46,14 +48,15 @@ def _get_submodule( return submod_node.target, submodule, node -def get_control_flow_submodules( +def _get_control_flow_submodules( graph_module: torch.fx.GraphModule, + op_to_submodule_arg_index: dict[HigherOrderOperator, list[int]], ) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: """ Returns a list of submodules used for control flow operations - (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look - into submodules). Specifically, the returned value is a list containing a - tuple of (name of the submodule that's stored in the graph module, the + that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the submodule itself, and the fx node that uses this submodule). """ control_flow_submodules = [] @@ -61,15 +64,50 @@ def get_control_flow_submodules( if node.op != "call_function": continue - if node.target is torch.ops.higher_order.cond: - control_flow_submodules.append(_get_submodule(graph_module, node, 1)) - control_flow_submodules.append(_get_submodule(graph_module, node, 2)) - if node.target is torch.ops.higher_order.map_impl: - control_flow_submodules.append(_get_submodule(graph_module, node, 0)) + for op in op_to_submodule_arg_index: + if node.target is not op: + continue + for i in op_to_submodule_arg_index[op]: + control_flow_submodules.append(_get_submodule(graph_module, node, i)) return control_flow_submodules +def get_control_flow_submodules( + graph_module: torch.fx.GraphModule, +) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + return _get_control_flow_submodules( + graph_module, + {torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]}, + ) + + +def get_cond_while_submodules( + graph_module: torch.fx.GraphModule, +) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/while_loop) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + return _get_control_flow_submodules( + graph_module, + { + torch.ops.higher_order.cond: [1, 2], + torch.ops.higher_order.while_loop: [0, 1], + }, + ) + + def bfs_trace_with_node_process( gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None] ) -> None: