Skip to content

Commit 34828bf

Browse files
authored
Correctly track modified bit for variety of replace ops
Differential Revision: D86725366 Pull Request resolved: #15727
1 parent 47aca69 commit 34828bf

File tree

2 files changed

+161
-160
lines changed

2 files changed

+161
-160
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 149 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
CadencePassAttribute,
3535
none_throws,
3636
register_cadence_pass,
37+
RemoveOrReplacePassInterface,
3738
)
3839
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
3940
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -115,84 +116,84 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
115116

116117

117118
@register_cadence_pass(CadencePassAttribute(opt_level=0))
118-
class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep
119+
class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): # keep
119120
"""
120121
Replace _safe_softmax with _softmax
121122
"""
122123

123-
def call_operator(
124-
self,
125-
op,
126-
args: tuple[Argument, ...],
127-
kwargs: dict[str, Argument],
128-
meta: NodeMetadata,
129-
) -> ProxyValue:
130-
if op != torch.ops.aten._safe_softmax.default:
131-
return super().call_operator(op, args, kwargs, meta)
124+
@property
125+
def targets(self) -> list[EdgeOpOverload]:
126+
return [torch.ops.aten._safe_softmax.default]
132127

128+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
133129
# Add False for the half_to_float argument of softmax
134-
softmax_args = list(args) + [False]
130+
softmax_args = tuple(list(node.args) + [False])
135131

136-
return super().call_operator(
137-
torch.ops.aten._softmax.default,
138-
tuple(softmax_args),
139-
kwargs,
140-
meta,
141-
)
132+
with node.graph.inserting_before(node):
133+
new_node = node.graph.call_function(
134+
torch.ops.aten._softmax.default,
135+
args=softmax_args,
136+
kwargs=node.kwargs,
137+
)
138+
new_node.meta = node.meta
139+
node.replace_all_uses_with(new_node)
140+
return True
142141

143142

144143
@register_cadence_pass(CadencePassAttribute(opt_level=0))
145-
class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
144+
class ReplacePT2QuantWithCadenceQuantPass(RemoveOrReplacePassInterface):
146145
"""
147146
Replace the pt2 quantization ops with cadence quantization ops.
148147
We do not link kernels to the PT2 quantization ops, so we need to
149148
replace them with cadence ops at all optimization levels.
150149
"""
151150

152-
def call_operator(
153-
self,
154-
op,
155-
args: Tuple[Argument, ...],
156-
kwargs: Dict[str, Argument],
157-
meta: NodeMetadata,
158-
) -> ProxyValue:
159-
ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops
160-
if op != ns.quantized_decomposed.quantize_per_tensor.default:
161-
return super().call_operator(op, args, kwargs, meta)
151+
@property
152+
def targets(self) -> list[EdgeOpOverload]:
153+
return [
154+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
155+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
156+
]
162157

163-
return super().call_operator(
164-
ns.cadence.quantize_per_tensor.default,
165-
args,
166-
kwargs,
167-
meta,
168-
)
158+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
159+
ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops
160+
with node.graph.inserting_before(node):
161+
new_node = node.graph.call_function(
162+
ns.cadence.quantize_per_tensor.default,
163+
args=node.args,
164+
kwargs=node.kwargs,
165+
)
166+
new_node.meta = node.meta
167+
node.replace_all_uses_with(new_node)
168+
return True
169169

170170

171171
@register_cadence_pass(CadencePassAttribute(opt_level=0))
172-
class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
172+
class ReplacePT2DequantWithCadenceDequantPass(RemoveOrReplacePassInterface):
173173
"""
174174
Replace the pt2 dequantization ops with cadence dequantization ops.
175175
We do not link kernels to the PT2 quantization ops, so we need to
176176
replace them with cadence ops at all optimization levels.
177177
"""
178178

179-
def call_operator(
180-
self,
181-
op,
182-
args: Tuple[Argument, ...],
183-
kwargs: Dict[str, Argument],
184-
meta: NodeMetadata,
185-
) -> ProxyValue:
186-
ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops
187-
if op != ns.quantized_decomposed.dequantize_per_tensor.default:
188-
return super().call_operator(op, args, kwargs, meta)
179+
@property
180+
def targets(self) -> list[EdgeOpOverload]:
181+
return [
182+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
183+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
184+
]
189185

190-
return super().call_operator(
191-
ns.cadence.dequantize_per_tensor.default,
192-
args,
193-
kwargs,
194-
meta,
195-
)
186+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
187+
ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops
188+
with node.graph.inserting_before(node):
189+
new_node = node.graph.call_function(
190+
ns.cadence.dequantize_per_tensor.default,
191+
args=node.args,
192+
kwargs=node.kwargs,
193+
)
194+
new_node.meta = node.meta
195+
node.replace_all_uses_with(new_node)
196+
return True
196197

197198

198199
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -232,18 +233,34 @@ def call_operator(
232233

233234

234235
@register_cadence_pass(CadencePassAttribute(opt_level=0))
235-
class ReplaceFunctionallyEquivalentOpTargets(ExportPass):
236+
class ReplaceFunctionallyEquivalentOpTargets(RemoveOrReplacePassInterface):
236237
"""
237238
Replace an op with a functionally equivalent op by just switching the op
238239
target, but without incurring any change to the op args.
239240
"""
240241

241-
def call_operator(self, op, args, kwargs, meta):
242-
if op not in functionally_equivalent_op_targets:
243-
return super().call_operator(op, args, kwargs, meta)
244-
return super().call_operator(
245-
functionally_equivalent_op_targets[op], args, kwargs, meta
246-
)
242+
@property
243+
def targets(self) -> list[EdgeOpOverload]:
244+
return list(functionally_equivalent_op_targets.keys())
245+
246+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
247+
assert isinstance(node.target, EdgeOpOverload)
248+
target_op = functionally_equivalent_op_targets[node.target]
249+
with node.graph.inserting_before(node):
250+
new_node = node.graph.call_function(
251+
target_op,
252+
args=node.args,
253+
kwargs=node.kwargs,
254+
)
255+
new_node.meta = node.meta
256+
node.replace_all_uses_with(new_node)
257+
258+
# RemoveOrReplacePassInterface calls eliminate_dead_code, but this doesn't
259+
# remove impure nodes (nodes which have side effects). Not sure if that is
260+
# generally safe, so instead of modifying the interface, just erasing
261+
# these nodes for this pass.
262+
node.graph.erase_node(node)
263+
return True
247264

248265

249266
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -1438,82 +1455,95 @@ def call_operator(self, op, args, kwargs, meta):
14381455

14391456

14401457
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1441-
class ReplaceScalarTensorWithFullPass(ExportPass):
1458+
class ReplaceScalarTensorWithFullPass(RemoveOrReplacePassInterface):
14421459
"""
14431460
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
14441461
scalar_tensor is not supported, so this is an opt_level=0 pass.
14451462
"""
14461463

1447-
def call_operator(
1448-
self,
1449-
op,
1450-
args: Tuple[Argument, ...],
1451-
kwargs: Dict[str, Argument],
1452-
meta: NodeMetadata,
1453-
) -> ProxyValue:
1454-
if op not in {
1455-
exir_ops.edge.aten.scalar_tensor.default,
1464+
@property
1465+
def targets(self) -> list[EdgeOpOverload]:
1466+
return [
14561467
torch.ops.aten.scalar_tensor.default,
1457-
}:
1458-
return super().call_operator(op, args, kwargs, meta)
1468+
exir_ops.edge.aten.scalar_tensor.default,
1469+
]
14591470

1460-
return super().call_operator(
1461-
exir_ops.edge.aten.full.default,
1462-
(
1463-
[1],
1464-
args[0],
1465-
),
1466-
{"dtype": torch.float32},
1467-
meta,
1468-
)
1471+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1472+
with node.graph.inserting_before(node):
1473+
new_node = node.graph.call_function(
1474+
exir_ops.edge.aten.full.default,
1475+
args=(
1476+
[1],
1477+
node.args[0],
1478+
),
1479+
kwargs={"dtype": torch.float32},
1480+
)
1481+
new_node.meta = node.meta
1482+
node.replace_all_uses_with(new_node)
1483+
return True
14691484

14701485

14711486
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1472-
class ReplaceFullLikeWithFullPass(ExportPass):
1487+
class ReplaceFullLikeWithFullPass(RemoveOrReplacePassInterface):
14731488
"""
14741489
aten.full_like can be replaced by aten.full with the shape of the arg tensor.
14751490
full_like is not supported, so this is an opt_level=0 pass.
14761491
"""
14771492

1478-
def call_operator(self, op, args, kwargs, meta):
1479-
if op not in {
1480-
exir_ops.edge.aten.full_like.default,
1481-
}:
1482-
return super().call_operator(op, args, kwargs, meta)
1493+
@property
1494+
def targets(self) -> list[EdgeOpOverload]:
1495+
return [exir_ops.edge.aten.full_like.default]
14831496

1484-
# Get the shape of the "like" tensor, and pass that in to the full op.
1485-
return super().call_operator(
1486-
exir_ops.edge.aten.full.default,
1487-
(
1488-
args[0].to_tensor().shape,
1489-
args[1],
1490-
),
1491-
{},
1492-
meta,
1493-
)
1497+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1498+
input_arg = node.args[0]
1499+
assert isinstance(input_arg, torch.fx.Node)
1500+
shape = input_arg.meta["val"].shape
1501+
fill_value = node.args[1]
1502+
1503+
with node.graph.inserting_before(node):
1504+
new_node = node.graph.call_function(
1505+
exir_ops.edge.aten.full.default,
1506+
args=(shape, fill_value),
1507+
kwargs={},
1508+
)
1509+
new_node.meta = node.meta
1510+
node.replace_all_uses_with(new_node)
1511+
return True
14941512

14951513

14961514
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1497-
class ReplaceInfArgInFullWithValuePass(ExportPass):
1515+
class ReplaceInfArgInFullWithValuePass(RemoveOrReplacePassInterface):
14981516
"""
14991517
aten.full allows "-inf" and "inf" as inputs. The profiler cannot
15001518
handle that, so replace them with the maximum value of the type.
15011519
"""
15021520

1503-
def call_operator(self, op, args, kwargs, meta):
1504-
if op not in {
1505-
exir_ops.edge.aten.full.default,
1506-
}:
1507-
return super().call_operator(op, args, kwargs, meta)
1521+
@property
1522+
def targets(self) -> list[EdgeOpOverload]:
1523+
return [exir_ops.edge.aten.full.default]
15081524

1509-
new_args = list(args)
1525+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
15101526

1511-
if args[1] == float("-inf"):
1527+
new_args = list(node.args)
1528+
fill_value = node.args[1]
1529+
if fill_value == float("-inf"):
15121530
new_args[1] = torch.finfo(torch.float32).min
1513-
elif args[1] == float("inf"):
1531+
elif fill_value == float("inf"):
15141532
new_args[1] = torch.finfo(torch.float32).max
1533+
else:
1534+
return False
15151535

1516-
return super().call_operator(op, tuple(new_args), kwargs, meta)
1536+
new_args = tuple(new_args)
1537+
1538+
with node.graph.inserting_before(node):
1539+
new_node = node.graph.call_function(
1540+
exir_ops.edge.aten.full.default,
1541+
args=new_args,
1542+
kwargs=node.kwargs,
1543+
)
1544+
new_node.meta = node.meta
1545+
node.replace_all_uses_with(new_node)
1546+
return True
15171547

15181548

15191549
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -1713,26 +1743,6 @@ def call_operator(
17131743
return super().call_operator(op, args, kwargs, meta)
17141744

17151745

1716-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1717-
class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass):
1718-
"""
1719-
Replace the aten gelu op with an approximate arg with an approximate gelu op.
1720-
"""
1721-
1722-
def call_operator(
1723-
self,
1724-
op,
1725-
args: Tuple[Argument, ...],
1726-
kwargs: Dict[str, Argument],
1727-
meta: NodeMetadata,
1728-
) -> ProxyValue:
1729-
if op not in {
1730-
exir_ops.edge.aten.gelu.default,
1731-
}:
1732-
return super().call_operator(op, args, kwargs, meta)
1733-
return super().call_operator(op, args, kwargs, meta)
1734-
1735-
17361746
# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
17371747
@register_cadence_pass(CadencePassAttribute(opt_level=2))
17381748
class ReplaceSplitWithSlicePass(ExportPass):
@@ -2122,18 +2132,25 @@ class CommonReplacePasses:
21222132

21232133

21242134
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2125-
class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass):
2135+
class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(RemoveOrReplacePassInterface):
21262136
"""
21272137
Replace aten linalg svd op with cadence custom op.
21282138
"""
21292139

2130-
def call_operator(self, op, args, kwargs, meta):
2131-
if op != exir_ops.edge.aten._linalg_svd.default:
2132-
return super().call_operator(op, args, kwargs, meta)
2140+
@property
2141+
def targets(self) -> list[EdgeOpOverload]:
2142+
return [exir_ops.edge.aten._linalg_svd.default]
21332143

2134-
return super().call_operator(
2135-
exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta
2136-
)
2144+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
2145+
with node.graph.inserting_before(node):
2146+
new_node = node.graph.call_function(
2147+
exir_ops.edge.cadence.linalg_svd.default,
2148+
args=node.args,
2149+
kwargs=node.kwargs,
2150+
)
2151+
new_node.meta = node.meta
2152+
node.replace_all_uses_with(new_node)
2153+
return True
21372154

21382155

21392156
# This class encapsulates all the functions that replace/switch one op in the
@@ -2165,6 +2182,5 @@ class CadenceReplaceOpsInGraph:
21652182
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
21662183
ReplaceAtenAvgPoolWithCadenceAvgPoolPass,
21672184
ReplaceWhereWithFullArgsWithWhereScalar,
2168-
ReplaceAtenApproxGeluWithApproxGeluPass,
21692185
ReplaceMulTensorWithMulAndFullOpsPass,
21702186
]

0 commit comments

Comments
 (0)