diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1d734bca03..13191db936 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -693,6 +693,20 @@ def aten_ops_softmax( ) +@dynamo_tensorrt_converter( + torch.ops.aten._log_softmax.default, supports_dynamic_shapes=True +) +def aten_ops_log_softmax( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + softmax = aten_ops_softmax(ctx, target, args, kwargs, name) + return impl.unary.log(ctx, target, SourceIR.ATEN, name, softmax) + + @dynamo_tensorrt_converter( torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]), diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index ae3e7e1ffa..7552b16685 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -77,7 +77,6 @@ aten.logit_backward, aten.log_sigmoid_backward, aten.log_sigmoid_forward, - aten._log_softmax, aten._log_softmax_backward_data, aten.logspace, aten.logsumexp.default, diff --git a/tests/py/dynamo/conversion/test_log_softmax_aten.py b/tests/py/dynamo/conversion/test_log_softmax_aten.py new file mode 100644 index 0000000000..306d225fb1 --- /dev/null +++ b/tests/py/dynamo/conversion/test_log_softmax_aten.py @@ -0,0 +1,34 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestLogSoftmaxConverter(DispatchTestCase): + def test_log_softmax(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._log_softmax.default(x, 1, False) + + inputs = [torch.randn(1, 3, 5, 7)] + self.run_test(TestModule(), inputs) + + def test_log_softmax_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._log_softmax.default(x, 2, False) + + input_specs = [ + Input( + min_shape=(1, 1, 1, 1), + opt_shape=(2, 4, 6, 8), + max_shape=(8, 8, 8, 8), + dtype=torch.float32, + ), + ] + self.run_test_with_dynamic_shape(TestModule(), input_specs) + + +if __name__ == "__main__": + run_tests()