Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions exir/graph_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,6 +11,7 @@
from typing import Callable, Dict, List, Tuple, Union

import torch
from torch._ops import HigherOrderOperator


LeafValue = Union[
Expand Down Expand Up @@ -46,30 +48,66 @@ def _get_submodule(
return submod_node.target, submodule, node


def get_control_flow_submodules(
def _get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
op_to_submodule_arg_index: dict[HigherOrderOperator, list[int]],
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the submodule that's stored in the graph module, the
that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
control_flow_submodules = []
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target is torch.ops.higher_order.cond:
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
if node.target is torch.ops.higher_order.map_impl:
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
for op in op_to_submodule_arg_index:
if node.target is not op:
continue
for i in op_to_submodule_arg_index[op]:
control_flow_submodules.append(_get_submodule(graph_module, node, i))

return control_flow_submodules


def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
return _get_control_flow_submodules(
graph_module,
{torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]},
)


def get_cond_while_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/while_loop) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
return _get_control_flow_submodules(
graph_module,
{
torch.ops.higher_order.cond: [1, 2],
torch.ops.higher_order.while_loop: [0, 1],
},
)


def bfs_trace_with_node_process(
gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None]
) -> None:
Expand Down
Loading