Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 69 additions & 12 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,28 @@ static Value getScalarFloatValue(Value input, Location loc,
return nullptr;
}

// Common helper for splat-only rounding-based folders.
static OpFoldResult foldSplatRounding(ValueTensorType resultType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think naming it foldFloatSplatWithRounding will be more appropriate since it only handles float data.

Attribute selfAttr,
APFloat::roundingMode mode) {
auto elems = dyn_cast_or_null<DenseElementsAttr>(selfAttr);
if (!elems || !elems.isSplat())
return {};

if (!isa<mlir::FloatType>(resultType.getDtype()))
return {};

auto outShaped = resultType.toBuiltinTensor();
if (!outShaped.hasStaticShape())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is static shape a requirement?

return {};

APFloat v = elems.getSplatValue<APFloat>();
// NaNs and infs are dealt with consistently with torch, so side-effects
// can be discarded.
Comment on lines +253 to +254
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification: Can you elaborate on what you meant by "NaNs and infs are dealt with consistently with torch" ?

(void)v.roundToIntegral(mode);
return DenseElementsAttr::get(outShaped, v);
}

//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2064,10 +2086,19 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {

OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() &&
isa<mlir::IntegerType>(resultType.getDtype())) {

if (!resultType || !resultType.hasDtype())
return {};

// No-op if the result is int, fold.
if (isa<mlir::IntegerType>(resultType.getDtype()))
return getSelf();
}

// Fold float splats.
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
/*mode*/ APFloat::rmTowardNegative))
return res;

return {};
}

Expand All @@ -2077,10 +2108,19 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {

OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() &&
isa<mlir::IntegerType>(resultType.getDtype())) {

if (!resultType || !resultType.hasDtype())
return {};

// No-op if the result is int, fold.
if (isa<mlir::IntegerType>(resultType.getDtype()))
return getSelf();
}

// Fold float splats.
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
/*mode*/ APFloat::rmTowardPositive))
return res;

return {};
}

Expand All @@ -2103,10 +2143,18 @@ OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) {

OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() &&
isa<mlir::IntegerType>(resultType.getDtype())) {
if (!resultType || !resultType.hasDtype())
return {};

// No-op if the result is int, fold.
if (isa<mlir::IntegerType>(resultType.getDtype()))
return getSelf();
}

// Fold float splats.
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
/*mode*/ APFloat::rmNearestTiesToEven))
return res;

return {};
}

Expand All @@ -2116,10 +2164,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {

OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
auto resultType = dyn_cast<ValueTensorType>(getType());
if (resultType && resultType.hasDtype() &&
isa<mlir::IntegerType>(resultType.getDtype())) {

if (!resultType || !resultType.hasDtype())
return {};

// No-op if the result is int, fold.
if (isa<mlir::IntegerType>(resultType.getDtype()))
return getSelf();
}

// Fold float splats.
if (auto res = foldSplatRounding(resultType, /*selfAttr*/ adaptor.getSelf(),
/*mode*/ APFloat::rmTowardZero))
return res;

return {};
}

Expand Down
64 changes: 64 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6844,6 +6844,70 @@ def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))


class AtenRoundNegFloatHalfToEvenSplatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([-1.5, -1.5], dtype=torch.float32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.round(self.const)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing the other ops are already covered in e2e tests?



@register_test_case(module_factory=lambda: AtenRoundNegFloatHalfToEvenSplatModule())
def AtenRoundNegFloatHalfToEvenSplatModule_basic(module, tu: TestUtils):
module.forward()


class AtenRoundPosFloatHalfToEvenSplatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([1.5, 1.5], dtype=torch.float32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.round(self.const)


@register_test_case(module_factory=lambda: AtenRoundPosFloatHalfToEvenSplatModule())
def AtenRoundPosFloatHalfToEvenSplatModule_basic(module, tu: TestUtils):
module.forward()


class AtenRoundInfSplatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([float("+inf")], dtype=torch.float32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.round(self.const)


@register_test_case(module_factory=lambda: AtenRoundInfSplatModule())
def AtenRoundInfSplatModule_basic(module, tu: TestUtils):
module.forward()


class AtenRoundNanSplatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.tensor([float("nan")], dtype=torch.float32)

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.round(self.const)


@register_test_case(module_factory=lambda: AtenRoundNanSplatModule())
def AtenRoundNanSplatModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


Expand Down
58 changes: 58 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3596,3 +3596,61 @@ func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
%1 = torch.aten.full %0, %int-Inf, %none, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],si64>
return %1 : !torch.vtensor<[2,1,4],si64>
}

// -----

// CHECK-LABEL: @torch.aten.ceil$fold
// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-1.000000e+00> : tensor<2x2xf32>)
// CHECK: return %[[C]]
func.func @torch.aten.ceil$fold() -> !torch.vtensor<[2,2],f32> {
%cst = torch.vtensor.literal(dense<-1.100000e+00> : tensor<2x2xf32>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a +ve value too for completeness?

: !torch.vtensor<[2,2],f32>
%r = torch.aten.ceil %cst : !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
return %r : !torch.vtensor<[2,2],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.floor$fold
// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3x4xf32>)
// CHECK: return %[[C]]
func.func @torch.aten.floor$fold() -> !torch.vtensor<[3,4],f32> {
%cst = torch.vtensor.literal(dense<1.900000e+00> : tensor<3x4xf32>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly adding a negative value will give full coverage

: !torch.vtensor<[3,4],f32>
%r = torch.aten.floor %cst : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
return %r : !torch.vtensor<[3,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.trunc$fold
// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-3.000000e+00> : tensor<1x3xf32>)
// CHECK: return %[[C]]
func.func @torch.aten.trunc$fold() -> !torch.vtensor<[1,3],f32> {
%cst = torch.vtensor.literal(dense<-3.700000e+00> : tensor<1x3xf32>)
: !torch.vtensor<[1,3],f32>
%r = torch.aten.trunc %cst : !torch.vtensor<[1,3],f32> -> !torch.vtensor<[1,3],f32>
return %r : !torch.vtensor<[1,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.round$fold
// CHECK-DAG: %[[POS:.*]] = torch.vtensor.literal(dense<2.000000e+00> : tensor<4x5xf32>)
// CHECK-DAG: %[[NEG:.*]] = torch.vtensor.literal(dense<-2.000000e+00> : tensor<2x3xf32>)
// CHECK: return %[[POS]], %[[NEG]]
func.func @torch.aten.round$fold()
-> (!torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32>) {
%cpos = torch.vtensor.literal(dense<2.500000e+00> : tensor<4x5xf32>)
: !torch.vtensor<[4,5],f32>
%rpos = torch.aten.round %cpos
: !torch.vtensor<[4,5],f32> -> !torch.vtensor<[4,5],f32>

%cneg = torch.vtensor.literal(dense<-2.500000e+00> : tensor<2x3xf32>)
: !torch.vtensor<[2,3],f32>
%rneg = torch.aten.round %cneg
: !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>

return %rpos, %rneg
: !torch.vtensor<[4,5],f32>, !torch.vtensor<[2,3],f32>
}
Loading