From 80ae125409053062d7ad03978c1fb3fa3813a24b Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 21 Oct 2025 18:08:14 -0400 Subject: [PATCH] [MLIR] Add fusability query to TilingInterface This introduces `isOpFusableWithProducer/Consumer` methods to the TilingInterface that enable querying whether a tilable op can be fused into a given set of producer slices or consumer slice without generating IR. This is needed to enable use of the tiling interface in pattern rewrites, as without this any pattern rewrite that tries to invoke the method to tile is allowed to generate IR and fail. --- .../mlir/Interfaces/TilingInterface.td | 37 ++++++++++++ .../Linalg/Transforms/TilingInterfaceImpl.cpp | 46 +++++++++++++++ .../TilingInterface/query-fusability.mlir | 49 ++++++++++++++++ .../TestTilingInterfaceTransformOps.cpp | 58 +++++++++++++++++++ .../TestTilingInterfaceTransformOps.td | 24 +++++++- 5 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Interfaces/TilingInterface/query-fusability.mlir diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index e0516abdfcf0c..c30782a25e40f 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -360,6 +360,43 @@ def TilingInterface : OpInterface<"TilingInterface"> { /*defaultImplementation=*/[{ return failure(); }] + >, + //===------------------------------------------------------------------===// + // Interface methods for querying fusability. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Indicates whether it is possible to fuse this operation with the given + result slice. This method is not allowed to generate any IR. + }], + /*retTy=*/"bool", + /*methodName=*/"isOpFusableWithConsumerSlice", + /*args=*/(ins + "unsigned":$resultNumber, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Indicates whether it is possible to fuse this operation with the given + list of operand slices. This method is not allowed to generate any IR. + }], + /*retTy=*/"bool", + /*methodName=*/"isOpFusableWithProducerSlices", + /*args=*/(ins + "::mlir::ArrayRef":$operandNumbers, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 57b610b31e964..527878786f50f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -359,6 +359,52 @@ struct LinalgOpTilingInterface /// Inline the op payload and store the result. return inlinePayload(builder, linalgOp, ivs, indexedValues); } + + bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + return !cast(op).getShapesToLoopsMap(); + } + + bool isOpFusableWithProducerSlices( + Operation *op, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes) const { + + auto linalgOp = cast(op); + SmallVector indexingMaps = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { + OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + // First verify that the iteration domain on operand subranges is well + // defined. + if (!linalgOp.getShapesToLoopsMap()) + return false; + // Next verify that operand slices are consistent. + DenseMap mappedOffsets, mappedSizes; + for (auto [indexingMap, offsets, sizes] : + llvm::zip_equal(indexingMaps, allOffsets, allSizes)) { + for (auto [resultExpr, offset, size] : + llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) { + auto dimExpr = dyn_cast(resultExpr); + if (!dimExpr) + return false; + unsigned position = dimExpr.getPosition(); + auto it = mappedOffsets.find(position); + if (it != mappedOffsets.end()) { + OpFoldResult seenOffset = it->second; + OpFoldResult seenSize = mappedSizes.lookup(position); + if (seenOffset != offset || seenSize != size) + return false; + } else { + mappedOffsets[position] = offset; + mappedSizes[position] = size; + } + } + } + return true; + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Interfaces/TilingInterface/query-fusability.mlir b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir new file mode 100644 index 0000000000000..1fa828c9cd868 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics + +func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 326fec3ee5cf0..d6bb178505d2b 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" @@ -622,6 +623,63 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestQueryProducerFusability +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { + for (Operation *target : state.getPayloadOps(getTarget())) { + auto tilingInterfaceOp = dyn_cast(target); + if (!tilingInterfaceOp) { + return emitSilenceableError() + << "target operation does not implement TilingInterface"; + } + + // Collect operand numbers and their corresponding producer insert_slice + // offsets and sizes. + SmallVector operandNumbers; + SmallVector> allOffsets; + SmallVector> allSizes; + + for (OpOperand &operand : target->getOpOperands()) { + Value operandValue = operand.get(); + Operation *definingOp = operandValue.getDefiningOp(); + + // Look for a producer tensor.insert_slice. This is only for testing + // purposes and otherwise is not a useful transformation. + if (auto insertSliceOp = + dyn_cast_or_null(definingOp)) { + operandNumbers.push_back(operand.getOperandNumber()); + allOffsets.push_back(insertSliceOp.getMixedOffsets()); + allSizes.push_back(insertSliceOp.getMixedSizes()); + } + } + + if (!operandNumbers.empty()) { + bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices( + operandNumbers, allOffsets, allSizes); + + if (isFusable) { + target->emitRemark() + << "can be fused with producer tensor.insert_slice ops"; + } else { + target->emitRemark() + << "cannot be fused with producer tensor.insert_slice ops"; + } + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::TestQueryProducerFusability::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsPayload(effects); +} + #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 694c4229eef62..4d0998052ba79 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -166,11 +166,33 @@ def TestTileUsingCustomLoopOp : Op< DefaultValuedAttr:$tile_sizes); let results = (outs TransformHandleTypeInterface:$tiled_ops, Variadic:$loops); - + let assemblyFormat = [{ $root_op `tile_sizes` `=` $tile_sizes attr-dict `:` functional-type(operands, results) }]; } +def TestQueryProducerFusability : Op< + Transform_Dialect, "test.query_producer_fusability", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let description = [{ + Test operation for the producer fusability query method in the + TilingInterface. + + For each operation in the target handle, this looks for tensor.insert_slice + ops that produce operands to the tilable op. The offset/sizes from those + inserts is used as the arguments to `isOpFusableWithProducerSlices` and + emits a remark with the result of the query. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; +} + #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS