-
Notifications
You must be signed in to change notification settings - Fork 612
[TOSA] Add transposed conv support #4360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: non-constant value for transposed not supported"); | ||
| if (transposed) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: transposed convolution not supported"); | ||
|
|
||
| auto input = adaptor.getInput(); | ||
| auto weight = adaptor.getWeight(); | ||
|
|
@@ -2338,12 +2335,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| auto bias = adaptor.getBias(); | ||
|
|
||
| if (isa<Torch::NoneType>(bias.getType())) { | ||
| auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy, | ||
| outputElemTy, weightShape); | ||
| if (failed(bias_result)) | ||
| // ConvTranspose weights use IOHW; the helper expects OIHW, so swap | ||
| // dims 0/1 before we synthesize the bias. | ||
| SmallVector<int64_t, 4> biasWeightShape = | ||
| transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0], | ||
| weightShape[2], weightShape[3]} | ||
| : weightShape; | ||
|
|
||
| auto biasResult = tosa::getConvBiasForNoneType( | ||
| op, rewriter, inputElemTy, outputElemTy, biasWeightShape); | ||
| if (failed(biasResult)) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Failed to create bias tensor for none type."); | ||
| bias = bias_result.value(); | ||
| bias = biasResult.value(); | ||
| } else { | ||
| if (!isa<RankedTensorType>(bias.getType())) | ||
| return rewriter.notifyMatchFailure( | ||
|
|
@@ -2370,8 +2374,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| m_TorchListOfConstantInts(padding_2d))) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "non-const padding list unsupported"); | ||
| // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D | ||
| // padding {height, width}. The Torch OFM computation uses 2*pad in each | ||
| // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D | ||
| // padding {height, width}. The PyTorch OFM computation uses 2*pad in each | ||
| // spatial direction, implying the same top=bottom=height and left=right=width | ||
| // values for TOSA. | ||
| SmallVector<int64_t> padding( | ||
|
|
@@ -2388,20 +2392,128 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| return rewriter.notifyMatchFailure( | ||
| op, "failed to get accumulator type for convolution ops"); | ||
|
|
||
| // Weight layout reference: | ||
| // Conv : PyTorch OIHW -> TOSA OHWI | ||
| // Depthwise : PyTorch OIHW* -> TOSA HWIM | ||
| // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier) | ||
| // Grouped : PyTorch O(I/G)HW -> N/A | ||
| // Transposed : PyTorch IOHW -> TOSA OHWI | ||
| // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. | ||
| // Perform the necessary transformations. | ||
| SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1}); | ||
| SmallVector<int64_t> transposedInputShape( | ||
| {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); | ||
| SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2}); | ||
| SmallVector<int64_t, 4> transposedInputShape; | ||
| for (int32_t dim : nchwToNhwcDims) | ||
| transposedInputShape.push_back(inputShape[dim]); | ||
| auto transposedInputType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(transposedInputShape), inputElemTy); | ||
| auto transposedInput = | ||
| rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), | ||
| getTypeConverter()->convertType(transposedInputType), input, | ||
| rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) | ||
| .getResult(); | ||
| auto createTransposedInput = [&]() { | ||
| return rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), getTypeConverter()->convertType(transposedInputType), | ||
| input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) | ||
| .getResult(); | ||
| }; | ||
|
|
||
| if (transposed) { | ||
| if (groups != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: grouped transposed convolution not supported by " | ||
| "TOSA"); | ||
| if (dilation[0] != 1 || dilation[1] != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: dilated transposed convolution not supported by " | ||
| "TOSA"); | ||
|
|
||
|
Comment on lines
+2420
to
+2427
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these notify failures need to happen before any IR rewrites take place, otherwise the pattern rewriter ends up in a recursive loop. For example, on line 2410 we'd have already introduced
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed for transpose and depthwise paths. |
||
| SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0}); | ||
|
|
||
| // TOSA 'out_pad' is a 4D array {top,bottom,left,right}. | ||
| // Map from PyTorch's (padding, output_padding): | ||
| // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W) | ||
| // Negative values are allowed and will be handled by the TOSA | ||
| // decomposition. | ||
| SmallVector<int64_t, 2> outPadding2D; | ||
| if (!matchPattern(adaptor.getOutputPadding(), | ||
| m_TorchListOfConstantInts(outPadding2D))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "non-const output_padding list unsupported for transposed conv"); | ||
|
|
||
| int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0]; | ||
| int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1]; | ||
| int64_t outPadTop = outPadH / 2; | ||
| int64_t outPadBottom = outPadH - outPadTop; | ||
| int64_t outPadLeft = outPadW / 2; | ||
| int64_t outPadRight = outPadW - outPadLeft; | ||
| SmallVector<int64_t, 4> outPad( | ||
| {outPadTop, outPadBottom, outPadLeft, outPadRight}); | ||
|
|
||
| Value nhwcInput = createTransposedInput(); | ||
| SmallVector<int64_t, 4> ohwiWeightShape; | ||
| for (int32_t dim : iohwToOhwi) | ||
| ohwiWeightShape.push_back(weightShape[dim]); | ||
| auto ohwiWeightType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); | ||
| Value transformedWeight = | ||
| rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), getTypeConverter()->convertType(ohwiWeightType), | ||
| weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi)) | ||
| .getResult(); | ||
|
|
||
| // Result type is NHWC (we'll transpose back). | ||
| auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); | ||
| SmallVector<int64_t, 4> outNHWC; | ||
| for (int32_t dim : nchwToNhwcDims) | ||
| outNHWC.push_back(outNCHW[dim]); | ||
| auto transConvOpTy = | ||
| RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy); | ||
|
|
||
| // Zero-points. | ||
| auto zps = tosa::createZPsAsConst(rewriter, input, weight); | ||
| Value inputZp = zps.first ? zps.first | ||
| : tosa::createZeroPointTensor( | ||
| rewriter, op->getLoc(), inputElemTy, 0) | ||
| .value(); | ||
| Value weightZp = zps.second ? zps.second | ||
| : tosa::createZeroPointTensor( | ||
| rewriter, op->getLoc(), weightElemTy, 0) | ||
| .value(); | ||
|
|
||
| Value convTOut = | ||
| rewriter | ||
| .create<tosa::TransposeConv2DOp>( | ||
| op->getLoc(), getTypeConverter()->convertType(transConvOpTy), | ||
| nhwcInput, transformedWeight, bias, inputZp, weightZp, | ||
| rewriter.getDenseI64ArrayAttr(outPad), | ||
| rewriter.getDenseI64ArrayAttr(stride), accType) | ||
| .getResult(); | ||
|
|
||
| SmallVector<int64_t, 4> transposedOutputShape; | ||
| for (int32_t dim : nhwcToNchwDims) | ||
| transposedOutputShape.push_back(outNHWC[dim]); | ||
| auto transposedOutputType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); | ||
| Value transposedOutput = | ||
| rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), | ||
| getTypeConverter()->convertType(transposedOutputType), convTOut, | ||
| rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) | ||
| .getResult(); | ||
|
|
||
| // Quantized rescale. | ||
| Value rescaledResult = transposedOutput; | ||
| if (isa<quant::QuantizedType>(inputElemTy)) { | ||
| rescaledResult = tosa::buildRescaleOpConvOutput( | ||
Lallapallooza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| rewriter, op, transposedOutput, inputTy, weightTy, outputTy); | ||
| } | ||
|
|
||
| // Final cast to requested output type. | ||
| rewriter.replaceOp( | ||
| op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy) | ||
| .value()}); | ||
| return success(); | ||
| } | ||
|
|
||
| SmallVector<int64_t> transformedWeightShape; | ||
| RankedTensorType transformedWeightType; | ||
|
|
@@ -2427,6 +2539,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| SmallVector<int32_t> transposedDims({2, 3, 0, 1}); | ||
| SmallVector<int64_t> transposedWeightShape = { | ||
| weightShape[2], weightShape[3], weightShape[0], weightShape[1]}; | ||
|
|
||
| // reshape: HWO(I/G) -> HWIM | ||
| outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; | ||
| if (outputCDim == kUnknownSize) { | ||
| return rewriter.notifyMatchFailure( | ||
| op, "number of output channels must be statically known for " | ||
| "depthwise convolutions"); | ||
| } | ||
|
|
||
| auto transposedWeightType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); | ||
| auto transposedWeight = | ||
|
|
@@ -2437,13 +2558,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| rewriter.getDenseI32ArrayAttr(transposedDims)) | ||
| .getResult(); | ||
|
|
||
| // reshape: HWO(I/G) -> HWIM | ||
| outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; | ||
| if (outputCDim == kUnknownSize) { | ||
| return rewriter.notifyMatchFailure( | ||
| op, "number of output channels must be statically known for " | ||
| "depthwise convolutions"); | ||
| } | ||
| transformedWeightShape = { | ||
| transposedWeightShape[0], | ||
| transposedWeightShape[1], | ||
|
|
@@ -2465,6 +2579,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| llvm_unreachable("Unhandled convolution type"); | ||
| } | ||
|
|
||
| Value transposedInput = createTransposedInput(); | ||
|
|
||
| int64_t outputHDim, outputWDim; | ||
| int64_t inputHDim = inputShape[2]; | ||
| int64_t inputWDim = inputShape[3]; | ||
|
|
@@ -2487,7 +2603,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| if (remainderHDim != 0) { | ||
| if (remainderHDim > padding[1]) { | ||
| SmallVector<int64_t> startHSlice(inputTy.getRank(), 0); | ||
| SmallVector<int64_t> sizeHSlice(transposedInputShape); | ||
| SmallVector<int64_t, 4> sizeHSlice(transposedInputShape); | ||
| // TOSA uses NHWC, so we will slice dim 1 for Height value | ||
| sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); | ||
| transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>( | ||
|
|
@@ -2583,7 +2699,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| llvm_unreachable("Unhandled convolution type"); | ||
| } | ||
|
|
||
| SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2}); | ||
| SmallVector<int64_t> transposedOutputShape( | ||
| {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); | ||
| auto transposedOutputType = RankedTensorType::get( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.