|
34 | 34 | CadencePassAttribute, |
35 | 35 | none_throws, |
36 | 36 | register_cadence_pass, |
| 37 | + RemoveOrReplacePassInterface, |
37 | 38 | ) |
38 | 39 | from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass |
39 | 40 | from executorch.backends.cadence.aot.utils import get_edge_overload_packet |
@@ -115,84 +116,84 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
115 | 116 |
|
116 | 117 |
|
117 | 118 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
118 | | -class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep |
| 119 | +class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): # keep |
119 | 120 | """ |
120 | 121 | Replace _safe_softmax with _softmax |
121 | 122 | """ |
122 | 123 |
|
123 | | - def call_operator( |
124 | | - self, |
125 | | - op, |
126 | | - args: tuple[Argument, ...], |
127 | | - kwargs: dict[str, Argument], |
128 | | - meta: NodeMetadata, |
129 | | - ) -> ProxyValue: |
130 | | - if op != torch.ops.aten._safe_softmax.default: |
131 | | - return super().call_operator(op, args, kwargs, meta) |
| 124 | + @property |
| 125 | + def targets(self) -> list[EdgeOpOverload]: |
| 126 | + return [torch.ops.aten._safe_softmax.default] |
132 | 127 |
|
| 128 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
133 | 129 | # Add False for the half_to_float argument of softmax |
134 | | - softmax_args = list(args) + [False] |
| 130 | + softmax_args = tuple(list(node.args) + [False]) |
135 | 131 |
|
136 | | - return super().call_operator( |
137 | | - torch.ops.aten._softmax.default, |
138 | | - tuple(softmax_args), |
139 | | - kwargs, |
140 | | - meta, |
141 | | - ) |
| 132 | + with node.graph.inserting_before(node): |
| 133 | + new_node = node.graph.call_function( |
| 134 | + torch.ops.aten._softmax.default, |
| 135 | + args=softmax_args, |
| 136 | + kwargs=node.kwargs, |
| 137 | + ) |
| 138 | + new_node.meta = node.meta |
| 139 | + node.replace_all_uses_with(new_node) |
| 140 | + return True |
142 | 141 |
|
143 | 142 |
|
144 | 143 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
145 | | -class ReplacePT2QuantWithCadenceQuantPass(ExportPass): |
| 144 | +class ReplacePT2QuantWithCadenceQuantPass(RemoveOrReplacePassInterface): |
146 | 145 | """ |
147 | 146 | Replace the pt2 quantization ops with cadence quantization ops. |
148 | 147 | We do not link kernels to the PT2 quantization ops, so we need to |
149 | 148 | replace them with cadence ops at all optimization levels. |
150 | 149 | """ |
151 | 150 |
|
152 | | - def call_operator( |
153 | | - self, |
154 | | - op, |
155 | | - args: Tuple[Argument, ...], |
156 | | - kwargs: Dict[str, Argument], |
157 | | - meta: NodeMetadata, |
158 | | - ) -> ProxyValue: |
159 | | - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops |
160 | | - if op != ns.quantized_decomposed.quantize_per_tensor.default: |
161 | | - return super().call_operator(op, args, kwargs, meta) |
| 151 | + @property |
| 152 | + def targets(self) -> list[EdgeOpOverload]: |
| 153 | + return [ |
| 154 | + torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| 155 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 156 | + ] |
162 | 157 |
|
163 | | - return super().call_operator( |
164 | | - ns.cadence.quantize_per_tensor.default, |
165 | | - args, |
166 | | - kwargs, |
167 | | - meta, |
168 | | - ) |
| 158 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 159 | + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops |
| 160 | + with node.graph.inserting_before(node): |
| 161 | + new_node = node.graph.call_function( |
| 162 | + ns.cadence.quantize_per_tensor.default, |
| 163 | + args=node.args, |
| 164 | + kwargs=node.kwargs, |
| 165 | + ) |
| 166 | + new_node.meta = node.meta |
| 167 | + node.replace_all_uses_with(new_node) |
| 168 | + return True |
169 | 169 |
|
170 | 170 |
|
171 | 171 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
172 | | -class ReplacePT2DequantWithCadenceDequantPass(ExportPass): |
| 172 | +class ReplacePT2DequantWithCadenceDequantPass(RemoveOrReplacePassInterface): |
173 | 173 | """ |
174 | 174 | Replace the pt2 dequantization ops with cadence dequantization ops. |
175 | 175 | We do not link kernels to the PT2 quantization ops, so we need to |
176 | 176 | replace them with cadence ops at all optimization levels. |
177 | 177 | """ |
178 | 178 |
|
179 | | - def call_operator( |
180 | | - self, |
181 | | - op, |
182 | | - args: Tuple[Argument, ...], |
183 | | - kwargs: Dict[str, Argument], |
184 | | - meta: NodeMetadata, |
185 | | - ) -> ProxyValue: |
186 | | - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops |
187 | | - if op != ns.quantized_decomposed.dequantize_per_tensor.default: |
188 | | - return super().call_operator(op, args, kwargs, meta) |
| 179 | + @property |
| 180 | + def targets(self) -> list[EdgeOpOverload]: |
| 181 | + return [ |
| 182 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 183 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 184 | + ] |
189 | 185 |
|
190 | | - return super().call_operator( |
191 | | - ns.cadence.dequantize_per_tensor.default, |
192 | | - args, |
193 | | - kwargs, |
194 | | - meta, |
195 | | - ) |
| 186 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 187 | + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops |
| 188 | + with node.graph.inserting_before(node): |
| 189 | + new_node = node.graph.call_function( |
| 190 | + ns.cadence.dequantize_per_tensor.default, |
| 191 | + args=node.args, |
| 192 | + kwargs=node.kwargs, |
| 193 | + ) |
| 194 | + new_node.meta = node.meta |
| 195 | + node.replace_all_uses_with(new_node) |
| 196 | + return True |
196 | 197 |
|
197 | 198 |
|
198 | 199 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
@@ -232,18 +233,34 @@ def call_operator( |
232 | 233 |
|
233 | 234 |
|
234 | 235 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
235 | | -class ReplaceFunctionallyEquivalentOpTargets(ExportPass): |
| 236 | +class ReplaceFunctionallyEquivalentOpTargets(RemoveOrReplacePassInterface): |
236 | 237 | """ |
237 | 238 | Replace an op with a functionally equivalent op by just switching the op |
238 | 239 | target, but without incurring any change to the op args. |
239 | 240 | """ |
240 | 241 |
|
241 | | - def call_operator(self, op, args, kwargs, meta): |
242 | | - if op not in functionally_equivalent_op_targets: |
243 | | - return super().call_operator(op, args, kwargs, meta) |
244 | | - return super().call_operator( |
245 | | - functionally_equivalent_op_targets[op], args, kwargs, meta |
246 | | - ) |
| 242 | + @property |
| 243 | + def targets(self) -> list[EdgeOpOverload]: |
| 244 | + return list(functionally_equivalent_op_targets.keys()) |
| 245 | + |
| 246 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 247 | + assert isinstance(node.target, EdgeOpOverload) |
| 248 | + target_op = functionally_equivalent_op_targets[node.target] |
| 249 | + with node.graph.inserting_before(node): |
| 250 | + new_node = node.graph.call_function( |
| 251 | + target_op, |
| 252 | + args=node.args, |
| 253 | + kwargs=node.kwargs, |
| 254 | + ) |
| 255 | + new_node.meta = node.meta |
| 256 | + node.replace_all_uses_with(new_node) |
| 257 | + |
| 258 | + # RemoveOrReplacePassInterface calls eliminate_dead_code, but this doesn't |
| 259 | + # remove impure nodes (nodes which have side effects). Not sure if that is |
| 260 | + # generally safe, so instead of modifying the interface, just erasing |
| 261 | + # these nodes for this pass. |
| 262 | + node.graph.erase_node(node) |
| 263 | + return True |
247 | 264 |
|
248 | 265 |
|
249 | 266 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
@@ -1438,82 +1455,95 @@ def call_operator(self, op, args, kwargs, meta): |
1438 | 1455 |
|
1439 | 1456 |
|
1440 | 1457 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1441 | | -class ReplaceScalarTensorWithFullPass(ExportPass): |
| 1458 | +class ReplaceScalarTensorWithFullPass(RemoveOrReplacePassInterface): |
1442 | 1459 | """ |
1443 | 1460 | aten.scalar_tensor can be replaced by aten.full with a shape of [1]. |
1444 | 1461 | scalar_tensor is not supported, so this is an opt_level=0 pass. |
1445 | 1462 | """ |
1446 | 1463 |
|
1447 | | - def call_operator( |
1448 | | - self, |
1449 | | - op, |
1450 | | - args: Tuple[Argument, ...], |
1451 | | - kwargs: Dict[str, Argument], |
1452 | | - meta: NodeMetadata, |
1453 | | - ) -> ProxyValue: |
1454 | | - if op not in { |
1455 | | - exir_ops.edge.aten.scalar_tensor.default, |
| 1464 | + @property |
| 1465 | + def targets(self) -> list[EdgeOpOverload]: |
| 1466 | + return [ |
1456 | 1467 | torch.ops.aten.scalar_tensor.default, |
1457 | | - }: |
1458 | | - return super().call_operator(op, args, kwargs, meta) |
| 1468 | + exir_ops.edge.aten.scalar_tensor.default, |
| 1469 | + ] |
1459 | 1470 |
|
1460 | | - return super().call_operator( |
1461 | | - exir_ops.edge.aten.full.default, |
1462 | | - ( |
1463 | | - [1], |
1464 | | - args[0], |
1465 | | - ), |
1466 | | - {"dtype": torch.float32}, |
1467 | | - meta, |
1468 | | - ) |
| 1471 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1472 | + with node.graph.inserting_before(node): |
| 1473 | + new_node = node.graph.call_function( |
| 1474 | + exir_ops.edge.aten.full.default, |
| 1475 | + args=( |
| 1476 | + [1], |
| 1477 | + node.args[0], |
| 1478 | + ), |
| 1479 | + kwargs={"dtype": torch.float32}, |
| 1480 | + ) |
| 1481 | + new_node.meta = node.meta |
| 1482 | + node.replace_all_uses_with(new_node) |
| 1483 | + return True |
1469 | 1484 |
|
1470 | 1485 |
|
1471 | 1486 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1472 | | -class ReplaceFullLikeWithFullPass(ExportPass): |
| 1487 | +class ReplaceFullLikeWithFullPass(RemoveOrReplacePassInterface): |
1473 | 1488 | """ |
1474 | 1489 | aten.full_like can be replaced by aten.full with the shape of the arg tensor. |
1475 | 1490 | full_like is not supported, so this is an opt_level=0 pass. |
1476 | 1491 | """ |
1477 | 1492 |
|
1478 | | - def call_operator(self, op, args, kwargs, meta): |
1479 | | - if op not in { |
1480 | | - exir_ops.edge.aten.full_like.default, |
1481 | | - }: |
1482 | | - return super().call_operator(op, args, kwargs, meta) |
| 1493 | + @property |
| 1494 | + def targets(self) -> list[EdgeOpOverload]: |
| 1495 | + return [exir_ops.edge.aten.full_like.default] |
1483 | 1496 |
|
1484 | | - # Get the shape of the "like" tensor, and pass that in to the full op. |
1485 | | - return super().call_operator( |
1486 | | - exir_ops.edge.aten.full.default, |
1487 | | - ( |
1488 | | - args[0].to_tensor().shape, |
1489 | | - args[1], |
1490 | | - ), |
1491 | | - {}, |
1492 | | - meta, |
1493 | | - ) |
| 1497 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1498 | + input_arg = node.args[0] |
| 1499 | + assert isinstance(input_arg, torch.fx.Node) |
| 1500 | + shape = input_arg.meta["val"].shape |
| 1501 | + fill_value = node.args[1] |
| 1502 | + |
| 1503 | + with node.graph.inserting_before(node): |
| 1504 | + new_node = node.graph.call_function( |
| 1505 | + exir_ops.edge.aten.full.default, |
| 1506 | + args=(shape, fill_value), |
| 1507 | + kwargs={}, |
| 1508 | + ) |
| 1509 | + new_node.meta = node.meta |
| 1510 | + node.replace_all_uses_with(new_node) |
| 1511 | + return True |
1494 | 1512 |
|
1495 | 1513 |
|
1496 | 1514 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1497 | | -class ReplaceInfArgInFullWithValuePass(ExportPass): |
| 1515 | +class ReplaceInfArgInFullWithValuePass(RemoveOrReplacePassInterface): |
1498 | 1516 | """ |
1499 | 1517 | aten.full allows "-inf" and "inf" as inputs. The profiler cannot |
1500 | 1518 | handle that, so replace them with the maximum value of the type. |
1501 | 1519 | """ |
1502 | 1520 |
|
1503 | | - def call_operator(self, op, args, kwargs, meta): |
1504 | | - if op not in { |
1505 | | - exir_ops.edge.aten.full.default, |
1506 | | - }: |
1507 | | - return super().call_operator(op, args, kwargs, meta) |
| 1521 | + @property |
| 1522 | + def targets(self) -> list[EdgeOpOverload]: |
| 1523 | + return [exir_ops.edge.aten.full.default] |
1508 | 1524 |
|
1509 | | - new_args = list(args) |
| 1525 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
1510 | 1526 |
|
1511 | | - if args[1] == float("-inf"): |
| 1527 | + new_args = list(node.args) |
| 1528 | + fill_value = node.args[1] |
| 1529 | + if fill_value == float("-inf"): |
1512 | 1530 | new_args[1] = torch.finfo(torch.float32).min |
1513 | | - elif args[1] == float("inf"): |
| 1531 | + elif fill_value == float("inf"): |
1514 | 1532 | new_args[1] = torch.finfo(torch.float32).max |
| 1533 | + else: |
| 1534 | + return False |
1515 | 1535 |
|
1516 | | - return super().call_operator(op, tuple(new_args), kwargs, meta) |
| 1536 | + new_args = tuple(new_args) |
| 1537 | + |
| 1538 | + with node.graph.inserting_before(node): |
| 1539 | + new_node = node.graph.call_function( |
| 1540 | + exir_ops.edge.aten.full.default, |
| 1541 | + args=new_args, |
| 1542 | + kwargs=node.kwargs, |
| 1543 | + ) |
| 1544 | + new_node.meta = node.meta |
| 1545 | + node.replace_all_uses_with(new_node) |
| 1546 | + return True |
1517 | 1547 |
|
1518 | 1548 |
|
1519 | 1549 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
@@ -1713,26 +1743,6 @@ def call_operator( |
1713 | 1743 | return super().call_operator(op, args, kwargs, meta) |
1714 | 1744 |
|
1715 | 1745 |
|
1716 | | -@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1717 | | -class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass): |
1718 | | - """ |
1719 | | - Replace the aten gelu op with an approximate arg with an approximate gelu op. |
1720 | | - """ |
1721 | | - |
1722 | | - def call_operator( |
1723 | | - self, |
1724 | | - op, |
1725 | | - args: Tuple[Argument, ...], |
1726 | | - kwargs: Dict[str, Argument], |
1727 | | - meta: NodeMetadata, |
1728 | | - ) -> ProxyValue: |
1729 | | - if op not in { |
1730 | | - exir_ops.edge.aten.gelu.default, |
1731 | | - }: |
1732 | | - return super().call_operator(op, args, kwargs, meta) |
1733 | | - return super().call_operator(op, args, kwargs, meta) |
1734 | | - |
1735 | | - |
1736 | 1746 | # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py |
1737 | 1747 | @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
1738 | 1748 | class ReplaceSplitWithSlicePass(ExportPass): |
@@ -2122,18 +2132,25 @@ class CommonReplacePasses: |
2122 | 2132 |
|
2123 | 2133 |
|
2124 | 2134 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
2125 | | -class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass): |
| 2135 | +class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(RemoveOrReplacePassInterface): |
2126 | 2136 | """ |
2127 | 2137 | Replace aten linalg svd op with cadence custom op. |
2128 | 2138 | """ |
2129 | 2139 |
|
2130 | | - def call_operator(self, op, args, kwargs, meta): |
2131 | | - if op != exir_ops.edge.aten._linalg_svd.default: |
2132 | | - return super().call_operator(op, args, kwargs, meta) |
| 2140 | + @property |
| 2141 | + def targets(self) -> list[EdgeOpOverload]: |
| 2142 | + return [exir_ops.edge.aten._linalg_svd.default] |
2133 | 2143 |
|
2134 | | - return super().call_operator( |
2135 | | - exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta |
2136 | | - ) |
| 2144 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 2145 | + with node.graph.inserting_before(node): |
| 2146 | + new_node = node.graph.call_function( |
| 2147 | + exir_ops.edge.cadence.linalg_svd.default, |
| 2148 | + args=node.args, |
| 2149 | + kwargs=node.kwargs, |
| 2150 | + ) |
| 2151 | + new_node.meta = node.meta |
| 2152 | + node.replace_all_uses_with(new_node) |
| 2153 | + return True |
2137 | 2154 |
|
2138 | 2155 |
|
2139 | 2156 | # This class encapsulates all the functions that replace/switch one op in the |
@@ -2165,6 +2182,5 @@ class CadenceReplaceOpsInGraph: |
2165 | 2182 | ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, |
2166 | 2183 | ReplaceAtenAvgPoolWithCadenceAvgPoolPass, |
2167 | 2184 | ReplaceWhereWithFullArgsWithWhereScalar, |
2168 | | - ReplaceAtenApproxGeluWithApproxGeluPass, |
2169 | 2185 | ReplaceMulTensorWithMulAndFullOpsPass, |
2170 | 2186 | ] |
0 commit comments