Skip to content

Commit a38e094

Browse files
[mlir] Dialect Conversion: Add support for post-order legalization order (#166292)
By default, the dialect conversion driver processes operations in pre-order: the initial worklist is populated pre-order. (New/modified operations are immediately legalized recursively.) This commit adds a new API for selective post-order legalization. Patterns can request an operation / region legalization via `ConversionPatternRewriter::legalize`. They can call these helper functions on nested regions before rewriting the operation itself. Note: In rollback mode, a failed recursive legalization typically leads to a conversion failure. Since recursive legalization is performed by separate pattern applications, there is no way for the original pattern to recover from such a failure.
1 parent c1dc064 commit a38e094

File tree

6 files changed

+199
-37
lines changed

6 files changed

+199
-37
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,28 @@ class ConversionPatternRewriter final : public PatternRewriter {
981981
/// Return a reference to the internal implementation.
982982
detail::ConversionPatternRewriterImpl &getImpl();
983983

984+
/// Attempt to legalize the given operation. This can be used within
985+
/// conversion patterns to change the default pre-order legalization order.
986+
/// Returns "success" if the operation was legalized, "failure" otherwise.
987+
///
988+
/// Note: In a partial conversion, this function returns "success" even if
989+
/// the operation could not be legalized, as long as it was not explicitly
990+
/// marked as illegal in the conversion target.
991+
LogicalResult legalize(Operation *op);
992+
993+
/// Attempt to legalize the given region. This can be used within
994+
/// conversion patterns to change the default pre-order legalization order.
995+
/// Returns "success" if the region was legalized, "failure" otherwise.
996+
///
997+
/// If the current pattern runs with a type converter, the entry block
998+
/// signature will be converted before legalizing the operations in the
999+
/// region.
1000+
///
1001+
/// Note: In a partial conversion, this function returns "success" even if
1002+
/// an operation could not be legalized, as long as it was not explicitly
1003+
/// marked as illegal in the conversion target.
1004+
LogicalResult legalize(Region *r);
1005+
9841006
private:
9851007
// Allow OperationConverter to construct new rewriters.
9861008
friend struct OperationConverter;
@@ -989,7 +1011,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
9891011
/// conversions. They apply some IR rewrites in a delayed fashion and could
9901012
/// bring the IR into an inconsistent state when used standalone.
9911013
explicit ConversionPatternRewriter(MLIRContext *ctx,
992-
const ConversionConfig &config);
1014+
const ConversionConfig &config,
1015+
OperationConverter &converter);
9931016

9941017
// Hide unsupported pattern rewriter API.
9951018
using OpBuilder::setListener;

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
9292
return pt;
9393
}
9494

95+
namespace {
96+
enum OpConversionMode {
97+
/// In this mode, the conversion will ignore failed conversions to allow
98+
/// illegal operations to co-exist in the IR.
99+
Partial,
100+
101+
/// In this mode, all operations must be legal for the given target for the
102+
/// conversion to succeed.
103+
Full,
104+
105+
/// In this mode, operations are analyzed for legality. No actual rewrites are
106+
/// applied to the operations on success.
107+
Analysis,
108+
};
109+
} // namespace
110+
95111
//===----------------------------------------------------------------------===//
96112
// ConversionValueMapping
97113
//===----------------------------------------------------------------------===//
@@ -866,8 +882,9 @@ namespace mlir {
866882
namespace detail {
867883
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
868884
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
869-
const ConversionConfig &config)
870-
: rewriter(rewriter), config(config),
885+
const ConversionConfig &config,
886+
OperationConverter &opConverter)
887+
: rewriter(rewriter), config(config), opConverter(opConverter),
871888
notifyingRewriter(rewriter.getContext(), config.listener) {}
872889

873890
//===--------------------------------------------------------------------===//
@@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11241141
/// Dialect conversion configuration.
11251142
const ConversionConfig &config;
11261143

1144+
/// The operation converter to use for recursive legalization.
1145+
OperationConverter &opConverter;
1146+
11271147
/// A set of erased operations. This set is utilized only if
11281148
/// `allowPatternRollback` is set to "false". Conceptually, this set is
11291149
/// similar to `replacedOps` (which is maintained when the flag is set to
@@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
20842104
//===----------------------------------------------------------------------===//
20852105

20862106
ConversionPatternRewriter::ConversionPatternRewriter(
2087-
MLIRContext *ctx, const ConversionConfig &config)
2088-
: PatternRewriter(ctx),
2089-
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
2107+
MLIRContext *ctx, const ConversionConfig &config,
2108+
OperationConverter &opConverter)
2109+
: PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
2110+
*this, config, opConverter)) {
20902111
setListener(impl.get());
20912112
}
20922113

@@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
22072228
return success();
22082229
}
22092230

2231+
LogicalResult ConversionPatternRewriter::legalize(Region *r) {
2232+
// Fast path: If the region is empty, there is nothing to legalize.
2233+
if (r->empty())
2234+
return success();
2235+
2236+
// Gather a list of all operations to legalize. This is done before
2237+
// converting the entry block signature because unrealized_conversion_cast
2238+
// ops should not be included.
2239+
SmallVector<Operation *> ops;
2240+
for (Block &b : *r)
2241+
for (Operation &op : b)
2242+
ops.push_back(&op);
2243+
2244+
// If the current pattern runs with a type converter, convert the entry block
2245+
// signature.
2246+
if (const TypeConverter *converter = impl->currentTypeConverter) {
2247+
std::optional<TypeConverter::SignatureConversion> conversion =
2248+
converter->convertBlockSignature(&r->front());
2249+
if (!conversion)
2250+
return failure();
2251+
applySignatureConversion(&r->front(), *conversion, converter);
2252+
}
2253+
2254+
// Legalize all operations in the region.
2255+
for (Operation *op : ops)
2256+
if (failed(legalize(op)))
2257+
return failure();
2258+
2259+
return success();
2260+
}
2261+
22102262
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
22112263
Block::iterator before,
22122264
ValueRange argValues) {
@@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
31923244
// OperationConverter
31933245
//===----------------------------------------------------------------------===//
31943246

3195-
namespace {
3196-
enum OpConversionMode {
3197-
/// In this mode, the conversion will ignore failed conversions to allow
3198-
/// illegal operations to co-exist in the IR.
3199-
Partial,
3200-
3201-
/// In this mode, all operations must be legal for the given target for the
3202-
/// conversion to succeed.
3203-
Full,
3204-
3205-
/// In this mode, operations are analyzed for legality. No actual rewrites are
3206-
/// applied to the operations on success.
3207-
Analysis,
3208-
};
3209-
} // namespace
3210-
32113247
namespace mlir {
32123248
// This class converts operations to a given conversion target via a set of
32133249
// rewrite patterns. The conversion behaves differently depending on the
@@ -3217,16 +3253,20 @@ struct OperationConverter {
32173253
const FrozenRewritePatternSet &patterns,
32183254
const ConversionConfig &config,
32193255
OpConversionMode mode)
3220-
: rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
3256+
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
32213257
mode(mode) {}
32223258

32233259
/// Converts the given operations to the conversion target.
32243260
LogicalResult convertOperations(ArrayRef<Operation *> ops);
32253261

3226-
private:
3227-
/// Converts an operation with the given rewriter.
3228-
LogicalResult convert(Operation *op);
3262+
/// Converts a single operation. If `isRecursiveLegalization` is "true", the
3263+
/// conversion is a recursive legalization request, triggered from within a
3264+
/// pattern. In that case, do not emit errors because there will be another
3265+
/// attempt at legalizing the operation later (via the regular pre-order
3266+
/// legalization mechanism).
3267+
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
32293268

3269+
private:
32303270
/// The rewriter to use when converting operations.
32313271
ConversionPatternRewriter rewriter;
32323272

@@ -3238,32 +3278,42 @@ struct OperationConverter {
32383278
};
32393279
} // namespace mlir
32403280

3241-
LogicalResult OperationConverter::convert(Operation *op) {
3281+
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3282+
return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
3283+
}
3284+
3285+
LogicalResult OperationConverter::convert(Operation *op,
3286+
bool isRecursiveLegalization) {
32423287
const ConversionConfig &config = rewriter.getConfig();
32433288

32443289
// Legalize the given operation.
32453290
if (failed(opLegalizer.legalize(op))) {
32463291
// Handle the case of a failed conversion for each of the different modes.
32473292
// Full conversions expect all operations to be converted.
3248-
if (mode == OpConversionMode::Full)
3249-
return op->emitError()
3250-
<< "failed to legalize operation '" << op->getName() << "'";
3293+
if (mode == OpConversionMode::Full) {
3294+
if (!isRecursiveLegalization)
3295+
op->emitError() << "failed to legalize operation '" << op->getName()
3296+
<< "'";
3297+
return failure();
3298+
}
32513299
// Partial conversions allow conversions to fail iff the operation was not
32523300
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
32533301
// set, non-legalizable ops are added to that set.
32543302
if (mode == OpConversionMode::Partial) {
3255-
if (opLegalizer.isIllegal(op))
3256-
return op->emitError()
3257-
<< "failed to legalize operation '" << op->getName()
3258-
<< "' that was explicitly marked illegal";
3259-
if (config.unlegalizedOps)
3303+
if (opLegalizer.isIllegal(op)) {
3304+
if (!isRecursiveLegalization)
3305+
op->emitError() << "failed to legalize operation '" << op->getName()
3306+
<< "' that was explicitly marked illegal";
3307+
return failure();
3308+
}
3309+
if (config.unlegalizedOps && !isRecursiveLegalization)
32603310
config.unlegalizedOps->insert(op);
32613311
}
32623312
} else if (mode == OpConversionMode::Analysis) {
32633313
// Analysis conversions don't fail if any operations fail to legalize,
32643314
// they are only interested in the operations that were successfully
32653315
// legalized.
3266-
if (config.legalizableOps)
3316+
if (config.legalizableOps && !isRecursiveLegalization)
32673317
config.legalizableOps->insert(op);
32683318
}
32693319
return success();

mlir/test/Transforms/test-legalizer-full.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,21 @@ builtin.module {
7272
}
7373

7474
}
75+
76+
// -----
77+
78+
// The region of "test.post_order_legalization" is converted before the op.
79+
80+
// expected-remark@+1 {{applyFullConversion failed}}
81+
builtin.module {
82+
func.func @test_preorder_legalization() {
83+
// expected-error@+1 {{failed to legalize operation 'test.post_order_legalization'}}
84+
"test.post_order_legalization"() ({
85+
^bb0(%arg0: i64):
86+
// Not-explicitly-legal ops are not allowed to survive.
87+
"test.remaining_consumer"(%arg0) : (i64) -> ()
88+
"test.invalid"(%arg0) : (i64) -> ()
89+
}) : () -> ()
90+
return
91+
}
92+
}

mlir/test/Transforms/test-legalizer-rollback.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
163163
"test.return"(%0) : (i32) -> ()
164164
}
165165
}
166+
167+
// -----
168+
169+
// CHECK-LABEL: func @test_failed_preorder_legalization
170+
// CHECK: "test.post_order_legalization"() ({
171+
// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
172+
// CHECK: "test.return"(%[[r]]) : (i32) -> ()
173+
// CHECK: }) : () -> ()
174+
// expected-remark @+1 {{applyPartialConversion failed}}
175+
module {
176+
func.func @test_failed_preorder_legalization() {
177+
// expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
178+
"test.post_order_legalization"() ({
179+
%0 = "test.illegal_op_g"() : () -> (i32)
180+
"test.return"(%0) : (i32) -> ()
181+
}) : () -> ()
182+
return
183+
}
184+
}

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
448448
"test.type_consumer"(%arg0) : (f16) -> ()
449449
"test.return"() : () -> ()
450450
}
451+
452+
// -----
453+
454+
// The region of "test.post_order_legalization" is converted before the op.
455+
456+
// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
457+
// CHECK: notifyOperationInserted: test.invalid
458+
// CHECK: notifyBlockErased
459+
// CHECK: notifyOperationInserted: test.valid, was unlinked
460+
// CHECK: notifyOperationReplaced: test.invalid
461+
// CHECK: notifyOperationErased: test.invalid
462+
// CHECK: notifyOperationModified: test.post_order_legalization
463+
464+
// CHECK-LABEL: func @test_preorder_legalization
465+
// CHECK: "test.post_order_legalization"() ({
466+
// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
467+
// Note: The survival of a not-explicitly-invalid operation does *not* cause
468+
// a conversion failure in when applying a partial conversion.
469+
// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64
470+
// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> ()
471+
// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
472+
// CHECK: }) {is_legal} : () -> ()
473+
func.func @test_preorder_legalization() {
474+
"test.post_order_legalization"() ({
475+
^bb0(%arg0: i64):
476+
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
477+
"test.remaining_consumer"(%arg0) : (i64) -> ()
478+
"test.invalid"(%arg0) : (i64) -> ()
479+
}) : () -> ()
480+
// expected-remark @+1 {{'func.return' is not legalizable}}
481+
return
482+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern
14181418
}
14191419
};
14201420

1421+
class TestPostOrderLegalization : public ConversionPattern {
1422+
public:
1423+
TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
1424+
: ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
1425+
LogicalResult
1426+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1427+
ConversionPatternRewriter &rewriter) const final {
1428+
for (Region &r : op->getRegions())
1429+
if (failed(rewriter.legalize(&r)))
1430+
return failure();
1431+
rewriter.modifyOpInPlace(
1432+
op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
1433+
return success();
1434+
}
1435+
};
1436+
14211437
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
14221438
/// function is just to trigger compiler errors. It is never executed.
14231439
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
15321548
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
15331549
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
15341550
TestValueReplace, TestReplaceWithValidConsumer,
1535-
TestTypeConsumerOpPattern>(&getContext(), converter);
1551+
TestTypeConsumerOpPattern, TestPostOrderLegalization>(
1552+
&getContext(), converter);
15361553
patterns.add<TestConvertBlockArgs>(converter, &getContext());
15371554
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
15381555
converter);
@@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
15601577
target.addDynamicallyLegalOp(
15611578
OperationName("test.value_replace", &getContext()),
15621579
[](Operation *op) { return op->hasAttr("is_legal"); });
1580+
target.addDynamicallyLegalOp(
1581+
OperationName("test.post_order_legalization", &getContext()),
1582+
[](Operation *op) { return op->hasAttr("is_legal"); });
15631583

15641584
// TestCreateUnregisteredOp creates `arith.constant` operation,
15651585
// which was not added to target intentionally to test

0 commit comments

Comments
 (0)