Skip to content

Conversation

@Lallapallooza
Copy link
Contributor

Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer TOSA xfails to drop the transpose-conv cases that now pass, and document the weight layout mapping.

Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer
TOSA xfails to drop the transpose-conv cases that now pass, and document
the weight layout mapping.

Change-Id: I23be2230a0948784402dca574597db1d979d5aee
@Lallapallooza
Copy link
Contributor Author

@sahas3 @sjarus can you please take a look?

Copy link
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good to me. Some minor comments and clarifying questions.


// Weight layout reference:
// Conv : PyTorch OIHW -> TOSA OHWI
// Depthwise : PyTorch OIHW* -> TOSA HWIM (*out = in * multiplier)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this reference, though it's not clear what you are trying to imply with (*out = in * multiplier) here.


SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});
SmallVector<int64_t> ohwiWeightShape(
{weightShape[1], weightShape[2], weightShape[3], weightShape[0]});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#nit: weightShape[iohwToOhwi[0]], weightShape[iohwToOhwi[1]], ... and so on will help readability I think

// Result type is NHWC (we'll transpose back).
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
SmallVector<int64_t> outNHWC(
{outNCHW[0], outNCHW[2], outNCHW[3], outNCHW[1]});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment to above to use the nchwToNhwcDims instead of hardcoding 0,2,3,1

@@ -1,8 +1,23 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is verify-diagnostics needed anymore unless you plan to add tests to verify the failure cases?

// Quantized rescale.
Value rescaledResult = transposedOutput;
if (isa<quant::QuantizedType>(inputElemTy)) {
rescaledResult = tosa::buildRescaleOpConvOutput(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConvTranspose2DQInt8_basic e2e test likely triggers this code path -- any idea why that is still failing? If that's not the correct test, are there any e2e tests that triggers this code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants