99
1010import array_api_compat
1111from array_api_compat import array_namespace
12+ import array_api_compat .numpy
1213
1314from ._helpers import import_ , all_libraries , wrapped_libraries
1415
@@ -22,6 +23,7 @@ def test_array_namespace(library, api_version, use_compat):
2223 if use_compat is True and library in {'array_api_strict' , 'jax.numpy' , 'sparse' }:
2324 pytest .raises (ValueError , lambda : array_namespace (array , use_compat = use_compat ))
2425 return
26+ print (use_compat )
2527 namespace = array_api_compat .array_namespace (array , api_version = api_version , use_compat = use_compat )
2628
2729 if use_compat is False or use_compat is None and library not in wrapped_libraries :
@@ -36,6 +38,17 @@ def test_array_namespace(library, api_version, use_compat):
3638 assert namespace == jax .experimental .array_api
3739 else :
3840 assert namespace == xp
41+ elif use_compat is None :
42+ if library == "dask.array" :
43+ # dask should always return wrapped version
44+ # since dask.array is not array API compatible
45+ assert namespace == array_api_compat .dask .array
46+ elif library == "numpy" :
47+ assert namespace == array_api_compat .numpy
48+ elif library == "torch" :
49+ assert namespace == array_api_compat .torch
50+ else :
51+ assert namespace == xp
3952 else :
4053 if library == "dask.array" :
4154 assert namespace == array_api_compat .dask .array
0 commit comments