Skip to content

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 5, 2025

Fixes a bug in VectorConvertToLLVMPattern, which converted operations with unsupported FP types. E.g., arith.addf ... : f4E2M1FN was lowered to llvm.fadd ... : i4, which does not verify. There are a few more patterns that have the same bug. Those will be fixed in follow-up PRs.

This commit is in preparation of adding an APFloat-based lowering for arith operations with unsupported floating-point types.

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Fixes a bug in VectorConvertToLLVMPattern, which converted operations with unsupported FP types. E.g., arith.addf ... : f4E2M1FN was lowered to llvm.fadd ... : i4, which does not verify.


Full diff: https://github.com/llvm/llvm-project/pull/166513.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+25)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+26)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..d8483114f1137 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -92,12 +92,37 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
+  /// Return "true" if the given type (or its element type) is a floating point
+  /// type.
+  static FloatType getFloatingPointType(Type type) {
+    if (auto floatType = dyn_cast<FloatType>(type))
+      return floatType;
+    if (auto vecType = dyn_cast<VectorType>(type))
+      return dyn_cast<FloatType>(vecType.getElementType());
+    return nullptr;
+  }
+
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     static_assert(
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
+
+    // The pattern should not apply if a floating-point operand is converted to
+    // a non-floating-point type. This indicates that the floating point type
+    // is not supported by the LLVM lowering. (Such types are converted to
+    // integers.)
+    for (Value operand : op->getOperands()) {
+      FloatType floatType = getFloatingPointType(operand.getType());
+      if (!floatType)
+        continue;
+      Type convertedType = this->getTypeConverter()->convertType(floatType);
+      if (!isa<FloatType>(convertedType))
+        return rewriter.notifyMatchFailure(op,
+                                           "unsupported floating point type");
+    }
+
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);
 
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index ba12ff29ebef9..b5dcb01d3dc6b 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -747,3 +747,29 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
   %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
   func.return %2 : memref<?xbf16>
 }
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+//       CHECK:   arith.addf {{.*}} : f4E2M1FN
+//       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
+//       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+  %0 = arith.addf %arg0, %arg0 : f4E2M1FN
+  %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
+  %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
+  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+}
+
+// -----
+
+//   CHECK-LABEL: func @supported_fp_type
+//         CHECK:   llvm.fadd {{.*}} : f32
+//         CHECK:   llvm.fadd {{.*}} : vector<4xf32>
+// CHECK-COUNT-4:   llvm.fadd {{.*}} : vector<8xf32>
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+  %0 = arith.addf %arg0, %arg0 : f32
+  %1 = arith.addf %arg1, %arg1 : vector<4xf32>
+  %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
+  return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+}

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vec_to_llvm_fp branch from 6eb0454 to b8c8d79 Compare November 5, 2025 07:17
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vec_to_llvm_fp branch from b8c8d79 to 7e53773 Compare November 5, 2025 12:02
@matthias-springer matthias-springer enabled auto-merge (squash) November 5, 2025 12:03
@matthias-springer matthias-springer merged commit 6c640b8 into main Nov 5, 2025
10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/vec_to_llvm_fp branch November 5, 2025 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants