Skip to content

Commit

Permalink
Support aten._log_softmax dynamo converter
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Sep 1, 2024
1 parent 0f8f23d commit be15018
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
14 changes: 14 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,20 @@ def aten_ops_softmax(
)


@dynamo_tensorrt_converter(
torch.ops.aten._log_softmax.default, supports_dynamic_shapes=True
)
def aten_ops_log_softmax(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
softmax = aten_ops_softmax(ctx, target, args, kwargs, name)
return impl.unary.log(ctx, target, SourceIR.ATEN, name, softmax)


@dynamo_tensorrt_converter(
torch.ops.aten.split.Tensor,
capability_validator=has_static_shapes_in_args([1]),
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
Expand Down
34 changes: 34 additions & 0 deletions tests/py/dynamo/conversion/test_log_softmax_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestLogSoftmaxConverter(DispatchTestCase):
def test_log_softmax(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._log_softmax.default(x, 1, False)

inputs = [torch.randn(1, 3, 5, 7)]
self.run_test(TestModule(), inputs)

def test_log_softmax_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._log_softmax.default(x, 2, False)

input_specs = [
Input(
min_shape=(1, 1, 1, 1),
opt_shape=(2, 4, 6, 8),
max_shape=(8, 8, 8, 8),
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)


if __name__ == "__main__":
run_tests()

0 comments on commit be15018

Please sign in to comment.