diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index ee968daff010..91618872cdc2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -57,7 +57,7 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); @@ -66,7 +66,7 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, static Value createGreaterThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); @@ -74,7 +74,7 @@ static Value createGreaterThanOrEqual(OpBuilder &b, Location loc, static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); @@ -82,7 +82,7 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, static Value createLessThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { - return createComparisonTemplate( b, loc, elementalType, lhs, rhs); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 57a549309c4d..ac04eeb41109 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -160,7 +160,9 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseGeFloatTensorModule()) def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) # ============================================================================== @@ -200,7 +202,9 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule()) def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) # ============================================================================== @@ -378,6 +382,28 @@ def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLeFloatTensorNanModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.le(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseLeFloatTensorNanModule()) +def ElementwiseLeFloatTensorNanModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) + +# ============================================================================== + class ElementwiseLeIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -414,7 +440,9 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule()) def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) + module.forward( + torch.tensor([[1.0, 2.2, torch.nan], [6.0, 2.0, 3.1]]).to(torch.float32), + torch.tensor([6.0, 2.1, torch.nan]).to(torch.float32)) # ==============================================================================