@@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
9292 return pt;
9393}
9494
95+ namespace {
96+ enum OpConversionMode {
97+ // / In this mode, the conversion will ignore failed conversions to allow
98+ // / illegal operations to co-exist in the IR.
99+ Partial,
100+
101+ // / In this mode, all operations must be legal for the given target for the
102+ // / conversion to succeed.
103+ Full,
104+
105+ // / In this mode, operations are analyzed for legality. No actual rewrites are
106+ // / applied to the operations on success.
107+ Analysis,
108+ };
109+ } // namespace
110+
95111// ===----------------------------------------------------------------------===//
96112// ConversionValueMapping
97113// ===----------------------------------------------------------------------===//
@@ -866,8 +882,9 @@ namespace mlir {
866882namespace detail {
867883struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
868884 explicit ConversionPatternRewriterImpl (ConversionPatternRewriter &rewriter,
869- const ConversionConfig &config)
870- : rewriter(rewriter), config(config),
885+ const ConversionConfig &config,
886+ OperationConverter &opConverter)
887+ : rewriter(rewriter), config(config), opConverter(opConverter),
871888 notifyingRewriter(rewriter.getContext(), config.listener) {}
872889
873890 // ===--------------------------------------------------------------------===//
@@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11241141 // / Dialect conversion configuration.
11251142 const ConversionConfig &config;
11261143
1144+ // / The operation converter to use for recursive legalization.
1145+ OperationConverter &opConverter;
1146+
11271147 // / A set of erased operations. This set is utilized only if
11281148 // / `allowPatternRollback` is set to "false". Conceptually, this set is
11291149 // / similar to `replacedOps` (which is maintained when the flag is set to
@@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
20842104// ===----------------------------------------------------------------------===//
20852105
20862106ConversionPatternRewriter::ConversionPatternRewriter (
2087- MLIRContext *ctx, const ConversionConfig &config)
2088- : PatternRewriter(ctx),
2089- impl(new detail::ConversionPatternRewriterImpl(*this , config)) {
2107+ MLIRContext *ctx, const ConversionConfig &config,
2108+ OperationConverter &opConverter)
2109+ : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
2110+ *this , config, opConverter)) {
20902111 setListener (impl.get ());
20912112}
20922113
@@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
22072228 return success ();
22082229}
22092230
2231+ LogicalResult ConversionPatternRewriter::legalize (Region *r) {
2232+ // Fast path: If the region is empty, there is nothing to legalize.
2233+ if (r->empty ())
2234+ return success ();
2235+
2236+ // Gather a list of all operations to legalize. This is done before
2237+ // converting the entry block signature because unrealized_conversion_cast
2238+ // ops should not be included.
2239+ SmallVector<Operation *> ops;
2240+ for (Block &b : *r)
2241+ for (Operation &op : b)
2242+ ops.push_back (&op);
2243+
2244+ // If the current pattern runs with a type converter, convert the entry block
2245+ // signature.
2246+ if (const TypeConverter *converter = impl->currentTypeConverter ) {
2247+ std::optional<TypeConverter::SignatureConversion> conversion =
2248+ converter->convertBlockSignature (&r->front ());
2249+ if (!conversion)
2250+ return failure ();
2251+ applySignatureConversion (&r->front (), *conversion, converter);
2252+ }
2253+
2254+ // Legalize all operations in the region.
2255+ for (Operation *op : ops)
2256+ if (failed (legalize (op)))
2257+ return failure ();
2258+
2259+ return success ();
2260+ }
2261+
22102262void ConversionPatternRewriter::inlineBlockBefore (Block *source, Block *dest,
22112263 Block::iterator before,
22122264 ValueRange argValues) {
@@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
31923244// OperationConverter
31933245// ===----------------------------------------------------------------------===//
31943246
3195- namespace {
3196- enum OpConversionMode {
3197- // / In this mode, the conversion will ignore failed conversions to allow
3198- // / illegal operations to co-exist in the IR.
3199- Partial,
3200-
3201- // / In this mode, all operations must be legal for the given target for the
3202- // / conversion to succeed.
3203- Full,
3204-
3205- // / In this mode, operations are analyzed for legality. No actual rewrites are
3206- // / applied to the operations on success.
3207- Analysis,
3208- };
3209- } // namespace
3210-
32113247namespace mlir {
32123248// This class converts operations to a given conversion target via a set of
32133249// rewrite patterns. The conversion behaves differently depending on the
@@ -3217,16 +3253,20 @@ struct OperationConverter {
32173253 const FrozenRewritePatternSet &patterns,
32183254 const ConversionConfig &config,
32193255 OpConversionMode mode)
3220- : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
3256+ : rewriter(ctx, config, * this ), opLegalizer(rewriter, target, patterns),
32213257 mode(mode) {}
32223258
32233259 // / Converts the given operations to the conversion target.
32243260 LogicalResult convertOperations (ArrayRef<Operation *> ops);
32253261
3226- private:
3227- // / Converts an operation with the given rewriter.
3228- LogicalResult convert (Operation *op);
3262+ // / Converts a single operation. If `isRecursiveLegalization` is "true", the
3263+ // / conversion is a recursive legalization request, triggered from within a
3264+ // / pattern. In that case, do not emit errors because there will be another
3265+ // / attempt at legalizing the operation later (via the regular pre-order
3266+ // / legalization mechanism).
3267+ LogicalResult convert (Operation *op, bool isRecursiveLegalization = false );
32293268
3269+ private:
32303270 // / The rewriter to use when converting operations.
32313271 ConversionPatternRewriter rewriter;
32323272
@@ -3238,32 +3278,42 @@ struct OperationConverter {
32383278};
32393279} // namespace mlir
32403280
3241- LogicalResult OperationConverter::convert (Operation *op) {
3281+ LogicalResult ConversionPatternRewriter::legalize (Operation *op) {
3282+ return impl->opConverter .convert (op, /* isRecursiveLegalization=*/ true );
3283+ }
3284+
3285+ LogicalResult OperationConverter::convert (Operation *op,
3286+ bool isRecursiveLegalization) {
32423287 const ConversionConfig &config = rewriter.getConfig ();
32433288
32443289 // Legalize the given operation.
32453290 if (failed (opLegalizer.legalize (op))) {
32463291 // Handle the case of a failed conversion for each of the different modes.
32473292 // Full conversions expect all operations to be converted.
3248- if (mode == OpConversionMode::Full)
3249- return op->emitError ()
3250- << " failed to legalize operation '" << op->getName () << " '" ;
3293+ if (mode == OpConversionMode::Full) {
3294+ if (!isRecursiveLegalization)
3295+ op->emitError () << " failed to legalize operation '" << op->getName ()
3296+ << " '" ;
3297+ return failure ();
3298+ }
32513299 // Partial conversions allow conversions to fail iff the operation was not
32523300 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
32533301 // set, non-legalizable ops are added to that set.
32543302 if (mode == OpConversionMode::Partial) {
3255- if (opLegalizer.isIllegal (op))
3256- return op->emitError ()
3257- << " failed to legalize operation '" << op->getName ()
3258- << " ' that was explicitly marked illegal" ;
3259- if (config.unlegalizedOps )
3303+ if (opLegalizer.isIllegal (op)) {
3304+ if (!isRecursiveLegalization)
3305+ op->emitError () << " failed to legalize operation '" << op->getName ()
3306+ << " ' that was explicitly marked illegal" ;
3307+ return failure ();
3308+ }
3309+ if (config.unlegalizedOps && !isRecursiveLegalization)
32603310 config.unlegalizedOps ->insert (op);
32613311 }
32623312 } else if (mode == OpConversionMode::Analysis) {
32633313 // Analysis conversions don't fail if any operations fail to legalize,
32643314 // they are only interested in the operations that were successfully
32653315 // legalized.
3266- if (config.legalizableOps )
3316+ if (config.legalizableOps && !isRecursiveLegalization )
32673317 config.legalizableOps ->insert (op);
32683318 }
32693319 return success ();
0 commit comments