Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,43 @@ def TilingInterface : OpInterface<"TilingInterface"> {
/*defaultImplementation=*/[{
return failure();
}]
>,
//===------------------------------------------------------------------===//
// Interface methods for querying fusability.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Indicates whether it is possible to fuse this operation with the given
result slice. This method is not allowed to generate any IR.
}],
/*retTy=*/"bool",
/*methodName=*/"isOpFusableWithConsumerSlice",
/*args=*/(ins
"unsigned":$resultNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Indicates whether it is possible to fuse this operation with the given
list of operand slices. This method is not allowed to generate any IR.
}],
/*retTy=*/"bool",
/*methodName=*/"isOpFusableWithProducerSlices",
/*args=*/(ins
"::mlir::ArrayRef<unsigned>":$operandNumbers,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]
>
];
}
Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,52 @@ struct LinalgOpTilingInterface
/// Inline the op payload and store the result.
return inlinePayload(builder, linalgOp, ivs, indexedValues);
}

bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
return !cast<LinalgOp>(op).getShapesToLoopsMap();
}

bool isOpFusableWithProducerSlices(
Operation *op, ArrayRef<unsigned> operandNumbers,
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {

auto linalgOp = cast<LinalgOp>(op);
SmallVector<AffineMap> indexingMaps =
llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
return linalgOp.getMatchingIndexingMap(&opOperand);
});
// First verify that the iteration domain on operand subranges is well
// defined.
if (!linalgOp.getShapesToLoopsMap())
return false;
// Next verify that operand slices are consistent.
DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
for (auto [indexingMap, offsets, sizes] :
llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
for (auto [resultExpr, offset, size] :
llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
if (!dimExpr)
return false;
unsigned position = dimExpr.getPosition();
auto it = mappedOffsets.find(position);
if (it != mappedOffsets.end()) {
OpFoldResult seenOffset = it->second;
OpFoldResult seenSize = mappedSizes.lookup(position);
if (seenOffset != offset || seenSize != size)
return false;
} else {
mappedOffsets[position] = offset;
mappedSizes[position] = size;
}
}
}
return true;
}
};

//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions mlir/test/Interfaces/TilingInterface/query-fusability.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics

func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%c20 = arith.constant 20 : index

%slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
%slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>

// expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}}
%result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>

return %result : tensor<100x200xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
transform.test.query_producer_fusability %add : !transform.any_op
transform.yield
}
}

// -----

func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%c20 = arith.constant 20 : index

%slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>
%slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32>

// expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}}
%result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>)
outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32>

return %result : tensor<100x200xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op
transform.test.query_producer_fusability %add : !transform.any_op
transform.yield
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
Expand Down Expand Up @@ -622,6 +623,63 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TestQueryProducerFusability
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply(
TransformRewriter &rewriter, TransformResults &transformResults,
TransformState &state) {
for (Operation *target : state.getPayloadOps(getTarget())) {
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
if (!tilingInterfaceOp) {
return emitSilenceableError()
<< "target operation does not implement TilingInterface";
}

// Collect operand numbers and their corresponding producer insert_slice
// offsets and sizes.
SmallVector<unsigned> operandNumbers;
SmallVector<SmallVector<OpFoldResult>> allOffsets;
SmallVector<SmallVector<OpFoldResult>> allSizes;

for (OpOperand &operand : target->getOpOperands()) {
Value operandValue = operand.get();
Operation *definingOp = operandValue.getDefiningOp();

// Look for a producer tensor.insert_slice. This is only for testing
// purposes and otherwise is not a useful transformation.
if (auto insertSliceOp =
dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) {
operandNumbers.push_back(operand.getOperandNumber());
allOffsets.push_back(insertSliceOp.getMixedOffsets());
allSizes.push_back(insertSliceOp.getMixedSizes());
}
}

if (!operandNumbers.empty()) {
bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices(
operandNumbers, allOffsets, allSizes);

if (isFusable) {
target->emitRemark()
<< "can be fused with producer tensor.insert_slice ops";
} else {
target->emitRemark()
<< "cannot be fused with producer tensor.insert_slice ops";
}
}
}

return DiagnosedSilenceableFailure::success();
}

void transform::TestQueryProducerFusability::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
onlyReadsPayload(effects);
}

#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.cpp.inc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,33 @@ def TestTileUsingCustomLoopOp : Op<
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs TransformHandleTypeInterface:$tiled_ops,
Variadic<TransformHandleTypeInterface>:$loops);

let assemblyFormat = [{
$root_op `tile_sizes` `=` $tile_sizes
attr-dict `:` functional-type(operands, results)
}];
}

def TestQueryProducerFusability : Op<
Transform_Dialect, "test.query_producer_fusability",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Test operation for the producer fusability query method in the
TilingInterface.

For each operation in the target handle, this looks for tensor.insert_slice
ops that produce operands to the tilable op. The offset/sizes from those
inserts is used as the arguments to `isOpFusableWithProducerSlices` and
emits a remark with the result of the query.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs);

let assemblyFormat = [{
$target attr-dict `:` type($target)
}];
}

#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS