Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,43 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;

/// Return the given type if it's a floating point type. If the given type is
/// a vector type, return its element type if it's a floating point type.
static FloatType getFloatingPointType(Type type) {
if (auto floatType = dyn_cast<FloatType>(type))
return floatType;
if (auto vecType = dyn_cast<VectorType>(type))
return dyn_cast<FloatType>(vecType.getElementType());
return nullptr;
}

LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");

// The pattern should not apply if a floating-point operand is converted to
// a non-floating-point type. This indicates that the floating point type
// is not supported by the LLVM lowering. (Such types are converted to
// integers.)
auto checkType = [&](Value v) -> LogicalResult {
FloatType floatType = getFloatingPointType(v.getType());
if (!floatType)
return success();
Type convertedType = this->getTypeConverter()->convertType(floatType);
if (!isa_and_nonnull<FloatType>(convertedType))
return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
return success();
};
for (Value operand : op->getOperands())
if (failed(checkType(operand)))
return failure();
if (failed(checkType(op->getResult(0))))
return failure();

// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);

Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,29 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
func.return %2 : memref<?xbf16>
}

// -----

// CHECK-LABEL: func @unsupported_fp_type
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
}

// -----

// CHECK-LABEL: func @supported_fp_type
// CHECK: llvm.fadd {{.*}} : f32
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
%0 = arith.addf %arg0, %arg0 : f32
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
}