@@ -532,6 +532,22 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
532532 a = testing .shaped_random ((0 , 1 ), xp , dtype )
533533 return xp .nanargmin (a , axis = 1 )
534534
535+ @testing .for_all_dtypes (no_complex = True )
536+ @testing .numpy_cupy_allclose ()
537+ def test_nanargmin_out_float_dtype (self , xp , dtype ):
538+ a = xp .array ([[0.0 ]])
539+ b = xp .empty ((1 ), dtype = "int64" )
540+ xp .nanargmin (a , axis = 1 , out = b )
541+ return b
542+
543+ @testing .for_all_dtypes (no_complex = True )
544+ @testing .numpy_cupy_array_equal ()
545+ def test_nanargmin_out_int_dtype (self , xp , dtype ):
546+ a = xp .array ([1 , 0 ])
547+ b = xp .empty ((), dtype = "int64" )
548+ xp .nanargmin (a , out = b )
549+ return b
550+
535551
536552class TestNanArgMax :
537553
@@ -623,6 +639,22 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
623639 a = testing .shaped_random ((0 , 1 ), xp , dtype )
624640 return xp .nanargmax (a , axis = 1 )
625641
642+ @testing .for_all_dtypes (no_complex = True )
643+ @testing .numpy_cupy_allclose ()
644+ def test_nanargmax_out_float_dtype (self , xp , dtype ):
645+ a = xp .array ([[0.0 ]])
646+ b = xp .empty ((1 ), dtype = "int64" )
647+ xp .nanargmax (a , axis = 1 , out = b )
648+ return b
649+
650+ @testing .for_all_dtypes (no_complex = True )
651+ @testing .numpy_cupy_array_equal ()
652+ def test_nanargmax_out_int_dtype (self , xp , dtype ):
653+ a = xp .array ([0 , 1 ])
654+ b = xp .empty ((), dtype = "int64" )
655+ xp .nanargmax (a , out = b )
656+ return b
657+
626658
627659@testing .parameterize (
628660 * testing .product (
@@ -771,7 +803,7 @@ def test_invalid_sorter(self):
771803
772804 def test_nonint_sorter (self ):
773805 for xp in (numpy , cupy ):
774- x = testing .shaped_arange ((12 ,), xp , xp .float32 )
806+ x = testing .shaped_arange ((12 ,), xp , xp .float64 )
775807 bins = xp .array ([10 , 4 , 2 , 1 , 8 ])
776808 sorter = xp .array ([], dtype = xp .float32 )
777809 with pytest .raises ((TypeError , ValueError )):
@@ -865,7 +897,7 @@ def test_invalid_sorter(self):
865897
866898 def test_nonint_sorter (self ):
867899 for xp in (numpy , cupy ):
868- x = testing .shaped_arange ((12 ,), xp , xp .float32 )
900+ x = testing .shaped_arange ((12 ,), xp , xp .float64 )
869901 bins = xp .array ([10 , 4 , 2 , 1 , 8 ])
870902 sorter = xp .array ([], dtype = xp .float32 )
871903 with pytest .raises ((TypeError , ValueError )):
0 commit comments