@@ -1182,6 +1182,60 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
11821182 } } ;
11831183 }
11841184
1185+ /// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
1186+ macro_rules! require_int_ty {
1187+ ( $ty: expr, $diag: expr) => {
1188+ match $ty {
1189+ ty:: Int ( i) => i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) ) ,
1190+ _ => {
1191+ return_error!( $diag) ;
1192+ }
1193+ }
1194+ } ;
1195+ }
1196+
1197+ /// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
1198+ macro_rules! require_int_or_uint_ty {
1199+ ( $ty: expr, $diag: expr) => {
1200+ match $ty {
1201+ ty:: Int ( i) => i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) ) ,
1202+ ty:: Uint ( i) => {
1203+ i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) )
1204+ }
1205+ _ => {
1206+ return_error!( $diag) ;
1207+ }
1208+ }
1209+ } ;
1210+ }
1211+
1212+ /// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
1213+ /// down to an i1 based mask that can be used by llvm intrinsics.
1214+ ///
1215+ /// The rust simd semantics are that each element should either consist of all ones or all zeroes,
1216+ /// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
1217+ /// but codegen for several targets is better if we consider the highest bit by shifting.
1218+ ///
1219+ /// For x86 SSE/AVX targets this is beneficial since most instructions with mask parameters only consider the highest bit.
1220+ /// So even though on llvm level we have an additional shift, in the final assembly there is no shift or truncate and
1221+ /// instead the mask can be used as is.
1222+ ///
1223+ /// For aarch64 and other targets there is a benefit because a mask from the sign bit can be more
1224+ /// efficiently converted to an all ones / all zeroes mask by comparing whether each element is negative.
1225+ fn vector_mask_to_bitmask < ' a , ' ll , ' tcx > (
1226+ bx : & mut Builder < ' a , ' ll , ' tcx > ,
1227+ i_xn : & ' ll Value ,
1228+ in_elem_bitwidth : u64 ,
1229+ in_len : u64 ,
1230+ ) -> & ' ll Value {
1231+ // Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
1232+ let shift_idx = bx. cx . const_int ( bx. type_ix ( in_elem_bitwidth) , ( in_elem_bitwidth - 1 ) as _ ) ;
1233+ let shift_indices = vec ! [ shift_idx; in_len as _] ;
1234+ let i_xn_msb = bx. lshr ( i_xn, bx. const_vector ( shift_indices. as_slice ( ) ) ) ;
1235+ // Truncate vector to an <i1 x N>
1236+ bx. trunc ( i_xn_msb, bx. type_vector ( bx. type_i1 ( ) , in_len) )
1237+ }
1238+
11851239 let tcx = bx. tcx ( ) ;
11861240 let sig = tcx. normalize_erasing_late_bound_regions ( bx. typing_env ( ) , callee_ty. fn_sig ( tcx) ) ;
11871241 let arg_tys = sig. inputs ( ) ;
@@ -1433,14 +1487,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14331487 m_len,
14341488 v_len
14351489 } ) ;
1436- match m_elem_ty. kind ( ) {
1437- ty:: Int ( _) => { }
1438- _ => return_error ! ( InvalidMonomorphization :: MaskType { span, name, ty: m_elem_ty } ) ,
1439- }
1440- // truncate the mask to a vector of i1s
1441- let i1 = bx. type_i1 ( ) ;
1442- let i1xn = bx. type_vector ( i1, m_len as u64 ) ;
1443- let m_i1s = bx. trunc ( args[ 0 ] . immediate ( ) , i1xn) ;
1490+ let in_elem_bitwidth =
1491+ require_int_ty ! ( m_elem_ty. kind( ) , InvalidMonomorphization :: MaskType {
1492+ span,
1493+ name,
1494+ ty: m_elem_ty
1495+ } ) ;
1496+ let m_i1s = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , in_elem_bitwidth, m_len) ;
14441497 return Ok ( bx. select ( m_i1s, args[ 1 ] . immediate ( ) , args[ 2 ] . immediate ( ) ) ) ;
14451498 }
14461499
@@ -1457,33 +1510,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14571510 let expected_bytes = in_len. div_ceil ( 8 ) ;
14581511
14591512 // Integer vector <i{in_bitwidth} x in_len>:
1460- let ( i_xn, in_elem_bitwidth) = match in_elem. kind ( ) {
1461- ty:: Int ( i) => (
1462- args[ 0 ] . immediate ( ) ,
1463- i. bit_width ( ) . unwrap_or_else ( || bx. data_layout ( ) . pointer_size . bits ( ) ) ,
1464- ) ,
1465- ty:: Uint ( i) => (
1466- args[ 0 ] . immediate ( ) ,
1467- i. bit_width ( ) . unwrap_or_else ( || bx. data_layout ( ) . pointer_size . bits ( ) ) ,
1468- ) ,
1469- _ => return_error ! ( InvalidMonomorphization :: VectorArgument {
1513+ let in_elem_bitwidth =
1514+ require_int_or_uint_ty ! ( in_elem. kind( ) , InvalidMonomorphization :: VectorArgument {
14701515 span,
14711516 name,
14721517 in_ty,
14731518 in_elem
1474- } ) ,
1475- } ;
1519+ } ) ;
14761520
1477- // LLVM doesn't always know the inputs are `0` or `!0`, so we shift here so it optimizes to
1478- // `pmovmskb` and similar on x86.
1479- let shift_indices =
1480- vec ! [
1481- bx. cx. const_int( bx. type_ix( in_elem_bitwidth) , ( in_elem_bitwidth - 1 ) as _) ;
1482- in_len as _
1483- ] ;
1484- let i_xn_msb = bx. lshr ( i_xn, bx. const_vector ( shift_indices. as_slice ( ) ) ) ;
1485- // Truncate vector to an <i1 x N>
1486- let i1xn = bx. trunc ( i_xn_msb, bx. type_vector ( bx. type_i1 ( ) , in_len) ) ;
1521+ let i1xn = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , in_elem_bitwidth, in_len) ;
14871522 // Bitcast <i1 x N> to iN:
14881523 let i_ = bx. bitcast ( i1xn, bx. type_ix ( in_len) ) ;
14891524
@@ -1704,28 +1739,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17041739 }
17051740 ) ;
17061741
1707- match element_ty2. kind ( ) {
1708- ty:: Int ( _) => ( ) ,
1709- _ => {
1710- return_error ! ( InvalidMonomorphization :: ThirdArgElementType {
1711- span,
1712- name,
1713- expected_element: element_ty2,
1714- third_arg: arg_tys[ 2 ]
1715- } ) ;
1716- }
1717- }
1742+ let mask_elem_bitwidth =
1743+ require_int_ty ! ( element_ty2. kind( ) , InvalidMonomorphization :: ThirdArgElementType {
1744+ span,
1745+ name,
1746+ expected_element: element_ty2,
1747+ third_arg: arg_tys[ 2 ]
1748+ } ) ;
17181749
17191750 // Alignment of T, must be a constant integer value:
17201751 let alignment_ty = bx. type_i32 ( ) ;
17211752 let alignment = bx. const_i32 ( bx. align_of ( in_elem) . bytes ( ) as i32 ) ;
17221753
17231754 // Truncate the mask vector to a vector of i1s:
1724- let ( mask, mask_ty) = {
1725- let i1 = bx. type_i1 ( ) ;
1726- let i1xn = bx. type_vector ( i1, in_len) ;
1727- ( bx. trunc ( args[ 2 ] . immediate ( ) , i1xn) , i1xn)
1728- } ;
1755+ let mask = vector_mask_to_bitmask ( bx, args[ 2 ] . immediate ( ) , mask_elem_bitwidth, in_len) ;
1756+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , in_len) ;
17291757
17301758 // Type of the vector of pointers:
17311759 let llvm_pointer_vec_ty = llvm_vector_ty ( bx, element_ty1, in_len) ;
@@ -1810,27 +1838,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18101838 }
18111839 ) ;
18121840
1813- require ! (
1814- matches!( mask_elem. kind( ) , ty:: Int ( _) ) ,
1815- InvalidMonomorphization :: ThirdArgElementType {
1841+ let m_elem_bitwidth =
1842+ require_int_ty ! ( mask_elem. kind( ) , InvalidMonomorphization :: ThirdArgElementType {
18161843 span,
18171844 name,
18181845 expected_element: values_elem,
18191846 third_arg: mask_ty,
1820- }
1821- ) ;
1847+ } ) ;
1848+
1849+ let mask = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , m_elem_bitwidth, mask_len) ;
1850+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , mask_len) ;
18221851
18231852 // Alignment of T, must be a constant integer value:
18241853 let alignment_ty = bx. type_i32 ( ) ;
18251854 let alignment = bx. const_i32 ( bx. align_of ( values_elem) . bytes ( ) as i32 ) ;
18261855
1827- // Truncate the mask vector to a vector of i1s:
1828- let ( mask, mask_ty) = {
1829- let i1 = bx. type_i1 ( ) ;
1830- let i1xn = bx. type_vector ( i1, mask_len) ;
1831- ( bx. trunc ( args[ 0 ] . immediate ( ) , i1xn) , i1xn)
1832- } ;
1833-
18341856 let llvm_pointer = bx. type_ptr ( ) ;
18351857
18361858 // Type of the vector of elements:
@@ -1901,27 +1923,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19011923 }
19021924 ) ;
19031925
1904- require ! (
1905- matches!( mask_elem. kind( ) , ty:: Int ( _) ) ,
1906- InvalidMonomorphization :: ThirdArgElementType {
1926+ let m_elem_bitwidth =
1927+ require_int_ty ! ( mask_elem. kind( ) , InvalidMonomorphization :: ThirdArgElementType {
19071928 span,
19081929 name,
19091930 expected_element: values_elem,
19101931 third_arg: mask_ty,
1911- }
1912- ) ;
1932+ } ) ;
1933+
1934+ let mask = vector_mask_to_bitmask ( bx, args[ 0 ] . immediate ( ) , m_elem_bitwidth, mask_len) ;
1935+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , mask_len) ;
19131936
19141937 // Alignment of T, must be a constant integer value:
19151938 let alignment_ty = bx. type_i32 ( ) ;
19161939 let alignment = bx. const_i32 ( bx. align_of ( values_elem) . bytes ( ) as i32 ) ;
19171940
1918- // Truncate the mask vector to a vector of i1s:
1919- let ( mask, mask_ty) = {
1920- let i1 = bx. type_i1 ( ) ;
1921- let i1xn = bx. type_vector ( i1, in_len) ;
1922- ( bx. trunc ( args[ 0 ] . immediate ( ) , i1xn) , i1xn)
1923- } ;
1924-
19251941 let ret_t = bx. type_void ( ) ;
19261942
19271943 let llvm_pointer = bx. type_ptr ( ) ;
@@ -1995,28 +2011,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19952011 ) ;
19962012
19972013 // The element type of the third argument must be a signed integer type of any width:
1998- match element_ty2. kind ( ) {
1999- ty:: Int ( _) => ( ) ,
2000- _ => {
2001- return_error ! ( InvalidMonomorphization :: ThirdArgElementType {
2002- span,
2003- name,
2004- expected_element: element_ty2,
2005- third_arg: arg_tys[ 2 ]
2006- } ) ;
2007- }
2008- }
2014+ let mask_elem_bitwidth =
2015+ require_int_ty ! ( element_ty2. kind( ) , InvalidMonomorphization :: ThirdArgElementType {
2016+ span,
2017+ name,
2018+ expected_element: element_ty2,
2019+ third_arg: arg_tys[ 2 ]
2020+ } ) ;
20092021
20102022 // Alignment of T, must be a constant integer value:
20112023 let alignment_ty = bx. type_i32 ( ) ;
20122024 let alignment = bx. const_i32 ( bx. align_of ( in_elem) . bytes ( ) as i32 ) ;
20132025
20142026 // Truncate the mask vector to a vector of i1s:
2015- let ( mask, mask_ty) = {
2016- let i1 = bx. type_i1 ( ) ;
2017- let i1xn = bx. type_vector ( i1, in_len) ;
2018- ( bx. trunc ( args[ 2 ] . immediate ( ) , i1xn) , i1xn)
2019- } ;
2027+ let mask = vector_mask_to_bitmask ( bx, args[ 2 ] . immediate ( ) , mask_elem_bitwidth, in_len) ;
2028+ let mask_ty = bx. type_vector ( bx. type_i1 ( ) , in_len) ;
20202029
20212030 let ret_t = bx. type_void ( ) ;
20222031
@@ -2164,8 +2173,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
21642173 } ) ;
21652174 args[ 0 ] . immediate( )
21662175 } else {
2167- match in_elem. kind( ) {
2168- ty:: Int ( _) | ty:: Uint ( _) => { }
2176+ let bitwidth = match in_elem. kind( ) {
2177+ ty:: Int ( i) => {
2178+ i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) )
2179+ }
2180+ ty:: Uint ( i) => {
2181+ i. bit_width( ) . unwrap_or_else( || bx. data_layout( ) . pointer_size. bits( ) )
2182+ }
21692183 _ => return_error!( InvalidMonomorphization :: UnsupportedSymbol {
21702184 span,
21712185 name,
@@ -2174,12 +2188,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
21742188 in_elem,
21752189 ret_ty
21762190 } ) ,
2177- }
2191+ } ;
21782192
2179- // boolean reductions operate on vectors of i1s:
2180- let i1 = bx. type_i1( ) ;
2181- let i1xn = bx. type_vector( i1, in_len as u64 ) ;
2182- bx. trunc( args[ 0 ] . immediate( ) , i1xn)
2193+ vector_mask_to_bitmask( bx, args[ 0 ] . immediate( ) , bitwidth, in_len as _)
21832194 } ;
21842195 return match in_elem. kind( ) {
21852196 ty:: Int ( _) | ty:: Uint ( _) => {
0 commit comments