@@ -92,12 +92,43 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
9292 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
9393 using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
9494
95+ // / Return the given type if it's a floating point type. If the given type is
96+ // / a vector type, return its element type if it's a floating point type.
97+ static FloatType getFloatingPointType (Type type) {
98+ if (auto floatType = dyn_cast<FloatType>(type))
99+ return floatType;
100+ if (auto vecType = dyn_cast<VectorType>(type))
101+ return dyn_cast<FloatType>(vecType.getElementType ());
102+ return nullptr ;
103+ }
104+
95105 LogicalResult
96106 matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor,
97107 ConversionPatternRewriter &rewriter) const override {
98108 static_assert (
99109 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
100110 " expected single result op" );
111+
112+ // The pattern should not apply if a floating-point operand is converted to
113+ // a non-floating-point type. This indicates that the floating point type
114+ // is not supported by the LLVM lowering. (Such types are converted to
115+ // integers.)
116+ auto checkType = [&](Value v) -> LogicalResult {
117+ FloatType floatType = getFloatingPointType (v.getType ());
118+ if (!floatType)
119+ return success ();
120+ Type convertedType = this ->getTypeConverter ()->convertType (floatType);
121+ if (!isa_and_nonnull<FloatType>(convertedType))
122+ return rewriter.notifyMatchFailure (op,
123+ " unsupported floating point type" );
124+ return success ();
125+ };
126+ for (Value operand : op->getOperands ())
127+ if (failed (checkType (operand)))
128+ return failure ();
129+ if (failed (checkType (op->getResult (0 ))))
130+ return failure ();
131+
101132 // Determine attributes for the target op
102133 AttrConvert<SourceOp, TargetOp> attrConvert (op);
103134
0 commit comments