@@ -35,6 +35,13 @@ def pytest_addoption(parser):
3535 default = [],
3636 help = "disable testing for Array API extension(s)" ,
3737 )
38+ # data-dependent shape
39+ parser .addoption (
40+ "--disable-data-dependent-shapes" ,
41+ "--disable-dds" ,
42+ action = "store_true" ,
43+ help = "disable testing functions with output shapes dependent on input" ,
44+ )
3845 # CI
3946 parser .addoption (
4047 "--ci" ,
@@ -47,6 +54,9 @@ def pytest_configure(config):
4754 config .addinivalue_line (
4855 "markers" , "xp_extension(ext): tests an Array API extension"
4956 )
57+ config .addinivalue_line (
58+ "markers" , "data_dependent_shapes: output shapes are dependent on inputs"
59+ )
5060 config .addinivalue_line ("markers" , "ci: primary test" )
5161 # Hypothesis
5262 hypothesis_max_examples = config .getoption ("--hypothesis-max-examples" )
@@ -83,9 +93,15 @@ def xp_has_ext(ext: str) -> bool:
8393
8494def pytest_collection_modifyitems (config , items ):
8595 disabled_exts = config .getoption ("--disable-extension" )
96+ disabled_dds = config .getoption ("--disable-data-dependent-shapes" )
8697 ci = config .getoption ("--ci" )
8798 for item in items :
8899 markers = list (item .iter_markers ())
100+ # skip if specified in skips.txt
101+ for id_ in skip_ids :
102+ if item .nodeid .startswith (id_ ):
103+ item .add_marker (mark .skip (reason = "skips.txt" ))
104+ break
89105 # skip if disabled or non-existent extension
90106 ext_mark = next ((m for m in markers if m .name == "xp_extension" ), None )
91107 if ext_mark is not None :
@@ -96,11 +112,14 @@ def pytest_collection_modifyitems(config, items):
96112 )
97113 elif not xp_has_ext (ext ):
98114 item .add_marker (mark .skip (reason = f"{ ext } not found in array module" ))
99- # skip if specified in skips.txt
100- for id_ in skip_ids :
101- if item .nodeid .startswith (id_ ):
102- item .add_marker (mark .skip (reason = "skips.txt" ))
103- break
115+ # skip if disabled by dds flag
116+ if disabled_dds :
117+ for m in markers :
118+ if m .name == "data_dependent_shapes" :
119+ item .add_marker (
120+ mark .skip (reason = "disabled via --disable-data-dependent-shapes" )
121+ )
122+ break
104123 # skip if test not appropiate for CI
105124 if ci :
106125 ci_mark = next ((m for m in markers if m .name == "ci" ), None )
0 commit comments