@@ -411,72 +411,47 @@ def _matmul_array_vals():
411411 x .__imatmul__ (y )
412412
413413
414- @pytest .mark .parametrize (
415- "op" ,
416- [
417- op for op , dtypes in binary_op_dtypes .items ()
418- if dtypes not in ("real numeric" , "floating-point" )
419- ],
420- )
421- def test_binary_operators_vs_numpy_int (op ):
422- """np.int64 is not a subclass of int and must be disallowed"""
423- a = asarray (1 )
424- i64 = np .int64 (1 )
425- with pytest .raises (TypeError , match = "Expected Array or Python scalar" ):
426- getattr (a , op )(i64 )
427-
428-
429- @pytest .mark .parametrize (
430- "op" ,
431- [
432- op for op , dtypes in binary_op_dtypes .items ()
433- if dtypes not in ("integer" , "integer or boolean" )
434- ],
435- )
436- def test_binary_operators_vs_numpy_float (op ):
437- """
438- np.float64 is a subclass of float and must be allowed.
439- np.float32 is not and must be rejected.
440- """
441- a = asarray (1. )
442- f64 = np .float64 (1. )
443- f32 = np .float32 (1. )
444- func = getattr (a , op )
445- for op in binary_op_dtypes :
446- assert isinstance (func (f64 ), Array )
447- with pytest .raises (TypeError , match = "Expected Array or Python scalar" ):
448- func (f32 )
449-
450-
451- @pytest .mark .parametrize (
452- "op" ,
453- [
454- op for op , dtypes in binary_op_dtypes .items ()
455- if dtypes not in ("integer" , "integer or boolean" , "real numeric" )
456- ],
457- )
458- def test_binary_operators_vs_numpy_complex (op ):
459- """
460- np.complex128 is a subclass of complex and must be allowed.
461- np.complex64 is not and must be rejected.
414+ @pytest .mark .parametrize ("op,dtypes" , binary_op_dtypes .items ())
415+ def test_binary_operators_vs_numpy_generics (op , dtypes ):
416+ """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128
417+ are disallowed in binary operators.
418+ np.float64 and np.complex128 are subclasses of float and complex, so they need
419+ special treatment in order to be rejected.
462420 """
463- a = asarray (1. )
464- c64 = np .complex64 (1. )
465- c128 = np .complex128 (1. )
466- func = getattr (a , op )
467- for op in binary_op_dtypes :
468- assert isinstance (func (c128 ), Array )
469- with pytest .raises (TypeError , match = "Expected Array or Python scalar" ):
470- func (c64 )
421+ match = "Expected Array or Python scalar"
422+
423+ if dtypes not in ("numeric" , "integer" , "real numeric" , "floating-point" ):
424+ a = asarray (True )
425+ func = getattr (a , op )
426+ with pytest .raises (TypeError , match = match ):
427+ func (np .bool_ (True ))
428+
429+ if dtypes != "floating-point" :
430+ a = asarray (1 )
431+ func = getattr (a , op )
432+ with pytest .raises (TypeError , match = match ):
433+ func (np .int64 (1 ))
434+
435+ if dtypes not in ("integer" , "integer or boolean" ):
436+ a = asarray (1. ,)
437+ func = getattr (a , op )
438+ with pytest .raises (TypeError , match = match ):
439+ func (np .float32 (1. ))
440+ with pytest .raises (TypeError , match = match ):
441+ func (np .float64 (1. ))
442+
443+ if dtypes not in ("integer" , "integer or boolean" , "real numeric" ):
444+ a = asarray (1. ,)
445+ func = getattr (a , op )
446+ with pytest .raises (TypeError , match = match ):
447+ func (np .complex64 (1. ))
448+ with pytest .raises (TypeError , match = match ):
449+ func (np .complex128 (1. ))
471450
472451
473452@pytest .mark .parametrize ("op,dtypes" , binary_op_dtypes .items ())
474453def test_binary_operators_device_mismatch (op , dtypes ):
475- if dtypes in ("real numeric" , "floating-point" ):
476- dtype = float64
477- else :
478- dtype = int64
479-
454+ dtype = float64 if dtypes == "floating-point" else int64
480455 a = asarray (1 , dtype = dtype , device = CPU_DEVICE )
481456 b = asarray (1 , dtype = dtype , device = Device ("device1" ))
482457 with pytest .raises (ValueError , match = "different devices" ):
0 commit comments