diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index e0516abdfcf0c..c30782a25e40f 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -360,6 +360,43 @@ def TilingInterface : OpInterface<"TilingInterface"> { /*defaultImplementation=*/[{ return failure(); }] + >, + //===------------------------------------------------------------------===// + // Interface methods for querying fusability. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Indicates whether it is possible to fuse this operation with the given + result slice. This method is not allowed to generate any IR. + }], + /*retTy=*/"bool", + /*methodName=*/"isOpFusableWithConsumerSlice", + /*args=*/(ins + "unsigned":$resultNumber, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Indicates whether it is possible to fuse this operation with the given + list of operand slices. This method is not allowed to generate any IR. + }], + /*retTy=*/"bool", + /*methodName=*/"isOpFusableWithProducerSlices", + /*args=*/(ins + "::mlir::ArrayRef":$operandNumbers, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 57b610b31e964..527878786f50f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -359,6 +359,52 @@ struct LinalgOpTilingInterface /// Inline the op payload and store the result. return inlinePayload(builder, linalgOp, ivs, indexedValues); } + + bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + return !cast(op).getShapesToLoopsMap(); + } + + bool isOpFusableWithProducerSlices( + Operation *op, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes) const { + + auto linalgOp = cast(op); + SmallVector indexingMaps = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { + OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + // First verify that the iteration domain on operand subranges is well + // defined. + if (!linalgOp.getShapesToLoopsMap()) + return false; + // Next verify that operand slices are consistent. + DenseMap mappedOffsets, mappedSizes; + for (auto [indexingMap, offsets, sizes] : + llvm::zip_equal(indexingMaps, allOffsets, allSizes)) { + for (auto [resultExpr, offset, size] : + llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) { + auto dimExpr = dyn_cast(resultExpr); + if (!dimExpr) + return false; + unsigned position = dimExpr.getPosition(); + auto it = mappedOffsets.find(position); + if (it != mappedOffsets.end()) { + OpFoldResult seenOffset = it->second; + OpFoldResult seenSize = mappedSizes.lookup(position); + if (seenOffset != offset || seenSize != size) + return false; + } else { + mappedOffsets[position] = offset; + mappedSizes[position] = size; + } + } + } + return true; + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Interfaces/TilingInterface/query-fusability.mlir b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir new file mode 100644 index 0000000000000..1fa828c9cd868 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics + +func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 326fec3ee5cf0..d6bb178505d2b 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" @@ -622,6 +623,63 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestQueryProducerFusability +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { + for (Operation *target : state.getPayloadOps(getTarget())) { + auto tilingInterfaceOp = dyn_cast(target); + if (!tilingInterfaceOp) { + return emitSilenceableError() + << "target operation does not implement TilingInterface"; + } + + // Collect operand numbers and their corresponding producer insert_slice + // offsets and sizes. + SmallVector operandNumbers; + SmallVector> allOffsets; + SmallVector> allSizes; + + for (OpOperand &operand : target->getOpOperands()) { + Value operandValue = operand.get(); + Operation *definingOp = operandValue.getDefiningOp(); + + // Look for a producer tensor.insert_slice. This is only for testing + // purposes and otherwise is not a useful transformation. + if (auto insertSliceOp = + dyn_cast_or_null(definingOp)) { + operandNumbers.push_back(operand.getOperandNumber()); + allOffsets.push_back(insertSliceOp.getMixedOffsets()); + allSizes.push_back(insertSliceOp.getMixedSizes()); + } + } + + if (!operandNumbers.empty()) { + bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices( + operandNumbers, allOffsets, allSizes); + + if (isFusable) { + target->emitRemark() + << "can be fused with producer tensor.insert_slice ops"; + } else { + target->emitRemark() + << "cannot be fused with producer tensor.insert_slice ops"; + } + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::TestQueryProducerFusability::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsPayload(effects); +} + #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 694c4229eef62..4d0998052ba79 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -166,11 +166,33 @@ def TestTileUsingCustomLoopOp : Op< DefaultValuedAttr:$tile_sizes); let results = (outs TransformHandleTypeInterface:$tiled_ops, Variadic:$loops); - + let assemblyFormat = [{ $root_op `tile_sizes` `=` $tile_sizes attr-dict `:` functional-type(operands, results) }]; } +def TestQueryProducerFusability : Op< + Transform_Dialect, "test.query_producer_fusability", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let description = [{ + Test operation for the producer fusability query method in the + TilingInterface. + + For each operation in the target handle, this looks for tensor.insert_slice + ops that produce operands to the tilable op. The offset/sizes from those + inserts is used as the arguments to `isOpFusableWithProducerSlices` and + emits a remark with the result of the query. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; +} + #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS