-
Notifications
You must be signed in to change notification settings - Fork 361
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support aten._log_softmax dynamo converter
- Loading branch information
Showing
3 changed files
with
48 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import run_tests | ||
from torch_tensorrt import Input | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestLogSoftmaxConverter(DispatchTestCase): | ||
def test_log_softmax(self): | ||
class TestModule(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.ops.aten._log_softmax.default(x, 1, False) | ||
|
||
inputs = [torch.randn(1, 3, 5, 7)] | ||
self.run_test(TestModule(), inputs) | ||
|
||
def test_log_softmax_with_dynamic_shape(self): | ||
class TestModule(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.ops.aten._log_softmax.default(x, 2, False) | ||
|
||
input_specs = [ | ||
Input( | ||
min_shape=(1, 1, 1, 1), | ||
opt_shape=(2, 4, 6, 8), | ||
max_shape=(8, 8, 8, 8), | ||
dtype=torch.float32, | ||
), | ||
] | ||
self.run_test_with_dynamic_shape(TestModule(), input_specs) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |