@@ -544,6 +544,7 @@ def setdiff1d(
544544 / ,
545545 * ,
546546 assume_unique : bool = False ,
547+ size : int | None = None ,
547548 fill_value : object | None = None ,
548549 xp : ModuleType | None = None ,
549550) -> Array :
@@ -561,11 +562,16 @@ def setdiff1d(
561562 assume_unique : bool
562563 If ``True``, the input arrays are both assumed to be unique, which
563564 can speed up the calculation. Default is ``False``.
564- fill_value : object, optional
565- Pad the output array with this value.
565+ size : int, optional
566+ The size of the output array. This is exclusively used inside the JAX JIT, and
567+ only for as long as JAX does not support arrays of unknown size inside it. In
568+ all other cases, it is disregarded.
569+ Returned elements will be clipped if they are more than size, and padded with
570+ `fill_value` if they are less. Default: raise if inside ``jax.jit``.
566571
567- This is exclusively used for JAX arrays when running inside ``jax.jit``,
568- where all array shapes need to be known in advance.
572+ fill_value : object, optional
573+ Pad the output array with this value. This is exclusively used for JAX arrays
574+ when running inside ``jax.jit``. Default: 0.
569575 xp : array_namespace, optional
570576 The standard-compatible namespace for `x1` and `x2`. Default: infer.
571577
@@ -630,7 +636,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
630636 return x1 if assume_unique else xp .unique_values (x1 )
631637
632638 def _jax_jit_impl (
633- x1 : Array , x2 : Array , fill_value : object | None
639+ x1 : Array , x2 : Array , size : int | None , fill_value : object | None
634640 ) -> Array : # numpydoc ignore=PR01,RT01
635641 """
636642 JAX implementation inside jax.jit.
@@ -639,9 +645,9 @@ def _jax_jit_impl(
639645 and not being able to filter by a boolean mask.
640646 Returns array the same size as x1, padded with fill_value.
641647 """
642- # unique_values inside jax.jit is not supported unless it's got a fixed size
643- mask = _x1_not_in_x2 ( x1 , x2 )
644-
648+ if size is None :
649+ msg = "`size` is mandatory when running inside `jax.jit`."
650+ raise ValueError ( msg )
645651 if fill_value is None :
646652 fill_value = xp .zeros ((), dtype = x1 .dtype )
647653 else :
@@ -650,9 +656,13 @@ def _jax_jit_impl(
650656 msg = "`fill_value` must be a scalar."
651657 raise ValueError (msg )
652658
659+ # unique_values inside jax.jit is not supported unless it's got a fixed size
660+ mask = _x1_not_in_x2 (x1 , x2 )
653661 x1 = xp .where (mask , x1 , fill_value )
654- # Note: jnp.unique_values sorts
655- return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
662+ # Move fill_value to the right
663+ x1 = xp .take (x1 , xp .argsort (~ mask , stable = True ))
664+ x1 = x1 [:size ]
665+ x1 = xp .unique_values (x1 , size = size , fill_value = fill_value )
656666
657667 if is_dask_namespace (xp ):
658668 return _dask_impl (x1 , x2 )
@@ -666,7 +676,7 @@ def _jax_jit_impl(
666676 jax .errors .ConcretizationTypeError ,
667677 jax .errors .NonConcreteBooleanIndexError ,
668678 ):
669- return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
679+ return _jax_jit_impl (x1 , x2 , size , fill_value ) # inside jax.jit
670680
671681 return _generic_impl (x1 , x2 )
672682
0 commit comments