1010from sklearn .utils import _safe_indexing
1111
1212from ..base import BaseUnderSampler
13- from ...dask ._support import is_dask_container
13+ from ...dask ._support import is_dask_collection
1414from ...utils import check_target_type
1515from ...utils import Substitution
16- from ...utils ._docstring import _random_state_docstring
16+ from ...utils ._docstring import (
17+ _random_state_docstring ,
18+ _validate_if_dask_collection_docstring
19+ )
1720from ...utils ._validation import _deprecate_positional_args
1821
1922
2023@Substitution (
2124 sampling_strategy = BaseUnderSampler ._sampling_strategy_docstring ,
2225 random_state = _random_state_docstring ,
26+ validate_if_dask_collection = _validate_if_dask_collection_docstring ,
2327)
2428class RandomUnderSampler (BaseUnderSampler ):
2529 """Class to perform random under-sampling.
@@ -38,6 +42,8 @@ class RandomUnderSampler(BaseUnderSampler):
3842 replacement : bool, default=False
3943 Whether the sample is with or without replacement.
4044
45+ {validate_if_dask_collection}
46+
4147 Attributes
4248 ----------
4349 sample_indices_ : ndarray of shape (n_new_samples,)
@@ -74,22 +80,23 @@ class RandomUnderSampler(BaseUnderSampler):
7480
7581 @_deprecate_positional_args
7682 def __init__ (
77- self , * , sampling_strategy = "auto" , random_state = None , replacement = False
83+ self ,
84+ * ,
85+ sampling_strategy = "auto" ,
86+ random_state = None ,
87+ replacement = False ,
88+ validate_if_dask_collection = False ,
7889 ):
79- super ().__init__ (sampling_strategy = sampling_strategy )
90+ super ().__init__ (
91+ sampling_strategy = sampling_strategy ,
92+ validate_if_dask_collection = validate_if_dask_collection ,
93+ )
8094 self .random_state = random_state
8195 self .replacement = replacement
8296
8397 def _check_X_y (self , X , y ):
84- if is_dask_container (y ) and hasattr (y , "to_dask_array" ):
85- y = y .to_dask_array ()
86- y .compute_chunk_sizes ()
87- y , binarize_y , self ._uniques = check_target_type (
88- y ,
89- indicate_one_vs_all = True ,
90- return_unique = True ,
91- )
92- if not any ([is_dask_container (arr ) for arr in (X , y )]):
98+ y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
99+ if not any ([is_dask_collection (arr ) for arr in (X , y )]):
93100 X , y = self ._validate_data (
94101 X ,
95102 y ,
@@ -98,24 +105,23 @@ def _check_X_y(self, X, y):
98105 dtype = None ,
99106 force_all_finite = False ,
100107 )
101- elif is_dask_container (X ) and hasattr (X , "to_dask_array" ):
102- X = X .to_dask_array ()
103- X .compute_chunk_sizes ()
104108 return X , y , binarize_y
105109
106110 @staticmethod
107111 def _find_target_class_indices (y , target_class ):
108112 target_class_indices = np .flatnonzero (y == target_class )
109- if is_dask_container (y ):
110- return target_class_indices .compute ()
113+ if is_dask_collection (y ):
114+ from dask import compute
115+
116+ return compute (target_class_indices )[0 ]
111117 return target_class_indices
112118
113119 def _fit_resample (self , X , y ):
114120 random_state = check_random_state (self .random_state )
115121
116122 idx_under = []
117123
118- for target_class in self ._uniques :
124+ for target_class in self ._classes_counts :
119125 target_class_indices = self ._find_target_class_indices (
120126 y , target_class
121127 )
0 commit comments