diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..04519995cd10 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( op, "Unimplemented: non-constant value for transposed not supported"); - if (transposed) - return rewriter.notifyMatchFailure( - op, "Unimplemented: transposed convolution not supported"); auto input = adaptor.getInput(); auto weight = adaptor.getWeight(); @@ -2338,12 +2335,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto bias = adaptor.getBias(); if (isa(bias.getType())) { - auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy, - outputElemTy, weightShape); - if (failed(bias_result)) + // ConvTranspose weights use IOHW; the helper expects OIHW, so swap + // dims 0/1 before we synthesize the bias. + SmallVector biasWeightShape = + transposed ? SmallVector{weightShape[1], weightShape[0], + weightShape[2], weightShape[3]} + : weightShape; + + auto biasResult = tosa::getConvBiasForNoneType( + op, rewriter, inputElemTy, outputElemTy, biasWeightShape); + if (failed(biasResult)) return rewriter.notifyMatchFailure( op, "Failed to create bias tensor for none type."); - bias = bias_result.value(); + bias = biasResult.value(); } else { if (!isa(bias.getType())) return rewriter.notifyMatchFailure( @@ -2370,8 +2374,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); - // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D - // padding {height, width}. The Torch OFM computation uses 2*pad in each + // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D + // padding {height, width}. The PyTorch OFM computation uses 2*pad in each // spatial direction, implying the same top=bottom=height and left=right=width // values for TOSA. SmallVector padding( @@ -2388,20 +2392,128 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get accumulator type for convolution ops"); + // Weight layout reference: + // Conv : PyTorch OIHW -> TOSA OHWI + // Depthwise : PyTorch OIHW* -> TOSA HWIM + // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier) + // Grouped : PyTorch O(I/G)HW -> N/A + // Transposed : PyTorch IOHW -> TOSA OHWI // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. // Perform the necessary transformations. SmallVector nchwToNhwcDims({0, 2, 3, 1}); - SmallVector transposedInputShape( - {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); + SmallVector nhwcToNchwDims({0, 3, 1, 2}); + SmallVector transposedInputShape; + for (int32_t dim : nchwToNhwcDims) + transposedInputShape.push_back(inputShape[dim]); auto transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); - auto transposedInput = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transposedInputType), input, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) - .getResult(); + auto createTransposedInput = [&]() { + return rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(transposedInputType), + input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) + .getResult(); + }; + + if (transposed) { + if (groups != 1) + return rewriter.notifyMatchFailure( + op, "Unimplemented: grouped transposed convolution not supported by " + "TOSA"); + if (dilation[0] != 1 || dilation[1] != 1) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dilated transposed convolution not supported by " + "TOSA"); + + SmallVector iohwToOhwi({1, 2, 3, 0}); + + // TOSA 'out_pad' is a 4D array {top,bottom,left,right}. + // Map from PyTorch's (padding, output_padding): + // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W) + // Negative values are allowed and will be handled by the TOSA + // decomposition. + SmallVector outPadding2D; + if (!matchPattern(adaptor.getOutputPadding(), + m_TorchListOfConstantInts(outPadding2D))) + return rewriter.notifyMatchFailure( + op, "non-const output_padding list unsupported for transposed conv"); + + int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0]; + int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1]; + int64_t outPadTop = outPadH / 2; + int64_t outPadBottom = outPadH - outPadTop; + int64_t outPadLeft = outPadW / 2; + int64_t outPadRight = outPadW - outPadLeft; + SmallVector outPad( + {outPadTop, outPadBottom, outPadLeft, outPadRight}); + + Value nhwcInput = createTransposedInput(); + SmallVector ohwiWeightShape; + for (int32_t dim : iohwToOhwi) + ohwiWeightShape.push_back(weightShape[dim]); + auto ohwiWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); + Value transformedWeight = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(ohwiWeightType), + weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi)) + .getResult(); + + // Result type is NHWC (we'll transpose back). + auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); + SmallVector outNHWC; + for (int32_t dim : nchwToNhwcDims) + outNHWC.push_back(outNCHW[dim]); + auto transConvOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy); + + // Zero-points. + auto zps = tosa::createZPsAsConst(rewriter, input, weight); + Value inputZp = zps.first ? zps.first + : tosa::createZeroPointTensor( + rewriter, op->getLoc(), inputElemTy, 0) + .value(); + Value weightZp = zps.second ? zps.second + : tosa::createZeroPointTensor( + rewriter, op->getLoc(), weightElemTy, 0) + .value(); + + Value convTOut = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(transConvOpTy), + nhwcInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(outPad), + rewriter.getDenseI64ArrayAttr(stride), accType) + .getResult(); + + SmallVector transposedOutputShape; + for (int32_t dim : nhwcToNchwDims) + transposedOutputShape.push_back(outNHWC[dim]); + auto transposedOutputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); + Value transposedOutput = + rewriter + .create( + op->getLoc(), + getTypeConverter()->convertType(transposedOutputType), convTOut, + rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) + .getResult(); + + // Quantized rescale. + Value rescaledResult = transposedOutput; + if (isa(inputElemTy)) { + rescaledResult = tosa::buildRescaleOpConvOutput( + rewriter, op, transposedOutput, inputTy, weightTy, outputTy); + } + + // Final cast to requested output type. + rewriter.replaceOp( + op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy) + .value()}); + return success(); + } SmallVector transformedWeightShape; RankedTensorType transformedWeightType; @@ -2427,6 +2539,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector transposedDims({2, 3, 0, 1}); SmallVector transposedWeightShape = { weightShape[2], weightShape[3], weightShape[0], weightShape[1]}; + + // reshape: HWO(I/G) -> HWIM + outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; + if (outputCDim == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "number of output channels must be statically known for " + "depthwise convolutions"); + } + auto transposedWeightType = RankedTensorType::get( makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); auto transposedWeight = @@ -2437,13 +2558,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getDenseI32ArrayAttr(transposedDims)) .getResult(); - // reshape: HWO(I/G) -> HWIM - outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; - if (outputCDim == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "number of output channels must be statically known for " - "depthwise convolutions"); - } transformedWeightShape = { transposedWeightShape[0], transposedWeightShape[1], @@ -2465,6 +2579,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Unhandled convolution type"); } + Value transposedInput = createTransposedInput(); + int64_t outputHDim, outputWDim; int64_t inputHDim = inputShape[2]; int64_t inputWDim = inputShape[3]; @@ -2487,7 +2603,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (remainderHDim != 0) { if (remainderHDim > padding[1]) { SmallVector startHSlice(inputTy.getRank(), 0); - SmallVector sizeHSlice(transposedInputShape); + SmallVector sizeHSlice(transposedInputShape); // TOSA uses NHWC, so we will slice dim 1 for Height value sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); transposedInput = tosa::CreateOpAndInfer( @@ -2583,7 +2699,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Unhandled convolution type"); } - SmallVector nhwcToNchwDims({0, 3, 1, 2}); SmallVector transposedOutputShape( {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); auto transposedOutputType = RankedTensorType::get( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 81071c6ab058..59910bf2692c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3588,7 +3588,6 @@ "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dModule_basic", "Conv_Transpose3dStaticModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", @@ -3713,16 +3712,11 @@ "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", - "Conv_Transpose2dModule_basic", "ConvolutionBackwardModule2DPadded_basic", - "ConvolutionBackwardModule2DStatic_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", - "ConvolutionModule2DTransposeStridedStatic_basic", - "ConvolutionModule2DTransposeStrided_basic", - "ConvolutionModule2DTranspose_basic", "ConvolutionModule2DGroupedTranspose_basic", "ConvolutionModule3DGroups_basic", "ConvolutionModule3DGroupsStrided_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index c9273c1f46c4..f2d148ec466e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -29,6 +29,8 @@ # that depend on TOSA as well as TOSA-to-Standard. "tosa-to-arith", "tosa-to-scf", + # Required for transposed convolution support (decomposes to conv ops). + "tosa-optional-decompositions", # Named ops must be legalized prior to general tosa-to-linalg "tosa-to-linalg-named", # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them diff --git a/test/Conversion/TorchToTosa/conv2d_transpose.mlir b/test/Conversion/TorchToTosa/conv2d_transpose.mlir index 7c24dc896630..ba78ba865d5b 100644 --- a/test/Conversion/TorchToTosa/conv2d_transpose.mlir +++ b/test/Conversion/TorchToTosa/conv2d_transpose.mlir @@ -1,8 +1,23 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file | FileCheck %s -// The following test ensures that a tranposed convolution op is not -// lowered in the torch-to-tosa conversion pass. +// The lowering now legalizes transpose convolutions into the TOSA dialect. +// Verify that we emit tosa.transpose_conv2d with the expected reshapes/ +// permutations. +// CHECK-LABEL: func.func @forward +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { +// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32> +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32> +// CHECK: %[[TRANS_IN:.*]] = tosa.transpose %[[IN_TENSOR]] {perms = array} : (tensor<1x64x1x100xf32>) -> tensor<1x1x100x64xf32> +// CHECK: %[[W_OHWI:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<64x64x3x3xf32>) -> tensor<64x3x3x64xf32> +// CHECK: %[[ZP0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[ZP1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[TCONV:.*]] = tosa.transpose_conv2d %[[TRANS_IN]], %[[W_OHWI]], %[[BIAS]], %[[ZP0]], %[[ZP1]] {acc_type = f32, out_pad = array, stride = array} : (tensor<1x1x100x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x200x64xf32> +// CHECK: %[[TRANS_OUT:.*]] = tosa.transpose %[[TCONV]] {perms = array} : (tensor<1x2x200x64xf32>) -> tensor<1x64x2x200xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[TRANS_OUT]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[1,64,2,200],f32> +// CHECK: } func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { %true = torch.constant.bool true %int1 = torch.constant.int 1 @@ -11,7 +26,6 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[ %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - // expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}} %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> return %output : !torch.vtensor<[1,64,2,200],f32> }