diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..f0465aa801f8 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" @@ -91,6 +92,16 @@ class ConvertAtenUnaryOp : public OpConversionPattern { self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); + if constexpr (std::is_same_v) { + if (auto intTy = dyn_cast(outType.getElementType())) { + if (intTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, outType, self); + return success(); + } + } + // otherwise fall through to standard emission + } + rewriter.replaceOpWithNewOp(op, outType, self); return success(); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c00e48f39e88..40b6b1b873af 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -5175,6 +5175,29 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseBitwiseNotBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.bitwise_not(x) + + +@register_test_case(module_factory=lambda: ElementwiseBitwiseNotBoolModule()) +def ElementwiseBitwiseNotBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=0, high=2).to(torch.bool)) + + +# ============================================================================== + + class ElementwiseSubTensorInt8Module(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d100fe9dcfde..f0f1f4b9ed6c 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -135,6 +135,20 @@ func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // ----- +// CHECK-LABEL: func.func @torch.aten.bitwise_not$bool( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> { +// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3],i1> -> tensor<2x3xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.logical_not %[[ARG_BUILTIN]] : (tensor<2x3xi1>) -> tensor<2x3xi1> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<2x3xi1> -> !torch.vtensor<[2,3],i1> +// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3],i1> +// CHECK: } +func.func @torch.aten.bitwise_not$bool(%arg0: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> { + %0 = torch.aten.bitwise_not %arg0 : !torch.vtensor<[2,3],i1> -> !torch.vtensor<[2,3],i1> + return %0 : !torch.vtensor<[2,3],i1> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.ceil$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor