Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 142 additions & 27 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: non-constant value for transposed not supported");
if (transposed)
return rewriter.notifyMatchFailure(
op, "Unimplemented: transposed convolution not supported");

auto input = adaptor.getInput();
auto weight = adaptor.getWeight();
Expand Down Expand Up @@ -2338,12 +2335,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
auto bias = adaptor.getBias();

if (isa<Torch::NoneType>(bias.getType())) {
auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy,
outputElemTy, weightShape);
if (failed(bias_result))
// ConvTranspose weights use IOHW; the helper expects OIHW, so swap
// dims 0/1 before we synthesize the bias.
SmallVector<int64_t, 4> biasWeightShape =
transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0],
weightShape[2], weightShape[3]}
: weightShape;

auto biasResult = tosa::getConvBiasForNoneType(
op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
if (failed(biasResult))
return rewriter.notifyMatchFailure(
op, "Failed to create bias tensor for none type.");
bias = bias_result.value();
bias = biasResult.value();
} else {
if (!isa<RankedTensorType>(bias.getType()))
return rewriter.notifyMatchFailure(
Expand All @@ -2370,8 +2374,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
m_TorchListOfConstantInts(padding_2d)))
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
// TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
// padding {height, width}. The Torch OFM computation uses 2*pad in each
// TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
// padding {height, width}. The PyTorch OFM computation uses 2*pad in each
// spatial direction, implying the same top=bottom=height and left=right=width
// values for TOSA.
SmallVector<int64_t> padding(
Expand All @@ -2388,20 +2392,128 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "failed to get accumulator type for convolution ops");

// Weight layout reference:
// Conv : PyTorch OIHW -> TOSA OHWI
// Depthwise : PyTorch OIHW* -> TOSA HWIM
// (PyTorch depthwise uses out_ch=in_ch*depth_multiplier)
// Grouped : PyTorch O(I/G)HW -> N/A
// Transposed : PyTorch IOHW -> TOSA OHWI
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
// Perform the necessary transformations.
SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1});
SmallVector<int64_t> transposedInputShape(
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
SmallVector<int64_t, 4> transposedInputShape;
for (int32_t dim : nchwToNhwcDims)
transposedInputShape.push_back(inputShape[dim]);
auto transposedInputType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
auto transposedInput =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transposedInputType), input,
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
.getResult();
auto createTransposedInput = [&]() {
return rewriter
.create<tosa::TransposeOp>(
op->getLoc(), getTypeConverter()->convertType(transposedInputType),
input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
.getResult();
};

if (transposed) {
if (groups != 1)
return rewriter.notifyMatchFailure(
op, "Unimplemented: grouped transposed convolution not supported by "
"TOSA");
if (dilation[0] != 1 || dilation[1] != 1)
return rewriter.notifyMatchFailure(
op, "Unimplemented: dilated transposed convolution not supported by "
"TOSA");

Comment on lines +2420 to +2427
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these notify failures need to happen before any IR rewrites take place, otherwise the pattern rewriter ends up in a recursive loop. For example, on line 2410 we'd have already introduced tosa.transpose and then we'll bail out from here for a failure case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed for transpose and depthwise paths.

SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});

// TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
// Map from PyTorch's (padding, output_padding):
// out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
// Negative values are allowed and will be handled by the TOSA
// decomposition.
SmallVector<int64_t, 2> outPadding2D;
if (!matchPattern(adaptor.getOutputPadding(),
m_TorchListOfConstantInts(outPadding2D)))
return rewriter.notifyMatchFailure(
op, "non-const output_padding list unsupported for transposed conv");

int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0];
int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1];
int64_t outPadTop = outPadH / 2;
int64_t outPadBottom = outPadH - outPadTop;
int64_t outPadLeft = outPadW / 2;
int64_t outPadRight = outPadW - outPadLeft;
SmallVector<int64_t, 4> outPad(
{outPadTop, outPadBottom, outPadLeft, outPadRight});

Value nhwcInput = createTransposedInput();
SmallVector<int64_t, 4> ohwiWeightShape;
for (int32_t dim : iohwToOhwi)
ohwiWeightShape.push_back(weightShape[dim]);
auto ohwiWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
Value transformedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
.getResult();

// Result type is NHWC (we'll transpose back).
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
SmallVector<int64_t, 4> outNHWC;
for (int32_t dim : nchwToNhwcDims)
outNHWC.push_back(outNCHW[dim]);
auto transConvOpTy =
RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy);

// Zero-points.
auto zps = tosa::createZPsAsConst(rewriter, input, weight);
Value inputZp = zps.first ? zps.first
: tosa::createZeroPointTensor(
rewriter, op->getLoc(), inputElemTy, 0)
.value();
Value weightZp = zps.second ? zps.second
: tosa::createZeroPointTensor(
rewriter, op->getLoc(), weightElemTy, 0)
.value();

Value convTOut =
rewriter
.create<tosa::TransposeConv2DOp>(
op->getLoc(), getTypeConverter()->convertType(transConvOpTy),
nhwcInput, transformedWeight, bias, inputZp, weightZp,
rewriter.getDenseI64ArrayAttr(outPad),
rewriter.getDenseI64ArrayAttr(stride), accType)
.getResult();

SmallVector<int64_t, 4> transposedOutputShape;
for (int32_t dim : nhwcToNchwDims)
transposedOutputShape.push_back(outNHWC[dim]);
auto transposedOutputType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedOutputShape), biasElemTy);
Value transposedOutput =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transposedOutputType), convTOut,
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
.getResult();

// Quantized rescale.
Value rescaledResult = transposedOutput;
if (isa<quant::QuantizedType>(inputElemTy)) {
rescaledResult = tosa::buildRescaleOpConvOutput(
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
}

// Final cast to requested output type.
rewriter.replaceOp(
op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy)
.value()});
return success();
}

SmallVector<int64_t> transformedWeightShape;
RankedTensorType transformedWeightType;
Expand All @@ -2427,6 +2539,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
SmallVector<int32_t> transposedDims({2, 3, 0, 1});
SmallVector<int64_t> transposedWeightShape = {
weightShape[2], weightShape[3], weightShape[0], weightShape[1]};

// reshape: HWO(I/G) -> HWIM
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
if (outputCDim == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "number of output channels must be statically known for "
"depthwise convolutions");
}

auto transposedWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
auto transposedWeight =
Expand All @@ -2437,13 +2558,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
rewriter.getDenseI32ArrayAttr(transposedDims))
.getResult();

// reshape: HWO(I/G) -> HWIM
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
if (outputCDim == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "number of output channels must be statically known for "
"depthwise convolutions");
}
transformedWeightShape = {
transposedWeightShape[0],
transposedWeightShape[1],
Expand All @@ -2465,6 +2579,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
llvm_unreachable("Unhandled convolution type");
}

Value transposedInput = createTransposedInput();

int64_t outputHDim, outputWDim;
int64_t inputHDim = inputShape[2];
int64_t inputWDim = inputShape[3];
Expand All @@ -2487,7 +2603,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
if (remainderHDim != 0) {
if (remainderHDim > padding[1]) {
SmallVector<int64_t> startHSlice(inputTy.getRank(), 0);
SmallVector<int64_t> sizeHSlice(transposedInputShape);
SmallVector<int64_t, 4> sizeHSlice(transposedInputShape);
// TOSA uses NHWC, so we will slice dim 1 for Height value
sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]);
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
Expand Down Expand Up @@ -2583,7 +2699,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
llvm_unreachable("Unhandled convolution type");
}

SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
SmallVector<int64_t> transposedOutputShape(
{outputShape[0], outputShape[3], outputShape[1], outputShape[2]});
auto transposedOutputType = RankedTensorType::get(
Expand Down
6 changes: 0 additions & 6 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3588,7 +3588,6 @@
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
"Conv_Transpose1dModule_basic",
"Conv_Transpose1dStaticModule_basic",
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
Expand Down Expand Up @@ -3713,16 +3712,11 @@
"Conv3dWithValidPaddingModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
"ConvolutionModule2DGroups_basic",
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"ConvolutionModule2DGroupedTranspose_basic",
"ConvolutionModule3DGroups_basic",
"ConvolutionModule3DGroupsStrided_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# that depend on TOSA as well as TOSA-to-Standard.
"tosa-to-arith",
"tosa-to-scf",
# Required for transposed convolution support (decomposes to conv ops).
"tosa-optional-decompositions",
# Named ops must be legalized prior to general tosa-to-linalg
"tosa-to-linalg-named",
# TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them
Expand Down
22 changes: 18 additions & 4 deletions test/Conversion/TorchToTosa/conv2d_transpose.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file | FileCheck %s

// The following test ensures that a tranposed convolution op is not
// lowered in the torch-to-tosa conversion pass.
// The lowering now legalizes transpose convolutions into the TOSA dialect.
// Verify that we emit tosa.transpose_conv2d with the expected reshapes/
// permutations.

// CHECK-LABEL: func.func @forward
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32>
// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32>
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32>
// CHECK: %[[TRANS_IN:.*]] = tosa.transpose %[[IN_TENSOR]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x64x1x100xf32>) -> tensor<1x1x100x64xf32>
// CHECK: %[[W_OHWI:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 1, 2, 3, 0>} : (tensor<64x64x3x3xf32>) -> tensor<64x3x3x64xf32>
// CHECK: %[[ZP0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[ZP1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[TCONV:.*]] = tosa.transpose_conv2d %[[TRANS_IN]], %[[W_OHWI]], %[[BIAS]], %[[ZP0]], %[[ZP1]] {acc_type = f32, out_pad = array<i64: 0, -1, 0, -1>, stride = array<i64: 2, 2>} : (tensor<1x1x100x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x200x64xf32>
// CHECK: %[[TRANS_OUT:.*]] = tosa.transpose %[[TCONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x2x200x64xf32>) -> tensor<1x64x2x200xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[TRANS_OUT]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,64,2,200],f32>
// CHECK: }
func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
%true = torch.constant.bool true
%int1 = torch.constant.int 1
Expand All @@ -11,7 +26,6 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[
%bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}}
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>
return %output : !torch.vtensor<[1,64,2,200],f32>
}
Loading