@@ -552,6 +552,7 @@ def setdiff1d(
552552 / ,
553553 * ,
554554 assume_unique : bool = False ,
555+ size : int | None = None ,
555556 fill_value : object | None = None ,
556557 xp : ModuleType | None = None ,
557558) -> Array :
@@ -569,11 +570,16 @@ def setdiff1d(
569570 assume_unique : bool
570571 If ``True``, the input arrays are both assumed to be unique, which
571572 can speed up the calculation. Default is ``False``.
572- fill_value : object, optional
573- Pad the output array with this value.
573+ size : int, optional
574+ The size of the output array. This is exclusively used inside the JAX JIT, and
575+ only for as long as JAX does not support arrays of unknown size inside it. In
576+ all other cases, it is disregarded.
577+ Returned elements will be clipped if they are more than size, and padded with
578+ `fill_value` if they are less. Default: raise if inside ``jax.jit``.
574579
575- This is exclusively used for JAX arrays when running inside ``jax.jit``,
576- where all array shapes need to be known in advance.
580+ fill_value : object, optional
581+ Pad the output array with this value. This is exclusively used for JAX arrays
582+ when running inside ``jax.jit``. Default: 0.
577583 xp : array_namespace, optional
578584 The standard-compatible namespace for `x1` and `x2`. Default: infer.
579585
@@ -639,7 +645,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
639645 return x1 if assume_unique else xp .unique_values (x1 )
640646
641647 def _jax_jit_impl (
642- x1 : Array , x2 : Array , fill_value : object | None
648+ x1 : Array , x2 : Array , size : int | None , fill_value : object | None
643649 ) -> Array : # numpydoc ignore=PR01,RT01
644650 """
645651 JAX implementation inside jax.jit.
@@ -648,9 +654,9 @@ def _jax_jit_impl(
648654 and not being able to filter by a boolean mask.
649655 Returns array the same size as x1, padded with fill_value.
650656 """
651- # unique_values inside jax.jit is not supported unless it's got a fixed size
652- mask = _x1_not_in_x2 ( x1 , x2 )
653-
657+ if size is None :
658+ msg = "`size` is mandatory when running inside `jax.jit`."
659+ raise ValueError ( msg )
654660 if fill_value is None :
655661 fill_value = xp .zeros ((), dtype = x1 .dtype )
656662 else :
@@ -659,9 +665,13 @@ def _jax_jit_impl(
659665 msg = "`fill_value` must be a scalar."
660666 raise ValueError (msg )
661667
668+ # unique_values inside jax.jit is not supported unless it's got a fixed size
669+ mask = _x1_not_in_x2 (x1 , x2 )
662670 x1 = xp .where (mask , x1 , fill_value )
663- # Note: jnp.unique_values sorts
664- return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
671+ # Move fill_value to the right
672+ x1 = xp .take (x1 , xp .argsort (~ mask , stable = True ))
673+ x1 = x1 [:size ]
674+ x1 = xp .unique_values (x1 , size = size , fill_value = fill_value )
665675
666676 if is_dask_namespace (xp ):
667677 return _dask_impl (x1 , x2 )
@@ -675,7 +685,7 @@ def _jax_jit_impl(
675685 jax .errors .ConcretizationTypeError ,
676686 jax .errors .NonConcreteBooleanIndexError ,
677687 ):
678- return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
688+ return _jax_jit_impl (x1 , x2 , size , fill_value ) # inside jax.jit
679689
680690 return _generic_impl (x1 , x2 )
681691
0 commit comments