Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No

arg_repset = op_repsets.get_arg_repset(arg_i)
if arg_repset.is_constrained():
return arg_repset
return

arg_node = op_repsets.op_node.args[arg_i]

Expand All @@ -378,21 +378,33 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset)

def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
"""
Similar to the `constrain_op_arg_repset` function, but for the output repset of
the operator.
"""
out_repset = op_repsets.get_out_repset(0)
if out_repset.is_constrained():
return

op_node = op_repsets.op_node
out_respset = self.trace_node_users_to_constrain_repset(op_node, out_repset)

op_repsets.try_constrain_with_out_repset(out_respset)

def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None:
# For most ops, constraining the argument repsets will also contrain the output
# repset due to OpRepSets maintaining synchronization rules.
for i in range(len(op_repsets.op_node.args)):
if utils.is_tensor_arg_node(op_repsets.op_node.args[i]):
self.constrain_op_arg_repset(i, op_repsets)

# TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there
# is no need to constrain output repsets explicitly. Currently, the exceptions
# (i.e. choose qparams) already define constrined repsets for the output, so
# there is again no need to explicitly constrain the outputs. If an operator
# appears later on that does not sync input and output representations, and
# defines ambiguous repsets for the output tensor(s), then we will need to add
# additional logic to this function to constrain the output repsets separately
# from the input repsets.
# However, some operators do not sync input and output representations and also
# define ambiguous repsets for the output tensor(s). In those cases we will need
# to execute additional logic to constrain the output repsets separately from
# the input repsets.
if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr:
self.constrain_op_out_repset(op_repsets)

def set_op_node_tensor_reprs(
self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node
Expand Down
19 changes: 19 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,25 @@ def try_constrain_with_arg_repset(
self.assert_sync_contraints()
return True

def try_constrain_with_out_repset(self, repset: TensorRepSet):
# Skip for operators that must synchronize the input and output representations
# or operators that have more than one output repset
if self.sync_primary_io_repr or len(self.outs_repset_list) > 1:
return False

out_current_repset = self.outs_repset_list[0]

if out_current_repset == repset:
return False

if not out_current_repset.any_in_common(repset):
return False

self.outs_repset_list[0] = out_current_repset.make_intersect(repset)

self.assert_sync_contraints()
return True

def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]:
"""
For each tensor participating in the op, pick a representation for it among the
Expand Down
Loading