1515# limitations under the License.
1616import builtins
1717import operator
18+ from numbers import Integral
1819
1920import numpy as np
2021
@@ -799,6 +800,79 @@ def _nonzero_impl(ary):
799800 return res
800801
801802
803+ def _validate_indices (inds , queue_list , usm_type_list ):
804+ """
805+ Utility for validating indices are usm_ndarray of integral dtype or Python
806+ integers. At least one must be an array.
807+
808+ For each array, the queue and usm type are appended to `queue_list` and
809+ `usm_type_list`, respectively.
810+ """
811+ any_usmarray = False
812+ for ind in inds :
813+ if isinstance (ind , dpt .usm_ndarray ):
814+ any_usmarray = True
815+ if ind .dtype .kind not in "ui" :
816+ raise IndexError (
817+ "arrays used as indices must be of integer (or boolean) "
818+ "type"
819+ )
820+ queue_list .append (ind .sycl_queue )
821+ usm_type_list .append (ind .usm_type )
822+ elif not isinstance (ind , Integral ):
823+ raise TypeError (
824+ "all elements of `ind` expected to be usm_ndarrays "
825+ f"or integers, found { type (ind )} "
826+ )
827+ if not any_usmarray :
828+ raise TypeError (
829+ "at least one element of `inds` expected to be a usm_ndarray"
830+ )
831+ return inds
832+
833+
834+ def _prepare_indices_arrays (inds , q , usm_type ):
835+ """
836+ Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
837+ a queue (assumed to be common to arrays in inds), and a usm type.
838+
839+ Python scalar integers are promoted to arrays on the provided queue and
840+ with the provided usm type. All arrays are then promoted to a common
841+ integral type (if possible) before being broadcast to a common shape.
842+ """
843+ # scalar integers -> arrays
844+ inds = tuple (
845+ map (
846+ lambda ind : (
847+ ind
848+ if isinstance (ind , dpt .usm_ndarray )
849+ else dpt .asarray (ind , usm_type = usm_type , sycl_queue = q )
850+ ),
851+ inds ,
852+ )
853+ )
854+
855+ # promote to a common integral type if possible
856+ ind_dt = dpt .result_type (* inds )
857+ if ind_dt .kind not in "ui" :
858+ raise ValueError (
859+ "cannot safely promote indices to an integer data type"
860+ )
861+ inds = tuple (
862+ map (
863+ lambda ind : (
864+ ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
865+ ),
866+ inds ,
867+ )
868+ )
869+
870+ # broadcast
871+ inds = dpt .broadcast_arrays (* inds )
872+
873+ return inds
874+
875+
802876def _take_multi_index (ary , inds , p , mode = 0 ):
803877 if not isinstance (ary , dpt .usm_ndarray ):
804878 raise TypeError (
@@ -819,15 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
819893 ]
820894 if not isinstance (inds , (list , tuple )):
821895 inds = (inds ,)
822- for ind in inds :
823- if not isinstance (ind , dpt .usm_ndarray ):
824- raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
825- queues_ .append (ind .sycl_queue )
826- usm_types_ .append (ind .usm_type )
827- if ind .dtype .kind not in "ui" :
828- raise IndexError (
829- "arrays used as indices must be of integer (or boolean) type"
830- )
896+
897+ _validate_indices (inds , queues_ , usm_types_ )
831898 res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
832899 exec_q = dpctl .utils .get_execution_queue (queues_ )
833900 if exec_q is None :
@@ -837,22 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
837904 "Use `usm_ndarray.to_device` method to migrate data to "
838905 "be associated with the same queue."
839906 )
907+
840908 if len (inds ) > 1 :
841- ind_dt = dpt .result_type (* inds )
842- # ind arrays have been checked to be of integer dtype
843- if ind_dt .kind not in "ui" :
844- raise ValueError (
845- "cannot safely promote indices to an integer data type"
846- )
847- inds = tuple (
848- map (
849- lambda ind : (
850- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
851- ),
852- inds ,
853- )
854- )
855- inds = dpt .broadcast_arrays (* inds )
909+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
910+
856911 ind0 = inds [0 ]
857912 ary_sh = ary .shape
858913 p_end = p + len (inds )
@@ -968,15 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
9681023 ]
9691024 if not isinstance (inds , (list , tuple )):
9701025 inds = (inds ,)
971- for ind in inds :
972- if not isinstance (ind , dpt .usm_ndarray ):
973- raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
974- queues_ .append (ind .sycl_queue )
975- usm_types_ .append (ind .usm_type )
976- if ind .dtype .kind not in "ui" :
977- raise IndexError (
978- "arrays used as indices must be of integer (or boolean) type"
979- )
1026+
1027+ _validate_indices (inds , queues_ , usm_types_ )
1028+
9801029 vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
9811030 exec_q = dpctl .utils .get_execution_queue (queues_ )
9821031 if exec_q is not None :
@@ -993,22 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
9931042 "Use `usm_ndarray.to_device` method to migrate data to "
9941043 "be associated with the same queue."
9951044 )
1045+
9961046 if len (inds ) > 1 :
997- ind_dt = dpt .result_type (* inds )
998- # ind arrays have been checked to be of integer dtype
999- if ind_dt .kind not in "ui" :
1000- raise ValueError (
1001- "cannot safely promote indices to an integer data type"
1002- )
1003- inds = tuple (
1004- map (
1005- lambda ind : (
1006- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
1007- ),
1008- inds ,
1009- )
1010- )
1011- inds = dpt .broadcast_arrays (* inds )
1047+ inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1048+
10121049 ind0 = inds [0 ]
10131050 ary_sh = ary .shape
10141051 p_end = p + len (inds )
0 commit comments