1+ """
2+ Tests for special cases.
3+
4+ Most test cases for special casing are built on runtime via the parametrized
5+ tests test_unary/test_binary/test_iop. Most of this file consists of utility
6+ classes and functions, all bought together to create the test cases (pytest
7+ params), to finally be run through generalised test logic.
8+
9+ TODO: test integer arrays for relevant special cases
10+ """
111# We use __future__ for forward reference type hints - this will work for even py3.8.0
212# See https://stackoverflow.com/a/33533514/5193926
313from __future__ import annotations
3242
3343pytestmark = pytest .mark .ci
3444
35- # The special case test casess are built on runtime via the parametrized
36- # test_unary and test_binary functions. Most of this file consists of utility
37- # classes and functions, all bought together to create the test cases (pytest
38- # params), to finally be run through the general test logic of either test_unary
39- # or test_binary.
40-
41-
4245UnaryCheck = Callable [[float ], bool ]
4346BinaryCheck = Callable [[float , float ], bool ]
4447
@@ -170,24 +173,6 @@ def parse_value(value_str: str) -> float:
170173r_approx_value = re .compile (
171174 rf"an implementation-dependent approximation to { r_code .pattern } "
172175)
173-
174-
175- def parse_inline_code (inline_code : str ) -> float :
176- """
177- Parses a Sphinx code string to return a float, e.g.
178-
179- >>> parse_value('``0``')
180- 0.
181- >>> parse_value('``NaN``')
182- float('nan')
183-
184- """
185- if m := r_code .match (inline_code ):
186- return parse_value (m .group (1 ))
187- else :
188- raise ParseError (inline_code )
189-
190-
191176r_not = re .compile ("not (.+)" )
192177r_equal_to = re .compile (f"equal to { r_code .pattern } " )
193178r_array_element = re .compile (r"``([+-]?)x([12])_i``" )
@@ -526,6 +511,10 @@ def __repr__(self) -> str:
526511 return f"{ self .__class__ .__name__ } (<{ self } >)"
527512
528513
514+ r_case_block = re .compile (r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters" )
515+ r_case = re .compile (r"\s+-\s*(.*)\." )
516+
517+
529518class UnaryCond (Protocol ):
530519 def __call__ (self , i : float ) -> bool :
531520 ...
@@ -546,12 +535,34 @@ class UnaryCase(Case):
546535
547536
548537r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
538+ r_already_int_case = re .compile (
539+ "If ``x_i`` is already integer-valued, the result is ``x_i``"
540+ )
549541r_even_round_halves_case = re .compile (
550542 "If two integers are equally close to ``x_i``, "
551543 "the result is the even integer closest to ``x_i``"
552544)
553545
554546
547+ def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
548+ """
549+ Returns a strategy that generates float-casted integers within the bounds of dtype.
550+ """
551+ for k in kw .keys ():
552+ # sanity check
553+ assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
554+ m , M = dh .dtype_ranges [dtype ]
555+ if "min_value" in kw .keys ():
556+ m = kw ["min_value" ]
557+ if "exclude_min" in kw .keys ():
558+ m += 1
559+ if "max_value" in kw .keys ():
560+ M = kw ["max_value" ]
561+ if "exclude_max" in kw .keys ():
562+ M -= 1
563+ return st .integers (math .ceil (m ), math .floor (M )).map (float )
564+
565+
555566def trailing_halves_from_dtype (dtype : DataType ) -> st .SearchStrategy [float ]:
556567 """
557568 Returns a strategy that generates floats that end with .5 and are within the
@@ -568,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
568579 )
569580
570581
582+ already_int_case = UnaryCase (
583+ cond_expr = "x_i.is_integer()" ,
584+ cond = lambda i : i .is_integer (),
585+ cond_from_dtype = integers_from_dtype ,
586+ result_expr = "x_i" ,
587+ check_result = lambda i , result : i == result ,
588+ )
571589even_round_halves_case = UnaryCase (
572590 cond_expr = "modf(i)[0] == 0.5" ,
573591 cond = lambda i : math .modf (i )[0 ] == 0.5 ,
@@ -586,7 +604,7 @@ def check_result(i: float, result: float) -> bool:
586604 return check_result
587605
588606
589- def parse_unary_docstring ( docstring : str ) -> List [UnaryCase ]:
607+ def parse_unary_case_block ( case_block : str ) -> List [UnaryCase ]:
590608 """
591609 Parses a Sphinx-formatted docstring of a unary function to return a list of
592610 codified unary cases, e.g.
@@ -616,7 +634,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
616634 ... an array containing the square root of each element in ``x``
617635 ... '''
618636 ...
619- >>> unary_cases = parse_unary_docstring(sqrt.__doc__)
637+ >>> case_block = r_case_block.search(sqrt.__doc__).group(1)
638+ >>> unary_cases = parse_unary_case_block(case_block)
620639 >>> for case in unary_cases:
621640 ... print(repr(case))
622641 UnaryCase(<x_i < 0 -> NaN>)
@@ -631,19 +650,14 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
631650 True
632651
633652 """
634-
635- match = r_special_cases .search (docstring )
636- if match is None :
637- return []
638- lines = match .group (1 ).split ("\n " )[:- 1 ]
639653 cases = []
640- for line in lines :
641- if m := r_case . match ( line ):
642- case = m . group ( 1 )
643- else :
644- warn ( f"line not machine-readable: ' { line } '" )
645- continue
646- if m := r_unary_case .search (case ):
654+ for case_m in r_case . finditer ( case_block ) :
655+ case_str = case_m . group ( 1 )
656+ if m := r_already_int_case . search ( case_str ):
657+ cases . append ( already_int_case )
658+ elif m := r_even_round_halves_case . search ( case_str ):
659+ cases . append ( even_round_halves_case )
660+ elif m := r_unary_case .search (case_str ):
647661 try :
648662 cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
649663 _check_result , result_expr = parse_result (m .group (2 ))
@@ -662,11 +676,9 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
662676 check_result = check_result ,
663677 )
664678 cases .append (case )
665- elif m := r_even_round_halves_case .search (case ):
666- cases .append (even_round_halves_case )
667679 else :
668- if not r_remaining_case .search (case ):
669- warn (f"case not machine-readable: '{ case } '" )
680+ if not r_remaining_case .search (case_str ):
681+ warn (f"case not machine-readable: '{ case_str } '" )
670682 return cases
671683
672684
@@ -690,12 +702,6 @@ class BinaryCase(Case):
690702 check_result : BinaryResultCheck
691703
692704
693- r_special_cases = re .compile (
694- r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
695- r"For floating-point operands,\n+"
696- r"((?:\s*-\s*.*\n)+)"
697- )
698- r_case = re .compile (r"\s+-\s*(.*)\.\n?" )
699705r_binary_case = re .compile ("If (.+), the result (.+)" )
700706r_remaining_case = re .compile ("In the remaining cases.+" )
701707r_cond_sep = re .compile (r"(?<!``x1_i``),? and |(?<!i\.e\.), " )
@@ -843,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
843849 return check_result
844850
845851
846- def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
847- """
848- Returns a strategy that generates float-casted integers within the bounds of dtype.
849- """
850- for k in kw .keys ():
851- # sanity check
852- assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
853- m , M = dh .dtype_ranges [dtype ]
854- if "min_value" in kw .keys ():
855- m = kw ["min_value" ]
856- if "exclude_min" in kw .keys ():
857- m += 1
858- if "max_value" in kw .keys ():
859- M = kw ["max_value" ]
860- if "exclude_max" in kw .keys ():
861- M -= 1
862- return st .integers (math .ceil (m ), math .floor (M )).map (float )
863-
864-
865852def parse_binary_case (case_str : str ) -> BinaryCase :
866853 """
867854 Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
@@ -880,8 +867,7 @@ def parse_binary_case(case_str: str) -> BinaryCase:
880867
881868 """
882869 case_m = r_binary_case .match (case_str )
883- if case_m is None :
884- raise ParseError (case_str )
870+ assert case_m is not None # sanity check
885871 cond_strs = r_cond_sep .split (case_m .group (1 ))
886872
887873 partial_conds = []
@@ -1078,7 +1064,7 @@ def cond(i1: float, i2: float) -> bool:
10781064r_redundant_case = re .compile ("result.+determined by the rule already stated above" )
10791065
10801066
1081- def parse_binary_docstring ( docstring : str ) -> List [BinaryCase ]:
1067+ def parse_binary_case_block ( case_block : str ) -> List [BinaryCase ]:
10821068 """
10831069 Parses a Sphinx-formatted docstring of a binary function to return a list of
10841070 codified binary cases, e.g.
@@ -1108,29 +1094,21 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11081094 ... an array containing the results
11091095 ... '''
11101096 ...
1111- >>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
1097+ >>> case_block = r_case_block.search(logaddexp.__doc__).group(1)
1098+ >>> binary_cases = parse_binary_case_block(case_block)
11121099 >>> for case in binary_cases:
11131100 ... print(repr(case))
11141101 BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
11151102 BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
11161103 BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)
11171104
11181105 """
1119-
1120- match = r_special_cases .search (docstring )
1121- if match is None :
1122- return []
1123- lines = match .group (1 ).split ("\n " )[:- 1 ]
11241106 cases = []
1125- for line in lines :
1126- if m := r_case .match (line ):
1127- case_str = m .group (1 )
1128- else :
1129- warn (f"line not machine-readable: '{ line } '" )
1130- continue
1107+ for case_m in r_case .finditer (case_block ):
1108+ case_str = case_m .group (1 )
11311109 if r_redundant_case .search (case_str ):
11321110 continue
1133- if m := r_binary_case .match (case_str ):
1111+ if r_binary_case .match (case_str ):
11341112 try :
11351113 case = parse_binary_case (case_str )
11361114 cases .append (case )
@@ -1150,6 +1128,10 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11501128 if stub .__doc__ is None :
11511129 warn (f"{ stub .__name__ } () stub has no docstring" )
11521130 continue
1131+ if m := r_case_block .search (stub .__doc__ ):
1132+ case_block = m .group (1 )
1133+ else :
1134+ continue
11531135 marks = []
11541136 try :
11551137 func = getattr (xp , stub .__name__ )
@@ -1164,40 +1146,44 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11641146 warn (f"{ func = } has no parameters" )
11651147 continue
11661148 if param_names [0 ] == "x" :
1167- if cases := parse_unary_docstring ( stub . __doc__ ):
1168- func_name_to_func = {stub .__name__ : func }
1149+ if cases := parse_unary_case_block ( case_block ):
1150+ name_to_func = {stub .__name__ : func }
11691151 if stub .__name__ in func_to_op .keys ():
11701152 op_name = func_to_op [stub .__name__ ]
11711153 op = getattr (operator , op_name )
1172- func_name_to_func [op_name ] = op
1173- for func_name , func in func_name_to_func .items ():
1154+ name_to_func [op_name ] = op
1155+ for func_name , func in name_to_func .items ():
11741156 for case in cases :
11751157 id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
11761158 p = pytest .param (func_name , func , case , id = id_ )
11771159 unary_params .append (p )
1160+ else :
1161+ warn (f"Special cases found for { stub .__name__ } but none were parsed" )
11781162 continue
11791163 if len (sig .parameters ) == 1 :
11801164 warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
11811165 continue
11821166 if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1183- if cases := parse_binary_docstring ( stub . __doc__ ):
1184- func_name_to_func = {stub .__name__ : func }
1167+ if cases := parse_binary_case_block ( case_block ):
1168+ name_to_func = {stub .__name__ : func }
11851169 if stub .__name__ in func_to_op .keys ():
11861170 op_name = func_to_op [stub .__name__ ]
11871171 op = getattr (operator , op_name )
1188- func_name_to_func [op_name ] = op
1189- # We collect inplaceoperator test cases seperately
1172+ name_to_func [op_name ] = op
1173+ # We collect inplace operator test cases seperately
11901174 iop_name = "__i" + op_name [2 :]
11911175 iop = getattr (operator , iop_name )
11921176 for case in cases :
11931177 id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
11941178 p = pytest .param (iop_name , iop , case , id = id_ )
11951179 iop_params .append (p )
1196- for func_name , func in func_name_to_func .items ():
1180+ for func_name , func in name_to_func .items ():
11971181 for case in cases :
11981182 id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
11991183 p = pytest .param (func_name , func , case , id = id_ )
12001184 binary_params .append (p )
1185+ else :
1186+ warn (f"Special cases found for { stub .__name__ } but none were parsed" )
12011187 continue
12021188 else :
12031189 warn (
@@ -1206,7 +1192,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
12061192 )
12071193
12081194
1209- # test_unary and test_binary naively generate arrays, i.e. arrays that might not
1195+ # test_{unary/binary/iop} naively generate arrays, i.e. arrays that might not
12101196# meet the condition that is being test. We then forcibly make the array meet
12111197# the condition by picking a random index to insert an acceptable element.
12121198#
@@ -1343,3 +1329,46 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
13431329 )
13441330 break
13451331 assume (good_example )
1332+
1333+
1334+ @pytest .mark .parametrize (
1335+ "func_name, expected" ,
1336+ [
1337+ ("mean" , float ("nan" )),
1338+ ("prod" , 1 ),
1339+ ("std" , float ("nan" )),
1340+ ("sum" , 0 ),
1341+ ("var" , float ("nan" )),
1342+ ],
1343+ ids = ["mean" , "prod" , "std" , "sum" , "var" ],
1344+ )
1345+ def test_empty_arrays (func_name , expected ): # TODO: parse docstrings to get expected
1346+ func = getattr (xp , func_name )
1347+ out = func (xp .asarray ([], dtype = dh .default_float ))
1348+ ph .assert_shape (func_name , out .shape , ()) # sanity check
1349+ msg = f"{ out = !r} , but should be { expected } "
1350+ if math .isnan (expected ):
1351+ assert xp .isnan (out ), msg
1352+ else :
1353+ assert out == expected , msg
1354+
1355+
1356+ @pytest .mark .parametrize (
1357+ "func_name" , [f .__name__ for f in category_to_funcs ["statistical" ]]
1358+ )
1359+ @given (
1360+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
1361+ data = st .data (),
1362+ )
1363+ def test_nan_propagation (func_name , x , data ):
1364+ func = getattr (xp , func_name )
1365+ set_idx = data .draw (
1366+ xps .indices (x .shape , max_dims = 0 , allow_ellipsis = False ), label = "set idx"
1367+ )
1368+ x [set_idx ] = float ("nan" )
1369+ note (f"{ x = } " )
1370+
1371+ out = func (x )
1372+
1373+ ph .assert_shape (func_name , out .shape , ()) # sanity check
1374+ assert xp .isnan (out ), f"{ out = !r} , but should be NaN"
0 commit comments