@@ -300,31 +300,53 @@ def test_permute_dims(x, axes):
300300def test_repeat (x , kw , data ):
301301 shape = x .shape
302302 axis = kw .get ("axis" , None )
303- dim = math .prod (shape ) if axis is None else shape [axis ]
304- repeat_strat = st .integers (1 , 4 )
303+ size = math .prod (shape ) if axis is None else shape [axis ]
304+ repeat_strat = st .integers (1 , 10 )
305305 repeats = data .draw (repeat_strat
306306 | hh .arrays (dtype = hh .int_dtypes , elements = repeat_strat ,
307- shape = st .sampled_from ([(1 ,), (dim ,)])),
307+ shape = st .sampled_from ([(1 ,), (size ,)])),
308308 label = "repeats" )
309309 if isinstance (repeats , int ):
310- n_repitions = dim * repeats
310+ n_repititions = size * repeats
311311 else :
312312 if repeats .shape == (1 ,):
313- n_repitions = dim * repeats [0 ]
313+ n_repititions = size * int ( repeats [0 ])
314314 else :
315- n_repitions = int (xp .sum (repeats ))
315+ n_repititions = int (xp .sum (repeats ))
316+
317+ assume (n_repititions <= hh .SQRT_MAX_ARRAY_SIZE )
316318
317319 out = xp .repeat (x , repeats , ** kw )
318320 ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
319321 if axis is None :
320- expected_shape = (n_repitions ,)
322+ expected_shape = (n_repititions ,)
321323 else :
322324 expected_shape = list (shape )
323- expected_shape [axis ] = n_repitions
325+ expected_shape [axis ] = n_repititions
324326 expected_shape = tuple (expected_shape )
325327 ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
326- # TODO: values testing
327328
329+ # Test values
330+
331+ if isinstance (repeats , int ):
332+ repeats_array = xp .full (size , repeats , dtype = xp .int32 )
333+ else :
334+ repeats_array = repeats
335+
336+ if kw .get ("axis" ) is None :
337+ x = xp .reshape (x , (- 1 ,))
338+ axis = 0
339+
340+ for idx , in sh .iter_indices (x .shape , skip_axes = axis ):
341+ x_slice = x [idx ]
342+ out_slice = out [idx ]
343+ start = 0
344+ for i , count in enumerate (repeats_array ):
345+ end = start + count
346+ ph .assert_array_elements ("repeat" , out = out_slice [start :end ],
347+ expected = xp .full ((count ,), x_slice [i ], dtype = x .dtype ),
348+ kw = kw )
349+ start = end
328350
329351@st .composite
330352def reshape_shapes (draw , shape ):
0 commit comments