@@ -59,28 +59,26 @@ impl Barrett {
5959///
6060/// * `a` `0 <= a < m`
6161/// * `b` `0 <= b < m`
62- /// * `m` `1 <= m <= 2^31 `
63- /// * `im` = ceil(2^64 / `m`)
62+ /// * `m` `1 <= m < 2^32 `
63+ /// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1
6464#[ allow( clippy:: many_single_char_names) ]
6565pub ( crate ) fn mul_mod ( a : u32 , b : u32 , m : u32 , im : u64 ) -> u32 {
6666 // [1] m = 1
6767 // a = b = im = 0, so okay
6868
6969 // [2] m >= 2
70- // im = ceil(2^64 / m)
70+ // im = ceil(2^64 / m) = floor((2^64 - 1) / m) + 1
7171 // -> im * m = 2^64 + r (0 <= r < m)
7272 // let z = a*b = c*m + d (0 <= c, d < m)
7373 // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
7474 // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
7575 // ((ab * im) >> 64) == c or c + 1
76- let mut z = a as u64 ;
77- z *= b as u64 ;
76+ let z = ( a as u64 ) * ( b as u64 ) ;
7877 let x = ( ( ( z as u128 ) * ( im as u128 ) ) >> 64 ) as u64 ;
79- let mut v = z . wrapping_sub ( x. wrapping_mul ( m as u64 ) ) as u32 ;
80- if m <= v {
81- v = v . wrapping_add ( m ) ;
78+ match z . overflowing_sub ( x. wrapping_mul ( m as u64 ) ) {
79+ ( v , true ) => ( v as u32 ) . wrapping_add ( m ) ,
80+ ( v , false ) => v as u32 ,
8281 }
83- v
8482}
8583
8684/// # Parameters
@@ -320,6 +318,17 @@ mod tests {
320318 let b = Barrett :: new ( 2147483647 ) ;
321319 assert_eq ! ( b. umod( ) , 2147483647 ) ;
322320 assert_eq ! ( b. mul( 1073741824 , 2147483645 ) , 2147483646 ) ;
321+
322+ // test `2^31 < self._m < 2^32` case.
323+ // https://github.com/rust-lang-ja/ac-library-rs/pull/112
324+ // https://github.com/atcoder/ac-library/issues/149
325+ // https://github.com/atcoder/ac-library/pull/163
326+ let b = Barrett :: new ( 3221225471 ) ;
327+ assert_eq ! ( b. umod( ) , 3221225471 ) ;
328+ assert_eq ! ( b. mul( 3188445886 , 2844002853 ) , 1840468257 ) ;
329+ assert_eq ! ( b. mul( 2834869488 , 2779159607 ) , 2084027561 ) ;
330+ assert_eq ! ( b. mul( 3032263594 , 3039996727 ) , 2130247251 ) ;
331+ assert_eq ! ( b. mul( 3029175553 , 3140869278 ) , 1892378237 ) ;
323332 }
324333
325334 #[ test]
0 commit comments