-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Selective ATen decompositions (#2173)
- Loading branch information
Showing
6 changed files
with
314 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
200 changes: 200 additions & 0 deletions
200
py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
from typing import Any, Callable, Dict, Set | ||
|
||
import torch | ||
from torch._decomp import core_aten_decompositions | ||
from torch._decomp import get_decompositions as get_torch_decompositions | ||
from torch._ops import OpOverload | ||
|
||
aten = torch.ops.aten | ||
|
||
_core_aten_decompositions: Dict[ | ||
OpOverload, Callable[[Any], Any] | ||
] = core_aten_decompositions() | ||
torch_enabled_decompositions: Set[OpOverload] = { | ||
aten._adaptive_avg_pool2d_backward, | ||
aten.addcdiv, | ||
aten.addcdiv_, | ||
aten.addcmul, | ||
aten.addcmul_, | ||
aten.addr, | ||
aten.aminmax, | ||
aten.arange.default, | ||
aten.arange.start, | ||
aten.avg_pool2d_backward, | ||
aten.binary_cross_entropy, | ||
aten.binary_cross_entropy_backward, | ||
aten.binary_cross_entropy_with_logits, | ||
aten.celu, | ||
aten.col2im, | ||
aten.count_nonzero, | ||
aten.cudnn_batch_norm, | ||
aten.cudnn_batch_norm_backward, | ||
aten.deg2rad, | ||
aten.detach, | ||
aten.diag_embed, | ||
aten.diagonal_backward, | ||
aten.dot, | ||
aten.elu, | ||
aten.elu_backward, | ||
aten._embedding_bag, | ||
aten.embedding_dense_backward, | ||
aten._euclidean_dist.default, | ||
aten.expand_as, | ||
aten.eye, | ||
aten.fill, | ||
aten.frac, | ||
aten._fused_moving_avg_obs_fq_helper, | ||
aten.gelu, | ||
aten.gelu_backward, | ||
aten.glu_backward, | ||
aten.grid_sampler_2d, | ||
aten.hardshrink, | ||
aten.hardshrink_backward, | ||
aten.hardsigmoid, | ||
aten.hardsigmoid_backward, | ||
aten.hardswish, | ||
aten.hardswish_, | ||
aten.hardswish_backward, | ||
aten.hardtanh, | ||
aten.hardtanh_, | ||
aten.hardtanh_backward, | ||
aten.heaviside, | ||
aten.huber_loss, | ||
aten.huber_loss_backward, | ||
aten.im2col, | ||
aten.index_add, | ||
aten.index_add_, | ||
aten.index_copy, | ||
aten.index_copy_, | ||
aten.index_fill, | ||
aten.index_fill_, | ||
aten.index_select, | ||
aten.isneginf, | ||
aten.isposinf, | ||
aten.l1_loss, | ||
aten.leaky_relu, | ||
aten.leaky_relu_, | ||
aten.leaky_relu_backward, | ||
aten.lerp, | ||
aten.linspace, | ||
aten.logaddexp, | ||
aten.logaddexp2, | ||
aten.logit, | ||
aten.logit_backward, | ||
aten.log_sigmoid_backward, | ||
aten.log_sigmoid_forward, | ||
aten._log_softmax, | ||
aten._log_softmax_backward_data, | ||
aten.logspace, | ||
aten.logsumexp.default, | ||
aten.masked_fill, | ||
aten.masked_fill_, | ||
aten.max_pool2d_with_indices_backward, | ||
aten.mish, | ||
aten.mse_loss, | ||
aten.mse_loss_backward, | ||
aten.mv, | ||
aten.mvlgamma, | ||
aten.nansum, | ||
aten.nan_to_num, | ||
aten.narrow, | ||
# TODO: Disable the below operators once freezing is done | ||
aten.native_batch_norm, | ||
aten.native_batch_norm_backward, | ||
aten._native_batch_norm_legit, | ||
aten._native_batch_norm_legit_functional, | ||
aten._native_batch_norm_legit_no_training, | ||
aten.native_dropout_backward, | ||
aten.native_group_norm, | ||
aten.native_group_norm_backward, | ||
aten.native_layer_norm, | ||
aten.native_layer_norm_backward, | ||
aten.new_empty, | ||
aten.new_full, | ||
aten.new_ones, | ||
aten.new_zeros, | ||
aten.nll_loss_backward, | ||
aten.nll_loss_forward, | ||
aten.norm, | ||
aten.ones, | ||
aten.ones_like, | ||
aten._prelu_kernel, | ||
aten._prelu_kernel_backward, | ||
aten._reshape_alias, | ||
aten.rad2deg, | ||
aten.renorm, | ||
aten.renorm_, | ||
aten.rot90, | ||
aten.rsub.Scalar, | ||
aten.rsub.Tensor, | ||
aten.select_backward, | ||
aten.select_scatter, | ||
aten.sgn, | ||
aten.sigmoid_backward, | ||
aten.silu, | ||
aten.silu_, | ||
aten.silu_backward, | ||
aten.sinc, | ||
aten.slice_backward, | ||
aten.smooth_l1_loss, | ||
aten.smooth_l1_loss_backward, | ||
aten.soft_margin_loss, | ||
aten.soft_margin_loss_backward, | ||
aten._softmax, | ||
aten._softmax_backward_data, | ||
aten.softplus, | ||
aten.softplus_backward, | ||
aten.softshrink, | ||
aten.softshrink_backward, | ||
aten.special_entr, | ||
aten.special_log_ndtr, | ||
aten.special_xlog1py, | ||
aten.stack, | ||
aten.t, | ||
aten.tanh_backward, | ||
aten.threshold, | ||
aten.threshold_backward, | ||
aten.trace, | ||
aten.transpose.int, | ||
aten.tril.default, | ||
aten.triu.default, | ||
aten.unfold, | ||
aten.unfold_backward, | ||
aten.unfold_copy, | ||
aten.upsample_bilinear2d, | ||
aten.upsample_bilinear2d.vec, | ||
aten.upsample_nearest2d_backward, | ||
aten.xlogy, | ||
aten.zero, | ||
aten.zero_, | ||
aten.zeros, | ||
aten.zeros_like, | ||
# Non-default convenience decompositions | ||
aten.clamp_min, | ||
aten.clamp_max, | ||
aten.linalg_vector_norm, | ||
aten.full, | ||
aten.repeat, | ||
} | ||
torch_disabled_decompositions: Set[OpOverload] = set() | ||
|
||
|
||
ENABLED_TORCH_DECOMPOSITIONS: Dict[ | ||
OpOverload, Callable[[Any], Any] | ||
] = get_torch_decompositions(torch_enabled_decompositions) | ||
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} | ||
|
||
|
||
def check_decomp_set_invariants() -> None: | ||
"""Validates no overlap between enabled and disabled decomposition sets""" | ||
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions) | ||
|
||
if overlap: | ||
raise AssertionError( | ||
f"Detected {overlap} registered in both torch_enabled_decompositions " | ||
"and torch_disabled_decompositions. Ensure all operator(s) are in " | ||
"at most one of the two sets." | ||
) | ||
|
||
|
||
check_decomp_set_invariants() |
Oops, something went wrong.