-
Notifications
You must be signed in to change notification settings - Fork 610
[Torch] Fold aten rounding ops on splat constants. #4359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,6 +234,28 @@ static Value getScalarFloatValue(Value input, Location loc, | |
| return nullptr; | ||
| } | ||
|
|
||
| // Common helper for splat-only rounding-based folders. | ||
| static OpFoldResult foldSplatRounding(ValueTensorType resultType, | ||
| Attribute selfAttr, | ||
| APFloat::roundingMode mode) { | ||
| auto elems = dyn_cast_or_null<DenseElementsAttr>(selfAttr); | ||
| if (!elems || !elems.isSplat()) | ||
| return {}; | ||
|
|
||
| if (!isa<mlir::FloatType>(resultType.getDtype())) | ||
| return {}; | ||
|
|
||
| auto outShaped = resultType.toBuiltinTensor(); | ||
| if (!outShaped.hasStaticShape()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is static shape a requirement? |
||
| return {}; | ||
|
|
||
| APFloat v = elems.getSplatValue<APFloat>(); | ||
| // NaNs and infs are dealt with consistently with torch, so side-effects | ||
| // can be discarded. | ||
|
Comment on lines
+253
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarification: Can you elaborate on what you meant by "NaNs and infs are dealt with consistently with torch" ? |
||
| (void)v.roundToIntegral(mode); | ||
| return DenseElementsAttr::get(outShaped, v); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // MethodOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -2064,10 +2086,19 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) { | |
|
|
||
| OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { | ||
| auto resultType = dyn_cast<ValueTensorType>(getType()); | ||
| if (resultType && resultType.hasDtype() && | ||
| isa<mlir::IntegerType>(resultType.getDtype())) { | ||
|
|
||
| if (!resultType || !resultType.hasDtype()) | ||
| return {}; | ||
|
|
||
| // No-op if the result is int, fold. | ||
| if (isa<mlir::IntegerType>(resultType.getDtype())) | ||
| return getSelf(); | ||
| } | ||
|
|
||
| // Fold float splats. | ||
| if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), | ||
| /*mode*/ APFloat::rmTowardNegative)) | ||
| return res; | ||
|
|
||
| return {}; | ||
| } | ||
|
|
||
|
|
@@ -2077,10 +2108,19 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { | |
|
|
||
| OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { | ||
| auto resultType = dyn_cast<ValueTensorType>(getType()); | ||
| if (resultType && resultType.hasDtype() && | ||
| isa<mlir::IntegerType>(resultType.getDtype())) { | ||
|
|
||
| if (!resultType || !resultType.hasDtype()) | ||
| return {}; | ||
|
|
||
| // No-op if the result is int, fold. | ||
| if (isa<mlir::IntegerType>(resultType.getDtype())) | ||
| return getSelf(); | ||
| } | ||
|
|
||
| // Fold float splats. | ||
| if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), | ||
| /*mode*/ APFloat::rmTowardPositive)) | ||
| return res; | ||
|
|
||
| return {}; | ||
| } | ||
|
|
||
|
|
@@ -2103,10 +2143,18 @@ OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) { | |
|
|
||
| OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { | ||
| auto resultType = dyn_cast<ValueTensorType>(getType()); | ||
| if (resultType && resultType.hasDtype() && | ||
| isa<mlir::IntegerType>(resultType.getDtype())) { | ||
| if (!resultType || !resultType.hasDtype()) | ||
| return {}; | ||
|
|
||
| // No-op if the result is int, fold. | ||
| if (isa<mlir::IntegerType>(resultType.getDtype())) | ||
| return getSelf(); | ||
| } | ||
|
|
||
| // Fold float splats. | ||
| if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), | ||
| /*mode*/ APFloat::rmNearestTiesToEven)) | ||
| return res; | ||
|
|
||
| return {}; | ||
| } | ||
|
|
||
|
|
@@ -2116,10 +2164,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { | |
|
|
||
| OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { | ||
| auto resultType = dyn_cast<ValueTensorType>(getType()); | ||
| if (resultType && resultType.hasDtype() && | ||
| isa<mlir::IntegerType>(resultType.getDtype())) { | ||
|
|
||
| if (!resultType || !resultType.hasDtype()) | ||
| return {}; | ||
|
|
||
| // No-op if the result is int, fold. | ||
| if (isa<mlir::IntegerType>(resultType.getDtype())) | ||
| return getSelf(); | ||
| } | ||
|
|
||
| // Fold float splats. | ||
| if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(), | ||
| /*mode*/ APFloat::rmTowardZero)) | ||
| return res; | ||
|
|
||
| return {}; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6844,6 +6844,70 @@ def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils): | |
| module.forward(tu.rand(5, 5, low=-3.0, high=3.0)) | ||
|
|
||
|
|
||
| class AtenRoundNegFloatHalfToEvenSplatModule(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.const = torch.tensor([-1.5, -1.5], dtype=torch.float32) | ||
|
|
||
| @export | ||
| @annotate_args([None]) | ||
| def forward(self): | ||
| return torch.ops.aten.round(self.const) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am guessing the other ops are already covered in e2e tests? |
||
|
|
||
|
|
||
| @register_test_case(module_factory=lambda: AtenRoundNegFloatHalfToEvenSplatModule()) | ||
| def AtenRoundNegFloatHalfToEvenSplatModule_basic(module, tu: TestUtils): | ||
| module.forward() | ||
|
|
||
|
|
||
| class AtenRoundPosFloatHalfToEvenSplatModule(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.const = torch.tensor([1.5, 1.5], dtype=torch.float32) | ||
|
|
||
| @export | ||
| @annotate_args([None]) | ||
| def forward(self): | ||
| return torch.ops.aten.round(self.const) | ||
|
|
||
|
|
||
| @register_test_case(module_factory=lambda: AtenRoundPosFloatHalfToEvenSplatModule()) | ||
| def AtenRoundPosFloatHalfToEvenSplatModule_basic(module, tu: TestUtils): | ||
| module.forward() | ||
|
|
||
|
|
||
| class AtenRoundInfSplatModule(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.const = torch.tensor([float("+inf")], dtype=torch.float32) | ||
|
|
||
| @export | ||
| @annotate_args([None]) | ||
| def forward(self): | ||
| return torch.ops.aten.round(self.const) | ||
|
|
||
|
|
||
| @register_test_case(module_factory=lambda: AtenRoundInfSplatModule()) | ||
| def AtenRoundInfSplatModule_basic(module, tu: TestUtils): | ||
| module.forward() | ||
|
|
||
|
|
||
| class AtenRoundNanSplatModule(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.const = torch.tensor([float("nan")], dtype=torch.float32) | ||
|
|
||
| @export | ||
| @annotate_args([None]) | ||
| def forward(self): | ||
| return torch.ops.aten.round(self.const) | ||
|
|
||
|
|
||
| @register_test_case(module_factory=lambda: AtenRoundNanSplatModule()) | ||
| def AtenRoundNanSplatModule_basic(module, tu: TestUtils): | ||
| module.forward() | ||
|
|
||
|
|
||
| # ============================================================================== | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3596,3 +3596,61 @@ func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> { | |
| %1 = torch.aten.full %0, %int-Inf, %none, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],si64> | ||
| return %1 : !torch.vtensor<[2,1,4],si64> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @torch.aten.ceil$fold | ||
| // CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-1.000000e+00> : tensor<2x2xf32>) | ||
| // CHECK: return %[[C]] | ||
| func.func @torch.aten.ceil$fold() -> !torch.vtensor<[2,2],f32> { | ||
| %cst = torch.vtensor.literal(dense<-1.100000e+00> : tensor<2x2xf32>) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a +ve value too for completeness? |
||
| : !torch.vtensor<[2,2],f32> | ||
| %r = torch.aten.ceil %cst : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> | ||
| return %r : !torch.vtensor<[2,2],f32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @torch.aten.floor$fold | ||
| // CHECK: %[[C:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3x4xf32>) | ||
| // CHECK: return %[[C]] | ||
| func.func @torch.aten.floor$fold() -> !torch.vtensor<[3,4],f32> { | ||
| %cst = torch.vtensor.literal(dense<1.900000e+00> : tensor<3x4xf32>) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly adding a negative value will give full coverage |
||
| : !torch.vtensor<[3,4],f32> | ||
| %r = torch.aten.floor %cst : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> | ||
| return %r : !torch.vtensor<[3,4],f32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @torch.aten.trunc$fold | ||
| // CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-3.000000e+00> : tensor<1x3xf32>) | ||
| // CHECK: return %[[C]] | ||
| func.func @torch.aten.trunc$fold() -> !torch.vtensor<[1,3],f32> { | ||
| %cst = torch.vtensor.literal(dense<-3.700000e+00> : tensor<1x3xf32>) | ||
| : !torch.vtensor<[1,3],f32> | ||
| %r = torch.aten.trunc %cst : !torch.vtensor<[1,3],f32> -> !torch.vtensor<[1,3],f32> | ||
| return %r : !torch.vtensor<[1,3],f32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @torch.aten.round$fold | ||
| // CHECK-DAG: %[[POS:.*]] = torch.vtensor.literal(dense<2.000000e+00> : tensor<4x5xf32>) | ||
| // CHECK-DAG: %[[NEG:.*]] = torch.vtensor.literal(dense<-2.000000e+00> : tensor<2x3xf32>) | ||
| // CHECK: return %[[POS]], %[[NEG]] | ||
| func.func @torch.aten.round$fold() | ||
| -> (!torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32>) { | ||
| %cpos = torch.vtensor.literal(dense<2.500000e+00> : tensor<4x5xf32>) | ||
| : !torch.vtensor<[4,5],f32> | ||
| %rpos = torch.aten.round %cpos | ||
| : !torch.vtensor<[4,5],f32> -> !torch.vtensor<[4,5],f32> | ||
|
|
||
| %cneg = torch.vtensor.literal(dense<-2.500000e+00> : tensor<2x3xf32>) | ||
| : !torch.vtensor<[2,3],f32> | ||
| %rneg = torch.aten.round %cneg | ||
| : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> | ||
|
|
||
| return %rpos, %rneg | ||
| : !torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32> | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think naming it
foldFloatSplatWithRoundingwill be more appropriate since it only handlesfloatdata.