1111
1212from ._at import at
1313from ._utils import _compat , _helpers
14- from ._utils ._compat import array_namespace , is_jax_array
14+ from ._utils ._compat import (
15+ array_namespace ,
16+ is_dask_namespace ,
17+ is_jax_array ,
18+ is_jax_namespace ,
19+ )
1520from ._utils ._helpers import asarrays
1621from ._utils ._typing import Array
1722
@@ -547,6 +552,7 @@ def setdiff1d(
547552 / ,
548553 * ,
549554 assume_unique : bool = False ,
555+ fill_value : object | None = None ,
550556 xp : ModuleType | None = None ,
551557) -> Array :
552558 """
@@ -563,6 +569,11 @@ def setdiff1d(
563569 assume_unique : bool
564570 If ``True``, the input arrays are both assumed to be unique, which
565571 can speed up the calculation. Default is ``False``.
572+ fill_value : object, optional
573+ Pad the output array with this value.
574+
575+ This is exclusively used for JAX arrays when running inside ``jax.jit``,
576+ where all array shapes need to be known in advance.
566577 xp : array_namespace, optional
567578 The standard-compatible namespace for `x1` and `x2`. Default: infer.
568579
@@ -587,13 +598,86 @@ def setdiff1d(
587598 xp = array_namespace (x1 , x2 )
588599 x1 , x2 = asarrays (x1 , x2 , xp = xp )
589600
590- if assume_unique :
591- x1 = xp .reshape (x1 , (- 1 ,))
592- x2 = xp .reshape (x2 , (- 1 ,))
593- else :
594- x1 = xp .unique_values (x1 )
595- x2 = xp .unique_values (x2 )
596- return x1 [_helpers .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
601+ x1 = xp .reshape (x1 , (- 1 ,))
602+ x2 = xp .reshape (x2 , (- 1 ,))
603+ if x1 .shape == (0 ,) or x2 .shape == (0 ,):
604+ return x1
605+
606+ def _x1_not_in_x2 (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
607+ """For each element of x1, return True if it is not also in x2."""
608+ # Even when assume_unique=True, there is no provision for x to be sorted
609+ x2 = xp .sort (x2 )
610+ idx = xp .searchsorted (x2 , x1 )
611+
612+ # FIXME at() is faster but needs JAX jit support for bool mask
613+ # idx = at(idx, idx == x2.shape[0]).set(0)
614+ idx = xp .where (idx == x2 .shape [0 ], xp .zeros_like (idx ), idx )
615+
616+ return xp .take (x2 , idx , axis = 0 ) != x1
617+
618+ def _generic_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
619+ """Generic implementation (including eager JAX)."""
620+ # Note: there is no provision in the Array API for xp.unique_values to sort
621+ if not assume_unique :
622+ # Call unique_values early to speed up the algorithm
623+ x1 = xp .unique_values (x1 )
624+ x2 = xp .unique_values (x2 )
625+ mask = _x1_not_in_x2 (x1 , x2 )
626+ x1 = x1 [mask ]
627+ return x1 if assume_unique else xp .sort (x1 )
628+
629+ def _dask_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
630+ """
631+ Dask implementation.
632+
633+ Works around unique_values returning unknown shapes.
634+ """
635+ # Do not call unique_values yet, as it would make array shapes unknown
636+ mask = _x1_not_in_x2 (x1 , x2 )
637+ x1 = x1 [mask ]
638+ # Note: da.unique_values sorts
639+ return x1 if assume_unique else xp .unique_values (x1 )
640+
641+ def _jax_jit_impl (
642+ x1 : Array , x2 : Array , fill_value : object | None
643+ ) -> Array : # numpydoc ignore=PR01,RT01
644+ """
645+ JAX implementation inside jax.jit.
646+
647+ Works around unique_values requiring a size= parameter
648+ and not being able to filter by a boolean mask.
649+ Returns array the same size as x1, padded with fill_value.
650+ """
651+ # unique_values inside jax.jit is not supported unless it's got a fixed size
652+ mask = _x1_not_in_x2 (x1 , x2 )
653+
654+ if fill_value is None :
655+ fill_value = xp .zeros ((), dtype = x1 .dtype )
656+ else :
657+ fill_value = xp .asarray (fill_value , dtype = x1 .dtype )
658+ if cast (Array , fill_value ).ndim != 0 :
659+ msg = "`fill_value` must be a scalar."
660+ raise ValueError (msg )
661+
662+ x1 = xp .where (mask , x1 , fill_value )
663+ # Note: jnp.unique_values sorts
664+ return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
665+
666+ if is_dask_namespace (xp ):
667+ return _dask_impl (x1 , x2 )
668+
669+ if is_jax_namespace (xp ):
670+ import jax
671+
672+ try :
673+ return _generic_impl (x1 , x2 ) # eager mode
674+ except (
675+ jax .errors .ConcretizationTypeError ,
676+ jax .errors .NonConcreteBooleanIndexError ,
677+ ):
678+ return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
679+
680+ return _generic_impl (x1 , x2 )
597681
598682
599683def sinc (x : Array , / , * , xp : ModuleType | None = None ) -> Array :
0 commit comments