@@ -153,6 +153,41 @@ def test_setitem(shape, dtypes, data):
153153 )
154154
155155
156+ class AwkwardIndexable :
157+ def __init__ (self , value : int ):
158+ self ._value = value
159+
160+ def __int__ (self ):
161+ raise TypeError ("__int__() should not be called" )
162+
163+ def __index__ (self ):
164+ return self ._value
165+
166+
167+ @pytest .mark .parametrize (
168+ "x, key" ,
169+ [
170+ (xp .asarray ([0 , 1 ]), AwkwardIndexable (1 )),
171+ (xp .asarray ([[0 , 1 ], [2 , 3 ]]), (0 , AwkwardIndexable (1 ))),
172+ ]
173+ )
174+ def test_getitem_supports_index (x , key ):
175+ out = x [key ]
176+ assert out == xp .asarray (1 )
177+
178+
179+ @pytest .mark .parametrize (
180+ "x, key, expected" ,
181+ [
182+ (xp .asarray ([0 , 1 ]), AwkwardIndexable (1 ), xp .asarray ([0 , 42 ])),
183+ (xp .asarray ([[0 , 1 ], [2 , 3 ]]), (0 , AwkwardIndexable (1 )), xp .asarray ([[0 , 42 ], [2 , 3 ]])),
184+ ]
185+ )
186+ def test_setitem_supports_index (x , key , expected ):
187+ x [key ] = xp .asarray (42 )
188+ ph .assert_array_elements ("__setitem__" , out = x , expected = expected , out_repr = "x" )
189+
190+
156191@pytest .mark .unvectorized
157192@pytest .mark .data_dependent_shapes
158193@given (hh .shapes (), st .data ())
0 commit comments