Skip to content

Commit 4dde949

Browse files
authored
[ET-VK][ez] Constrain out repsets individually
Differential Revision: D86674164 Pull Request resolved: #15704
1 parent bff7bbf commit 4dde949

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
368368

369369
arg_repset = op_repsets.get_arg_repset(arg_i)
370370
if arg_repset.is_constrained():
371-
return arg_repset
371+
return
372372

373373
arg_node = op_repsets.op_node.args[arg_i]
374374

@@ -378,21 +378,33 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
378378
arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset)
379379
op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset)
380380

381+
def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
382+
"""
383+
Similar to the `constrain_op_arg_repset` function, but for the output repset of
384+
the operator.
385+
"""
386+
out_repset = op_repsets.get_out_repset(0)
387+
if out_repset.is_constrained():
388+
return
389+
390+
op_node = op_repsets.op_node
391+
out_respset = self.trace_node_users_to_constrain_repset(op_node, out_repset)
392+
393+
op_repsets.try_constrain_with_out_repset(out_respset)
394+
381395
def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None:
382396
# For most ops, constraining the argument repsets will also contrain the output
383397
# repset due to OpRepSets maintaining synchronization rules.
384398
for i in range(len(op_repsets.op_node.args)):
385399
if utils.is_tensor_arg_node(op_repsets.op_node.args[i]):
386400
self.constrain_op_arg_repset(i, op_repsets)
387401

388-
# TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there
389-
# is no need to constrain output repsets explicitly. Currently, the exceptions
390-
# (i.e. choose qparams) already define constrined repsets for the output, so
391-
# there is again no need to explicitly constrain the outputs. If an operator
392-
# appears later on that does not sync input and output representations, and
393-
# defines ambiguous repsets for the output tensor(s), then we will need to add
394-
# additional logic to this function to constrain the output repsets separately
395-
# from the input repsets.
402+
# However, some operators do not sync input and output representations and also
403+
# define ambiguous repsets for the output tensor(s). In those cases we will need
404+
# to execute additional logic to constrain the output repsets separately from
405+
# the input repsets.
406+
if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr:
407+
self.constrain_op_out_repset(op_repsets)
396408

397409
def set_op_node_tensor_reprs(
398410
self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node

backends/vulkan/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,25 @@ def try_constrain_with_arg_repset(
11061106
self.assert_sync_contraints()
11071107
return True
11081108

1109+
def try_constrain_with_out_repset(self, repset: TensorRepSet):
1110+
# Skip for operators that must synchronize the input and output representations
1111+
# or operators that have more than one output repset
1112+
if self.sync_primary_io_repr or len(self.outs_repset_list) > 1:
1113+
return False
1114+
1115+
out_current_repset = self.outs_repset_list[0]
1116+
1117+
if out_current_repset == repset:
1118+
return False
1119+
1120+
if not out_current_repset.any_in_common(repset):
1121+
return False
1122+
1123+
self.outs_repset_list[0] = out_current_repset.make_intersect(repset)
1124+
1125+
self.assert_sync_contraints()
1126+
return True
1127+
11091128
def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]:
11101129
"""
11111130
For each tensor participating in the op, pick a representation for it among the

0 commit comments

Comments
 (0)