-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][LLVM] Fix unsupported FP lowering in VectorConvertToLLVMPattern
#166513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesFixes a bug in Full diff: https://github.com/llvm/llvm-project/pull/166513.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..d8483114f1137 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -92,12 +92,37 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
+ /// Return "true" if the given type (or its element type) is a floating point
+ /// type.
+ static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+ }
+
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
+
+ // The pattern should not apply if a floating-point operand is converted to
+ // a non-floating-point type. This indicates that the floating point type
+ // is not supported by the LLVM lowering. (Such types are converted to
+ // integers.)
+ for (Value operand : op->getOperands()) {
+ FloatType floatType = getFloatingPointType(operand.getType());
+ if (!floatType)
+ continue;
+ Type convertedType = this->getTypeConverter()->convertType(floatType);
+ if (!isa<FloatType>(convertedType))
+ return rewriter.notifyMatchFailure(op,
+ "unsupported floating point type");
+ }
+
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index ba12ff29ebef9..b5dcb01d3dc6b 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -747,3 +747,29 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
func.return %2 : memref<?xbf16>
}
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+// CHECK: arith.addf {{.*}} : f4E2M1FN
+// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
+// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+ %0 = arith.addf %arg0, %arg0 : f4E2M1FN
+ %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
+ %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
+ return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+}
+
+// -----
+
+// CHECK-LABEL: func @supported_fp_type
+// CHECK: llvm.fadd {{.*}} : f32
+// CHECK: llvm.fadd {{.*}} : vector<4xf32>
+// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+ %0 = arith.addf %arg0, %arg0 : f32
+ %1 = arith.addf %arg1, %arg1 : vector<4xf32>
+ %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
+ return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+}
|
6eb0454 to
b8c8d79
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
b8c8d79 to
7e53773
Compare
Fixes a bug in
VectorConvertToLLVMPattern, which converted operations with unsupported FP types. E.g.,arith.addf ... : f4E2M1FNwas lowered tollvm.fadd ... : i4, which does not verify. There are a few more patterns that have the same bug. Those will be fixed in follow-up PRs.This commit is in preparation of adding an
APFloat-based lowering forarithoperations with unsupported floating-point types.