diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 82d11ec0737a..89ce7c9ad000 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3171,6 +3171,100 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ }]; } +def Torch_AtenBitwiseXorScalarOp : Torch_Op<"aten.bitwise_xor.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_xor.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseXorScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseXorScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseXor_ScalarOp : Torch_Op<"aten.bitwise_xor_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_xor_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseXor_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseXor_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseOrScalarOp : Torch_Op<"aten.bitwise_or.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_or.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseOrScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseOrScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseOr_ScalarOp : Torch_Op<"aten.bitwise_or_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_or_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseOr_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseOr_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index f76331cbe6a8..aba96da937db 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -2312,6 +2312,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp); INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp); INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseAndScalarOp, chlo::BroadcastAndOp); + INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseXorScalarOp, chlo::BroadcastXorOp); + INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseOrScalarOp, chlo::BroadcastOrOp); #undef INSERT_BINARY_LOGICAL_PATTERN diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7d5e65c21cef..5a59e11927a7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -364,6 +364,8 @@ def emit_with_mutating_variants(key, **kwargs): "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_or.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::bitwise_xor.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index b8fa177bb93d..c644553ddccf 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -7147,3 +7147,49 @@ def forward(self, x): @register_test_case(module_factory=lambda: AtenAsStridedUnknownSizeModule()) def AtenAsStridedUnknownSizeModule_basic(module, tu: TestUtils): module.forward(torch.randn(12, 13)) + + +# ============================================================================== + + +class BitwiseOrScalarIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 6], torch.int64, True), + ] + ) + def forward(self, x): + return torch.bitwise_or(x, 3) + + +@register_test_case(module_factory=lambda: BitwiseOrScalarIntModule()) +def BitwiseOrScalarIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 6, low=-16, high=16)) + + +# ============================================================================== + + +class BitwiseXorScalarIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 3], torch.int64, True), + ] + ) + def forward(self, x): + return torch.bitwise_xor(x, 15) + + +@register_test_case(module_factory=lambda: BitwiseXorScalarIntModule()) +def BitwiseXorScalarIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(5, 3, low=-16, high=16)) diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index c46328095440..cd6a006e7569 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -339,3 +339,42 @@ func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int %0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32> return %0 : !torch.vtensor<[2,3,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_or.Scalar( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[2,8],si32>) -> !torch.vtensor<[2,8],si32> { +// CHECK: %[[BUILTIN:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[2,8],si32> -> tensor<2x8xi32> +// CHECK: %[[UNUSED:.*]] = torch.constant.int 3 +// CHECK: %[[CST:.*]] = arith.constant 3 : i64 +// CHECK: %[[ELEM:.*]] = tensor.from_elements %[[CST]] : tensor<1xi64> +// CHECK: %[[CONV:.*]] = stablehlo.convert %[[ELEM]] : (tensor<1xi64>) -> tensor<1xi32> +// CHECK: %[[SCALAR:.*]] = stablehlo.reshape %[[CONV]] : (tensor<1xi32>) -> tensor +// CHECK: %[[RES:.*]] = chlo.broadcast_or %[[BUILTIN]], %[[SCALAR]] : (tensor<2x8xi32>, tensor) -> tensor<2x8xi32> +// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[RES]] : tensor<2x8xi32> -> !torch.vtensor<[2,8],si32> +// CHECK: return %[[OUT]] : !torch.vtensor<[2,8],si32> +// CHECK: } +func.func @torch.aten.bitwise_or.Scalar(%arg0: !torch.vtensor<[2,8],si32>) -> !torch.vtensor<[2,8],si32> { + %int3 = torch.constant.int 3 + %0 = torch.aten.bitwise_or.Scalar %arg0, %int3 : !torch.vtensor<[2,8],si32>, !torch.int -> !torch.vtensor<[2,8],si32> + return %0 : !torch.vtensor<[2,8],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_xor.Scalar( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,7],si64>) -> !torch.vtensor<[3,7],si64> { +// CHECK: %[[BUILTIN:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[3,7],si64> -> tensor<3x7xi64> +// CHECK: %[[UNUSED:.*]] = torch.constant.int 15 +// CHECK: %[[CST:.*]] = arith.constant 15 : i64 +// CHECK: %[[ELEM:.*]] = tensor.from_elements %[[CST]] : tensor<1xi64> +// CHECK: %[[SCALAR:.*]] = stablehlo.reshape %[[ELEM]] : (tensor<1xi64>) -> tensor +// CHECK: %[[RES:.*]] = chlo.broadcast_xor %[[BUILTIN]], %[[SCALAR]] : (tensor<3x7xi64>, tensor) -> tensor<3x7xi64> +// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[RES]] : tensor<3x7xi64> -> !torch.vtensor<[3,7],si64> +// CHECK: return %[[OUT]] : !torch.vtensor<[3,7],si64> +// CHECK: } +func.func @torch.aten.bitwise_xor.Scalar(%arg0: !torch.vtensor<[3,7],si64>) -> !torch.vtensor<[3,7],si64> { + %int15 = torch.constant.int 15 + %0 = torch.aten.bitwise_xor.Scalar %arg0, %int15 : !torch.vtensor<[3,7],si64>, !torch.int -> !torch.vtensor<[3,7],si64> + return %0 : !torch.vtensor<[3,7],si64> +}