@@ -368,7 +368,7 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
368368
369369 arg_repset = op_repsets .get_arg_repset (arg_i )
370370 if arg_repset .is_constrained ():
371- return arg_repset
371+ return
372372
373373 arg_node = op_repsets .op_node .args [arg_i ]
374374
@@ -378,21 +378,33 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
378378 arg_repset = self .trace_node_users_to_constrain_repset (arg_node , arg_repset )
379379 op_repsets .try_constrain_with_arg_repset (arg_i , arg_repset )
380380
381+ def constrain_op_out_repset (self , op_repsets : utils .OpRepSets ) -> None :
382+ """
383+ Similar to the `constrain_op_arg_repset` function, but for the output repset of
384+ the operator.
385+ """
386+ out_repset = op_repsets .get_out_repset (0 )
387+ if out_repset .is_constrained ():
388+ return
389+
390+ op_node = op_repsets .op_node
391+ out_respset = self .trace_node_users_to_constrain_repset (op_node , out_repset )
392+
393+ op_repsets .try_constrain_with_out_repset (out_respset )
394+
381395 def constrain_op_repsets (self , op_repsets : utils .OpRepSets ) -> None :
382396 # For most ops, constraining the argument repsets will also contrain the output
383397 # repset due to OpRepSets maintaining synchronization rules.
384398 for i in range (len (op_repsets .op_node .args )):
385399 if utils .is_tensor_arg_node (op_repsets .op_node .args [i ]):
386400 self .constrain_op_arg_repset (i , op_repsets )
387401
388- # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there
389- # is no need to constrain output repsets explicitly. Currently, the exceptions
390- # (i.e. choose qparams) already define constrined repsets for the output, so
391- # there is again no need to explicitly constrain the outputs. If an operator
392- # appears later on that does not sync input and output representations, and
393- # defines ambiguous repsets for the output tensor(s), then we will need to add
394- # additional logic to this function to constrain the output repsets separately
395- # from the input repsets.
402+ # However, some operators do not sync input and output representations and also
403+ # define ambiguous repsets for the output tensor(s). In those cases we will need
404+ # to execute additional logic to constrain the output repsets separately from
405+ # the input repsets.
406+ if not op_repsets .sync_primary_io_repr and op_repsets .sync_outs_repr :
407+ self .constrain_op_out_repset (op_repsets )
396408
397409 def set_op_node_tensor_reprs (
398410 self , graph_module : torch .fx .GraphModule , op_node : torch .fx .Node
0 commit comments