Skip to content

Commit c99c982

Browse files
committed
[TOSA] Add legalization for avg_pool2d
Before this patch, the `avg_pool2d` and `avg_pool1d` legalizations lacked support for pooling with count_include_pad=True. This patch introduces that support. Signed-off-by: Vitalii Shutov <vitalii.shutov@arm.com> Change-Id: I73fa26a58379e2c021929ade81c983ff91c59667
1 parent 2c989a2 commit c99c982

File tree

4 files changed

+179
-51
lines changed

4 files changed

+179
-51
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
106106
Type inputElemTy, Type outputElemTy,
107107
ArrayRef<int64_t> weightShape);
108108

109+
// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later
110+
// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as
111+
// {top, bottom, left, right}. Returns the padded tensor value.
112+
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
113+
Operation *op, Value inputNHWC,
114+
ArrayRef<int64_t> padExtents);
115+
109116
} // namespace tosa
110117
} // namespace mlir
111118

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6075,7 +6075,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
60756075
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
60766076
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
60776077
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6078-
DenseI64ArrayAttr &pad) {
6078+
DenseI64ArrayAttr &pad,
6079+
SmallVectorImpl<int64_t> *explicitNHWCPad = nullptr) {
60796080

60806081
RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
60816082
if (!inputTy)
@@ -6115,21 +6116,43 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61156116

61166117
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
61176118
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6118-
// Currently, we can not represent `count_include_pad` with the existing
6119-
// TOSA AvgPool2d specification. Without the below check, we produce silent
6120-
// wrong answer (SWA) when the `count_include_pad` value is `true.`
6121-
//
6122-
// Note: We need to check for `count_include_pad` only when the `padding`
6123-
// value is non-zero.
6119+
// When count_include_pad=true with non-zero padding, we will materialize an
6120+
// explicit pad after transposing to NHWC. Track the padding extents and
6121+
// zero out the TOSA op padding so the divisor matches the full kernel size.
61246122
bool countIncludePad;
61256123
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
61266124
(!matchPattern(op.getCountIncludePad(),
61276125
m_TorchConstantBool(&countIncludePad)) ||
61286126

61296127
countIncludePad)) {
6130-
return rewriter.notifyMatchFailure(
6131-
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6132-
"`count_include_pad` value should be `False`.");
6128+
if (!explicitNHWCPad)
6129+
return rewriter.notifyMatchFailure(
6130+
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6131+
"`count_include_pad` value should be `False`.");
6132+
6133+
// Remember the spatial padding so we can emit an NHWC tosa.pad right
6134+
// after the transpose.
6135+
explicitNHWCPad->assign(
6136+
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
6137+
6138+
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
6139+
if (ShapedType::isDynamic(dim))
6140+
return ShapedType::kDynamic;
6141+
return dim + before + after;
6142+
};
6143+
6144+
// Update the logical input type used for shape computations to include
6145+
// the extra zeros supplied by the explicit pad.
6146+
SmallVector<int64_t> paddedShape(inputTy.getShape().begin(),
6147+
inputTy.getShape().end());
6148+
// Height stored at rank-2, width at rank-1 for NCHW shapes.
6149+
paddedShape[inputRank - 2] =
6150+
addPad(paddedShape[inputRank - 2], paddingInts[0], paddingInts[0]);
6151+
paddedShape[inputRank - 1] =
6152+
addPad(paddedShape[inputRank - 1], paddingInts[1], paddingInts[1]);
6153+
inputTy = RankedTensorType::get(paddedShape, inputTy.getElementType());
6154+
6155+
paddingInts.assign(/*Count=*/2, /*Value=*/0);
61336156
}
61346157
}
61356158

@@ -6275,15 +6298,23 @@ class ConvertAtenAvgPool2dOp
62756298
}
62766299

62776300
SmallVector<int64_t, 2> dilationArray{1, 1};
6301+
SmallVector<int64_t, 4> explicitNHWCPad;
62786302
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
62796303
tosa::AvgPool2dOp>(
6280-
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6304+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad,
6305+
&explicitNHWCPad)))
62816306
return rewriter.notifyMatchFailure(
62826307
op, "invalid pooling parameters or input type");
62836308

6284-
// Transpose to xHWC
6285-
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6286-
transposePoolingInputToHwc(op, rewriter, self);
6309+
Value transposed =
6310+
ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6311+
transposePoolingInputToHwc(op, rewriter, self);
6312+
6313+
if (!explicitNHWCPad.empty())
6314+
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
6315+
transposed, explicitNHWCPad);
6316+
6317+
input = transposed;
62876318

62886319
return success();
62896320
}
@@ -6328,16 +6359,23 @@ class ConvertAtenAvgPool1dOp
63286359
.getResult();
63296360

63306361
SmallVector<int64_t, 2> dilationArray{1, 1};
6362+
SmallVector<int64_t, 4> explicitNHWCPad;
63316363
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
63326364
tosa::AvgPool2dOp>(
63336365
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6334-
pad)))
6366+
pad, &explicitNHWCPad)))
63356367
return rewriter.notifyMatchFailure(
63366368
op, "invalid pooling parameters or input type");
63376369

6338-
// Transpose to xHWC
6339-
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6340-
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
6370+
Value transposed =
6371+
ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6372+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
6373+
6374+
if (!explicitNHWCPad.empty())
6375+
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
6376+
transposed, explicitNHWCPad);
6377+
6378+
input = transposed;
63416379

63426380
return success();
63436381
}

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,5 +595,42 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
595595
}
596596
}
597597

598+
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
599+
Operation *op, Value inputNHWC,
600+
ArrayRef<int64_t> padExtents) {
601+
assert(padExtents.size() == 4 && "expected [top, bottom, left, right]");
602+
603+
if (llvm::all_of(padExtents, [](int64_t v) { return v == 0; }))
604+
return inputNHWC;
605+
606+
SmallVector<int64_t, 8> nhwcPadding = {
607+
0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0};
608+
Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding);
609+
610+
auto inputTy = cast<RankedTensorType>(inputNHWC.getType());
611+
SmallVector<int64_t, 4> resultShape(inputTy.getShape().begin(),
612+
inputTy.getShape().end());
613+
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
614+
if (ShapedType::isDynamic(dim))
615+
return ShapedType::kDynamic;
616+
return dim + before + after;
617+
};
618+
resultShape[1] = addPad(resultShape[1], padExtents[0], padExtents[1]);
619+
resultShape[2] = addPad(resultShape[2], padExtents[2], padExtents[3]);
620+
621+
auto resultTy = RankedTensorType::get(resultShape, inputTy.getElementType());
622+
623+
Type elemTy = inputTy.getElementType();
624+
Value padConst;
625+
if (isa<mlir::FloatType>(elemTy)) {
626+
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
627+
} else {
628+
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
629+
}
630+
631+
return rewriter.create<tosa::PadOp>(loc, resultTy, inputNHWC, nhwcPadShape,
632+
padConst);
633+
}
634+
598635
} // namespace tosa
599636
} // namespace mlir

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
22652265

22662266
// -----
22672267

2268-
func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
2269-
%int0 = torch.constant.int 0
2270-
%int1 = torch.constant.int 1
2271-
%int3 = torch.constant.int 3
2272-
%false= torch.constant.bool false
2273-
%count_include_pad = torch.constant.bool true
2274-
%divisor_override = torch.constant.none
2275-
2276-
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2277-
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2278-
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2279-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2280-
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
2281-
return %3 : !torch.vtensor<[1,192,35,35],f32>
2282-
}
2283-
2284-
// -----
2285-
22862268
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
22872269
%int0 = torch.constant.int 0
22882270
%int1 = torch.constant.int 1
@@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
28022784

28032785
// -----
28042786

2805-
func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
2806-
%int1 = torch.constant.int 1
2807-
%int3 = torch.constant.int 3
2808-
%false = torch.constant.bool false
2809-
%count_include_pad = torch.constant.bool true
2810-
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
2811-
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2812-
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2813-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2814-
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
2815-
return %3 : !torch.vtensor<[1,512,10],f32>
2816-
}
2817-
2818-
// -----
2819-
28202787
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
28212788
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
28222789
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4328,3 +4295,82 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
43284295
%0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16>
43294296
return %0 : !torch.vtensor<[2,3],f16>
43304297
}
4298+
4299+
// -----
4300+
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4301+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4302+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4303+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4304+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4305+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4306+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4307+
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4308+
// CHECK: %[[VAL_7:.*]] = torch.constant.none
4309+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4310+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4311+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4312+
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
4313+
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4314+
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4315+
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
4316+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4317+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4318+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4319+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4320+
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4321+
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4322+
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4323+
// CHECK: }
4324+
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4325+
%int0 = torch.constant.int 0
4326+
%int1 = torch.constant.int 1
4327+
%int3 = torch.constant.int 3
4328+
%false= torch.constant.bool false
4329+
%count_include_pad = torch.constant.bool true
4330+
%divisor_override = torch.constant.none
4331+
4332+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
4333+
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4334+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4335+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
4336+
return %3 : !torch.vtensor<[1,192,35,35],f32>
4337+
}
4338+
4339+
// -----
4340+
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4341+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4342+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4343+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4344+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4345+
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4346+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4347+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4348+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4349+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4350+
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4351+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4352+
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_10]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
4353+
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4354+
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4355+
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
4356+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4357+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4358+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4359+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4360+
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4361+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4362+
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4363+
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4364+
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4365+
// CHECK: }
4366+
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4367+
%int1 = torch.constant.int 1
4368+
%int3 = torch.constant.int 3
4369+
%false = torch.constant.bool false
4370+
%count_include_pad = torch.constant.bool true
4371+
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
4372+
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4373+
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4374+
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
4375+
return %3 : !torch.vtensor<[1,512,10],f32>
4376+
}

0 commit comments

Comments
 (0)