Skip to content

Commit 6c640b8

Browse files
[mlir][LLVM] Fix unsupported FP lowering in VectorConvertToLLVMPattern (#166513)
Fixes a bug in `VectorConvertToLLVMPattern`, which converted operations with unsupported FP types. E.g., `arith.addf ... : f4E2M1FN` was lowered to `llvm.fadd ... : i4`, which does not verify. There are a few more patterns that have the same bug. Those will be fixed in follow-up PRs. This commit is in preparation of adding an `APFloat`-based lowering for `arith` operations with unsupported floating-point types.
1 parent a38e094 commit 6c640b8

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,43 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
9292
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
9393
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
9494

95+
/// Return the given type if it's a floating point type. If the given type is
96+
/// a vector type, return its element type if it's a floating point type.
97+
static FloatType getFloatingPointType(Type type) {
98+
if (auto floatType = dyn_cast<FloatType>(type))
99+
return floatType;
100+
if (auto vecType = dyn_cast<VectorType>(type))
101+
return dyn_cast<FloatType>(vecType.getElementType());
102+
return nullptr;
103+
}
104+
95105
LogicalResult
96106
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
97107
ConversionPatternRewriter &rewriter) const override {
98108
static_assert(
99109
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
100110
"expected single result op");
111+
112+
// The pattern should not apply if a floating-point operand is converted to
113+
// a non-floating-point type. This indicates that the floating point type
114+
// is not supported by the LLVM lowering. (Such types are converted to
115+
// integers.)
116+
auto checkType = [&](Value v) -> LogicalResult {
117+
FloatType floatType = getFloatingPointType(v.getType());
118+
if (!floatType)
119+
return success();
120+
Type convertedType = this->getTypeConverter()->convertType(floatType);
121+
if (!isa_and_nonnull<FloatType>(convertedType))
122+
return rewriter.notifyMatchFailure(op,
123+
"unsupported floating point type");
124+
return success();
125+
};
126+
for (Value operand : op->getOperands())
127+
if (failed(checkType(operand)))
128+
return failure();
129+
if (failed(checkType(op->getResult(0))))
130+
return failure();
131+
101132
// Determine attributes for the target op
102133
AttrConvert<SourceOp, TargetOp> attrConvert(op);
103134

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,3 +747,29 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
747747
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
748748
func.return %2 : memref<?xbf16>
749749
}
750+
751+
// -----
752+
753+
// CHECK-LABEL: func @unsupported_fp_type
754+
// CHECK: arith.addf {{.*}} : f4E2M1FN
755+
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
756+
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
757+
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
758+
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
759+
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
760+
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
761+
return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
762+
}
763+
764+
// -----
765+
766+
// CHECK-LABEL: func @supported_fp_type
767+
// CHECK: llvm.fadd {{.*}} : f32
768+
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
769+
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
770+
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
771+
%0 = arith.addf %arg0, %arg0 : f32
772+
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
773+
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
774+
return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
775+
}

0 commit comments

Comments
 (0)