@@ -401,6 +401,24 @@ def test_none_shape_bool(self, xp: ModuleType):
401401 a = a [a ]
402402 xp_assert_equal (isclose (a , b ), xp .asarray ([True , False ]))
403403
404+ @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
405+ @pytest .mark .skip_xp_backend (Backend .TORCH , reason = "Array API 2024.12 support" )
406+ def test_python_scalar (self , xp : ModuleType ):
407+ a = xp .asarray ([0.0 , 0.1 ], dtype = xp .float32 )
408+ xp_assert_equal (isclose (a , 0.0 ), xp .asarray ([True , False ]))
409+ xp_assert_equal (isclose (0.0 , a ), xp .asarray ([True , False ]))
410+
411+ a = xp .asarray ([0 , 1 ], dtype = xp .int16 )
412+ xp_assert_equal (isclose (a , 0 ), xp .asarray ([True , False ]))
413+ xp_assert_equal (isclose (0 , a ), xp .asarray ([True , False ]))
414+
415+ xp_assert_equal (isclose (0 , 0 , xp = xp ), xp .asarray (True ))
416+ xp_assert_equal (isclose (0 , 1 , xp = xp ), xp .asarray (False ))
417+
418+ def test_all_python_scalars (self ):
419+ with pytest .raises (TypeError , match = "Unrecognized" ):
420+ isclose (0 , 0 )
421+
404422 def test_xp (self , xp : ModuleType ):
405423 a = xp .asarray ([0.0 , 0.0 ])
406424 b = xp .asarray ([1e-9 , 1e-4 ])
@@ -413,30 +431,22 @@ def test_basic(self, xp: ModuleType):
413431 # Using 0-dimensional array
414432 a = xp .asarray (1 )
415433 b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
416- k = xp .asarray ([[1 , 2 ], [3 , 4 ]])
417- xp_assert_equal (kron (a , b ), k )
418- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
419- b = xp .asarray (1 )
420- xp_assert_equal (kron (a , b ), k )
434+ xp_assert_equal (kron (a , b ), b )
435+ xp_assert_equal (kron (b , a ), b )
421436
422437 # Using 1-dimensional array
423438 a = xp .asarray ([3 ])
424439 b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
425440 k = xp .asarray ([[3 , 6 ], [9 , 12 ]])
426441 xp_assert_equal (kron (a , b ), k )
427- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
428- b = xp .asarray ([3 ])
429- xp_assert_equal (kron (a , b ), k )
442+ xp_assert_equal (kron (b , a ), k )
430443
431444 # Using 3-dimensional array
432445 a = xp .asarray ([[[1 ]], [[2 ]]])
433446 b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
434447 k = xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
435448 xp_assert_equal (kron (a , b ), k )
436- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
437- b = xp .asarray ([[[1 ]], [[2 ]]])
438- k = xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
439- xp_assert_equal (kron (a , b ), k )
449+ xp_assert_equal (kron (b , a ), k )
440450
441451 def test_kron_smoke (self , xp : ModuleType ):
442452 a = xp .ones ((3 , 3 ))
@@ -474,6 +484,18 @@ def test_kron_shape(
474484 k = kron (a , b )
475485 assert k .shape == expected_shape
476486
487+ def test_python_scalar (self , xp : ModuleType ):
488+ a = 1
489+ # Test no dtype promotion to xp.asarray(a); use b.dtype
490+ b = xp .asarray ([[1 , 2 ], [3 , 4 ]], dtype = xp .int16 )
491+ xp_assert_equal (kron (a , b ), b )
492+ xp_assert_equal (kron (b , a ), b )
493+ xp_assert_equal (kron (1 , 1 , xp = xp ), xp .asarray (1 ))
494+
495+ def test_all_python_scalars (self ):
496+ with pytest .raises (TypeError , match = "Unrecognized" ):
497+ kron (1 , 1 )
498+
477499 def test_device (self , xp : ModuleType , device : Device ):
478500 x1 = xp .asarray ([1 , 2 , 3 ], device = device )
479501 x2 = xp .asarray ([4 , 5 ], device = device )
@@ -601,6 +623,28 @@ def test_shapes(
601623 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
602624 xp_assert_equal (actual , xp .empty ((0 ,)))
603625
626+ @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
627+ @pytest .mark .parametrize ("assume_unique" , [True , False ])
628+ def test_python_scalar (self , xp : ModuleType , assume_unique : bool ):
629+ # Test no dtype promotion to xp.asarray(x2); use x1.dtype
630+ x1 = xp .asarray ([3 , 1 , 2 ], dtype = xp .int16 )
631+ x2 = 3
632+ actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
633+ xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
634+
635+ actual = setdiff1d (x2 , x1 , assume_unique = assume_unique )
636+ xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
637+
638+ xp_assert_equal (
639+ setdiff1d (0 , 0 , assume_unique = assume_unique , xp = xp ),
640+ xp .asarray ([0 ])[:0 ], # Default int dtype for backend
641+ )
642+
643+ @pytest .mark .parametrize ("assume_unique" , [True , False ])
644+ def test_all_python_scalars (self , assume_unique : bool ):
645+ with pytest .raises (TypeError , match = "Unrecognized" ):
646+ setdiff1d (0 , 0 , assume_unique = assume_unique )
647+
604648 def test_device (self , xp : ModuleType , device : Device ):
605649 x1 = xp .asarray ([3 , 8 , 20 ], device = device )
606650 x2 = xp .asarray ([2 , 3 , 4 ], device = device )
0 commit comments