1111
1212from ._at import at
1313from ._utils import _compat , _helpers
14- from ._utils ._compat import array_namespace , is_jax_array
14+ from ._utils ._compat import (
15+ array_namespace ,
16+ is_dask_namespace ,
17+ is_jax_array ,
18+ is_jax_namespace ,
19+ )
1520from ._utils ._typing import Array
1621
1722__all__ = [
@@ -539,6 +544,7 @@ def setdiff1d(
539544 / ,
540545 * ,
541546 assume_unique : bool = False ,
547+ fill_value : object | None = None ,
542548 xp : ModuleType | None = None ,
543549) -> Array :
544550 """
@@ -555,6 +561,11 @@ def setdiff1d(
555561 assume_unique : bool
556562 If ``True``, the input arrays are both assumed to be unique, which
557563 can speed up the calculation. Default is ``False``.
564+ fill_value : object, optional
565+ Pad the output array with this value.
566+
567+ This is exclusively used for JAX arrays when running inside ``jax.jit``,
568+ where all array shapes need to be known in advance.
558569 xp : array_namespace, optional
559570 The standard-compatible namespace for `x1` and `x2`. Default: infer.
560571
@@ -578,12 +589,86 @@ def setdiff1d(
578589 if xp is None :
579590 xp = array_namespace (x1 , x2 )
580591
581- if assume_unique :
582- x1 = xp .reshape (x1 , (- 1 ,))
583- else :
584- x1 = xp .unique_values (x1 )
585- x2 = xp .unique_values (x2 )
586- return x1 [_helpers .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
592+ x1 = xp .reshape (x1 , (- 1 ,))
593+ x2 = xp .reshape (x2 , (- 1 ,))
594+ if x1 .shape == (0 ,) or x2 .shape == (0 ,):
595+ return x1
596+
597+ def _x1_not_in_x2 (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
598+ """For each element of x1, return True if it is not also in x2."""
599+ # Even when assume_unique=True, there is no provision for x to be sorted
600+ x2 = xp .sort (x2 )
601+ idx = xp .searchsorted (x2 , x1 )
602+
603+ # FIXME at() is faster but needs JAX jit support for bool mask
604+ # idx = at(idx, idx == x2.shape[0]).set(0)
605+ idx = xp .where (idx == x2 .shape [0 ], xp .zeros_like (idx ), idx )
606+
607+ return xp .take (x2 , idx , axis = 0 ) != x1
608+
609+ def _generic_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
610+ """Generic implementation (including eager JAX)."""
611+ # Note: there is no provision in the Array API for xp.unique_values to sort
612+ if not assume_unique :
613+ # Call unique_values early to speed up the algorithm
614+ x1 = xp .unique_values (x1 )
615+ x2 = xp .unique_values (x2 )
616+ mask = _x1_not_in_x2 (x1 , x2 )
617+ x1 = x1 [mask ]
618+ return x1 if assume_unique else xp .sort (x1 )
619+
620+ def _dask_impl (x1 : Array , x2 : Array ) -> Array : # numpydoc ignore=PR01,RT01
621+ """
622+ Dask implementation.
623+
624+ Works around unique_values returning unknown shapes.
625+ """
626+ # Do not call unique_values yet, as it would make array shapes unknown
627+ mask = _x1_not_in_x2 (x1 , x2 )
628+ x1 = x1 [mask ]
629+ # Note: da.unique_values sorts
630+ return x1 if assume_unique else xp .unique_values (x1 )
631+
632+ def _jax_jit_impl (
633+ x1 : Array , x2 : Array , fill_value : object | None
634+ ) -> Array : # numpydoc ignore=PR01,RT01
635+ """
636+ JAX implementation inside jax.jit.
637+
638+ Works around unique_values requiring a size= parameter
639+ and not being able to filter by a boolean mask.
640+ Returns array the same size as x1, padded with fill_value.
641+ """
642+ # unique_values inside jax.jit is not supported unless it's got a fixed size
643+ mask = _x1_not_in_x2 (x1 , x2 )
644+
645+ if fill_value is None :
646+ fill_value = xp .zeros ((), dtype = x1 .dtype )
647+ else :
648+ fill_value = xp .asarray (fill_value , dtype = x1 .dtype )
649+ if cast (Array , fill_value ).ndim != 0 :
650+ msg = "`fill_value` must be a scalar."
651+ raise ValueError (msg )
652+
653+ x1 = xp .where (mask , x1 , fill_value )
654+ # Note: jnp.unique_values sorts
655+ return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
656+
657+ if is_dask_namespace (xp ):
658+ return _dask_impl (x1 , x2 )
659+
660+ if is_jax_namespace (xp ):
661+ import jax
662+
663+ try :
664+ return _generic_impl (x1 , x2 ) # eager mode
665+ except (
666+ jax .errors .ConcretizationTypeError ,
667+ jax .errors .NonConcreteBooleanIndexError ,
668+ ):
669+ return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
670+
671+ return _generic_impl (x1 , x2 )
587672
588673
589674def sinc (x : Array , / , * , xp : ModuleType | None = None ) -> Array :
0 commit comments