55from typing import Any
66
77from arrayfire_wrapper .lib ._broadcast import bcast_var
8- from arrayfire_wrapper .lib .create_and_modify_array .manage_array import release_array
9-
8+ from arrayfire_wrapper .lib .create_and_modify_array .manage_array import release_array , retain_array
9+ from arrayfire_wrapper . defines import AFArray
1010
1111class _IndexSequence (ctypes .Structure ):
1212 """
@@ -186,7 +186,7 @@ class IndexStructure(ctypes.Structure):
186186 -----------
187187
188188 idx: key
189- - If of type af.Array , self.idx.arr = idx, self.isSeq = False
189+ - If of type AFArray , self.idx.arr = idx, self.isSeq = False
190190 - If of type af.ParallelRange, self.idx.seq = idx, self.isBatch = True
191191 - Default:, self.idx.seq = af._IndexSequence(idx)
192192
@@ -197,26 +197,21 @@ class IndexStructure(ctypes.Structure):
197197
198198 """
199199
200- def __init__ (self , idx : Any ) -> None :
200+ def __init__ (self , idx : int | slice | AFArray ) -> None :
201201 self .idx = _IndexUnion ()
202202 self .isBatch = False
203203 self .isSeq = True
204204
205- # BUG cyclic reimport
206- # if isinstance(idx, Array):
207- # if idx.dtype == af_bool:
208- # self.idx.arr = everything.where(idx.arr)
209- # else:
210- # self.idx.arr = everything.retain_array(idx.arr)
211-
212- # self.isSeq = False
213-
214- if isinstance (idx , ParallelRange ):
205+ if isinstance (idx , int ) or isinstance (idx , slice ):
206+ self .idx .seq = _IndexSequence (idx )
207+ elif isinstance (idx , ParallelRange ):
215208 self .idx .seq = idx
216209 self .isBatch = True
217-
210+ elif isinstance (idx , AFArray ):
211+ self .idx .arr = retain_array (idx )
212+ self .isSeq = False
218213 else :
219- self . idx . seq = _IndexSequence ( idx )
214+ raise IndexError ( "Invalid type while indexing arrayfire.array" )
220215
221216 def __del__ (self ) -> None :
222217 if not self .isSeq :
@@ -247,7 +242,7 @@ def __setitem__(self, idx: int, value: IndexStructure) -> None:
247242 self .idxs [idx ] = value
248243
249244
250- def get_indices (key : int | slice | tuple [int | slice , ...]) -> CIndexStructure : # BUG
245+ def get_indices (key : int | slice | tuple [int | slice | AFArray , ...] | AFArray ) -> CIndexStructure : # BUG
251246 indices = CIndexStructure ()
252247
253248 if isinstance (key , tuple ):
0 commit comments