Skip to content

Commit 1108c97

Browse files
committed
Add test_out1_overlap
1 parent 85530e7 commit 1108c97

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

dpnp/tests/test_mathematical.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,7 @@ class TestFrexp:
716716
ALL_DTYPES_NO_COMPLEX = get_all_dtypes(
717717
no_none=True, no_float16=False, no_complex=True
718718
)
719+
ALL_FLOAT_DTYPES = get_float_dtypes(no_float16=False)
719720

720721
@pytest.mark.parametrize("dt", ALL_DTYPES_NO_COMPLEX)
721722
def test_basic(self, dt):
@@ -727,7 +728,7 @@ def test_basic(self, dt):
727728
assert_array_equal(res1, exp1)
728729
assert_array_equal(res2, exp2)
729730

730-
@pytest.mark.parametrize("dt", get_float_dtypes())
731+
@pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES)
731732
def test_out(self, dt):
732733
a = numpy.array(5.7, dtype=dt)
733734
ia = dpnp.array(a)
@@ -784,7 +785,7 @@ def test_out_all_dtypes(self, dt, out1_dt, out2_dt):
784785
reason="numpy.frexp gives different answers for NAN/INF on Windows and Linux",
785786
)
786787
@pytest.mark.parametrize("stride", [-4, -2, -1, 1, 2, 4])
787-
@pytest.mark.parametrize("dt", get_float_dtypes())
788+
@pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES)
788789
def test_strides_out(self, stride, dt):
789790
a = numpy.array(
790791
[numpy.nan, numpy.nan, numpy.inf, -numpy.inf, 0.0, -0.0, 1.0, -1.0],
@@ -808,26 +809,18 @@ def test_strides_out(self, stride, dt):
808809
assert_array_equal(iout_mant, out_mant)
809810
assert_array_equal(iout_exp, out_exp)
810811

811-
@pytest.mark.parametrize("dt", get_float_dtypes())
812+
@pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES)
812813
def test_out_overlap(self, dt):
813-
a = numpy.ones(15, dtype=dt)
814+
size = 15
815+
a = numpy.ones(2 * size, dtype=dt)
814816
ia = dpnp.array(a)
815817

816-
out_mant = numpy.ones_like(a)
817-
out_exp = 2 * numpy.ones_like(a, dtype="i")
818-
iout_mant, iout_exp = dpnp.array(out_mant), dpnp.array(out_exp)
819-
820-
res1, res2 = dpnp.frexp(ia, out=(iout_mant, iout_exp))
821-
exp1, exp2 = numpy.frexp(a, out=(out_mant, out_exp))
822-
assert_array_equal(res1, exp1)
823-
assert_array_equal(res2, exp2)
818+
# out1 overlaps memory of input array
819+
_ = dpnp.frexp(ia[size::], ia[::2])
820+
_ = numpy.frexp(a[size::], a[::2])
821+
assert_array_equal(ia, a)
824822

825-
assert_array_equal(iout_mant, out_mant)
826-
assert_array_equal(iout_exp, out_exp)
827-
assert res1 is iout_mant
828-
assert res2 is iout_exp
829-
830-
@pytest.mark.parametrize("dt", get_float_dtypes())
823+
@pytest.mark.parametrize("dt", ALL_FLOAT_DTYPES)
831824
def test_empty(self, dt):
832825
a = numpy.empty(0, dtype=dt)
833826
ia = dpnp.array(a)

0 commit comments

Comments
 (0)