1717from imblearn .utils import check_neighbors_object
1818from imblearn .utils import check_sampling_strategy
1919from imblearn .utils import check_target_type
20+ from imblearn .utils import get_classes_counts
2021from imblearn .utils ._validation import ArraysTransformer
2122from imblearn .utils ._validation import _deprecate_positional_args
2223
2324multiclass_target = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
25+ multiclass_classes_counts = get_classes_counts (multiclass_target )
2426binary_target = np .array ([1 ] * 25 + [0 ] * 100 )
27+ binary_classes_counts = get_classes_counts (binary_target )
2528
2629
2730def test_check_neighbors_object ():
@@ -70,11 +73,11 @@ def test_check_target_type_ova(target, output_target, is_ova):
7073 assert binarize_target == is_ova
7174
7275
73- def test_check_sampling_strategy_warning ():
76+ def test_check_sampling_strategy_error_dict_cleaning_methods ():
7477 msg = "dict for cleaning methods is not supported"
7578 with pytest .raises (ValueError , match = msg ):
7679 check_sampling_strategy (
77- {1 : 0 , 2 : 0 , 3 : 0 }, multiclass_target , "clean-sampling"
80+ {1 : 0 , 2 : 0 , 3 : 0 }, multiclass_classes_counts , "clean-sampling"
7881 )
7982
8083
@@ -83,19 +86,19 @@ def test_check_sampling_strategy_warning():
8386 [
8487 (
8588 0.5 ,
86- binary_target ,
89+ binary_classes_counts ,
8790 "clean-sampling" ,
8891 "'clean-sampling' methods do let the user specify the sampling ratio" , # noqa
8992 ),
9093 (
9194 0.1 ,
92- np .array ([0 ] * 10 + [1 ] * 20 ),
95+ get_classes_counts ( np .array ([0 ] * 10 + [1 ] * 20 ) ),
9396 "over-sampling" ,
9497 "remove samples from the minority class while trying to generate new" , # noqa
9598 ),
9699 (
97100 0.1 ,
98- np .array ([0 ] * 10 + [1 ] * 20 ),
101+ get_classes_counts ( np .array ([0 ] * 10 + [1 ] * 20 ) ),
99102 "under-sampling" ,
100103 "generate new sample in the majority class while trying to remove" ,
101104 ),
@@ -108,15 +111,21 @@ def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
108111
109112def test_check_sampling_strategy_error ():
110113 with pytest .raises (ValueError , match = "'sampling_type' should be one of" ):
111- check_sampling_strategy ("auto" , np .array ([1 , 2 , 3 ]), "rnd" )
114+ check_sampling_strategy (
115+ "auto" , get_classes_counts (np .array ([1 , 2 , 3 ])), "rnd"
116+ )
112117
113118 error_regex = "The target 'y' needs to have more than 1 class."
114119 with pytest .raises (ValueError , match = error_regex ):
115- check_sampling_strategy ("auto" , np .ones ((10 ,)), "over-sampling" )
120+ check_sampling_strategy (
121+ "auto" , get_classes_counts (np .ones ((10 ,))), "over-sampling"
122+ )
116123
117124 error_regex = "When 'sampling_strategy' is a string, it needs to be one of"
118125 with pytest .raises (ValueError , match = error_regex ):
119- check_sampling_strategy ("rnd" , np .array ([1 , 2 , 3 ]), "over-sampling" )
126+ check_sampling_strategy (
127+ "rnd" , get_classes_counts (np .array ([1 , 2 , 3 ])), "over-sampling"
128+ )
120129
121130
122131@pytest .mark .parametrize (
@@ -136,7 +145,9 @@ def test_check_sampling_strategy_error_wrong_string(
136145 ),
137146 ):
138147 check_sampling_strategy (
139- sampling_strategy , np .array ([1 , 2 , 3 ]), sampling_type
148+ sampling_strategy ,
149+ get_classes_counts (np .array ([1 , 2 , 3 ])),
150+ sampling_type ,
140151 )
141152
142153
@@ -153,14 +164,18 @@ def test_sampling_strategy_class_target_unknown(
153164):
154165 y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
155166 with pytest .raises (ValueError , match = "are not present in the data." ):
156- check_sampling_strategy (sampling_strategy , y , sampling_method )
167+ check_sampling_strategy (
168+ sampling_strategy , get_classes_counts (y ), sampling_method
169+ )
157170
158171
159172def test_sampling_strategy_dict_error ():
160173 y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
161174 sampling_strategy = {1 : - 100 , 2 : 50 , 3 : 25 }
162175 with pytest .raises (ValueError , match = "in a class cannot be negative." ):
163- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
176+ check_sampling_strategy (
177+ sampling_strategy , get_classes_counts (y ), "under-sampling"
178+ )
164179 sampling_strategy = {1 : 45 , 2 : 100 , 3 : 70 }
165180 error_regex = (
166181 "With over-sampling methods, the number of samples in a"
@@ -169,7 +184,9 @@ def test_sampling_strategy_dict_error():
169184 " samples are asked."
170185 )
171186 with pytest .raises (ValueError , match = error_regex ):
172- check_sampling_strategy (sampling_strategy , y , "over-sampling" )
187+ check_sampling_strategy (
188+ sampling_strategy , get_classes_counts (y ), "over-sampling"
189+ )
173190
174191 error_regex = (
175192 "With under-sampling methods, the number of samples in a"
@@ -178,21 +195,27 @@ def test_sampling_strategy_dict_error():
178195 " are asked."
179196 )
180197 with pytest .raises (ValueError , match = error_regex ):
181- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
198+ check_sampling_strategy (
199+ sampling_strategy , get_classes_counts (y ), "under-sampling"
200+ )
182201
183202
184203@pytest .mark .parametrize ("sampling_strategy" , [- 10 , 10 ])
185204def test_sampling_strategy_float_error_not_in_range (sampling_strategy ):
186205 y = np .array ([1 ] * 50 + [2 ] * 100 )
187206 with pytest .raises (ValueError , match = "it should be in the range" ):
188- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
207+ check_sampling_strategy (
208+ sampling_strategy , get_classes_counts (y ), "under-sampling"
209+ )
189210
190211
191212def test_sampling_strategy_float_error_not_binary ():
192213 y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
193214 with pytest .raises (ValueError , match = "the type of target is binary" ):
194215 sampling_strategy = 0.5
195- check_sampling_strategy (sampling_strategy , y , "under-sampling" )
216+ check_sampling_strategy (
217+ sampling_strategy , get_classes_counts (y ), "under-sampling"
218+ )
196219
197220
198221@pytest .mark .parametrize (
@@ -202,7 +225,9 @@ def test_sampling_strategy_list_error_not_clean_sampling(sampling_method):
202225 y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
203226 with pytest .raises (ValueError , match = "cannot be a list for samplers" ):
204227 sampling_strategy = [1 , 2 , 3 ]
205- check_sampling_strategy (sampling_strategy , y , sampling_method )
228+ check_sampling_strategy (
229+ sampling_strategy , get_classes_counts (y ), sampling_method
230+ )
206231
207232
208233def _sampling_strategy_func (y ):
@@ -215,42 +240,87 @@ def _sampling_strategy_func(y):
215240@pytest .mark .parametrize (
216241 "sampling_strategy, sampling_type, expected_sampling_strategy, target" ,
217242 [
218- ("auto" , "under-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
219- ("auto" , "clean-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
220- ("auto" , "over-sampling" , {1 : 50 , 3 : 75 }, multiclass_target ),
221- ("all" , "over-sampling" , {1 : 50 , 2 : 0 , 3 : 75 }, multiclass_target ),
222- ("all" , "under-sampling" , {1 : 25 , 2 : 25 , 3 : 25 }, multiclass_target ),
223- ("all" , "clean-sampling" , {1 : 25 , 2 : 25 , 3 : 25 }, multiclass_target ),
224- ("majority" , "under-sampling" , {2 : 25 }, multiclass_target ),
225- ("majority" , "clean-sampling" , {2 : 25 }, multiclass_target ),
226- ("minority" , "over-sampling" , {3 : 75 }, multiclass_target ),
227- ("not minority" , "over-sampling" , {1 : 50 , 2 : 0 }, multiclass_target ),
228- ("not minority" , "under-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
229- ("not minority" , "clean-sampling" , {1 : 25 , 2 : 25 }, multiclass_target ),
230- ("not majority" , "over-sampling" , {1 : 50 , 3 : 75 }, multiclass_target ),
231- ("not majority" , "under-sampling" , {1 : 25 , 3 : 25 }, multiclass_target ),
232- ("not majority" , "clean-sampling" , {1 : 25 , 3 : 25 }, multiclass_target ),
243+ ("auto" , "under-sampling" , {1 : 25 , 2 : 25 }, multiclass_classes_counts ),
244+ ("auto" , "clean-sampling" , {1 : 25 , 2 : 25 }, multiclass_classes_counts ),
245+ ("auto" , "over-sampling" , {1 : 50 , 3 : 75 }, multiclass_classes_counts ),
246+ (
247+ "all" ,
248+ "over-sampling" ,
249+ {1 : 50 , 2 : 0 , 3 : 75 },
250+ multiclass_classes_counts ,
251+ ),
252+ (
253+ "all" ,
254+ "under-sampling" ,
255+ {1 : 25 , 2 : 25 , 3 : 25 },
256+ multiclass_classes_counts ,
257+ ),
258+ (
259+ "all" ,
260+ "clean-sampling" ,
261+ {1 : 25 , 2 : 25 , 3 : 25 },
262+ multiclass_classes_counts ,
263+ ),
264+ ("majority" , "under-sampling" , {2 : 25 }, multiclass_classes_counts ),
265+ ("majority" , "clean-sampling" , {2 : 25 }, multiclass_classes_counts ),
266+ ("minority" , "over-sampling" , {3 : 75 }, multiclass_classes_counts ),
267+ (
268+ "not minority" ,
269+ "over-sampling" ,
270+ {1 : 50 , 2 : 0 },
271+ multiclass_classes_counts ,
272+ ),
273+ (
274+ "not minority" ,
275+ "under-sampling" ,
276+ {1 : 25 , 2 : 25 },
277+ multiclass_classes_counts ,
278+ ),
279+ (
280+ "not minority" ,
281+ "clean-sampling" ,
282+ {1 : 25 , 2 : 25 },
283+ multiclass_classes_counts ,
284+ ),
285+ (
286+ "not majority" ,
287+ "over-sampling" ,
288+ {1 : 50 , 3 : 75 },
289+ multiclass_classes_counts ,
290+ ),
291+ (
292+ "not majority" ,
293+ "under-sampling" ,
294+ {1 : 25 , 3 : 25 },
295+ multiclass_classes_counts ,
296+ ),
297+ (
298+ "not majority" ,
299+ "clean-sampling" ,
300+ {1 : 25 , 3 : 25 },
301+ multiclass_classes_counts ,
302+ ),
233303 (
234304 {1 : 70 , 2 : 100 , 3 : 70 },
235305 "over-sampling" ,
236306 {1 : 20 , 2 : 0 , 3 : 45 },
237- multiclass_target ,
307+ multiclass_classes_counts ,
238308 ),
239309 (
240310 {1 : 30 , 2 : 45 , 3 : 25 },
241311 "under-sampling" ,
242312 {1 : 30 , 2 : 45 , 3 : 25 },
243- multiclass_target ,
313+ multiclass_classes_counts ,
244314 ),
245- ([1 ], "clean-sampling" , {1 : 25 }, multiclass_target ),
315+ ([1 ], "clean-sampling" , {1 : 25 }, multiclass_classes_counts ),
246316 (
247317 _sampling_strategy_func ,
248318 "over-sampling" ,
249319 {1 : 50 , 2 : 0 , 3 : 75 },
250- multiclass_target ,
320+ multiclass_classes_counts ,
251321 ),
252- (0.5 , "over-sampling" , {1 : 25 }, binary_target ),
253- (0.5 , "under-sampling" , {0 : 50 }, binary_target ),
322+ (0.5 , "over-sampling" , {1 : 25 }, binary_classes_counts ),
323+ (0.5 , "under-sampling" , {0 : 50 }, binary_classes_counts ),
254324 ],
255325)
256326def test_check_sampling_strategy (
@@ -271,23 +341,27 @@ def test_sampling_strategy_dict_over_sampling():
271341 r" the majority class \(class #2 -> 100\)"
272342 )
273343 with warns (UserWarning , expected_msg ):
274- check_sampling_strategy (sampling_strategy , y , "over-sampling" )
344+ check_sampling_strategy (
345+ sampling_strategy , get_classes_counts (y ), "over-sampling"
346+ )
275347
276348
277349def test_sampling_strategy_callable_args ():
278350 y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
279351 multiplier = {1 : 1.5 , 2 : 1 , 3 : 3 }
280352
281- def sampling_strategy_func (y , multiplier ):
353+ def sampling_strategy_func (classes_counts , multiplier ):
282354 """samples such that each class will be affected by the multiplier."""
283- target_stats = Counter (y )
284355 return {
285356 key : int (values * multiplier [key ])
286- for key , values in target_stats .items ()
357+ for key , values in classes_counts .items ()
287358 }
288359
289360 sampling_strategy_ = check_sampling_strategy (
290- sampling_strategy_func , y , "over-sampling" , multiplier = multiplier
361+ sampling_strategy_func ,
362+ get_classes_counts (y ),
363+ "over-sampling" ,
364+ multiplier = multiplier ,
291365 )
292366 assert sampling_strategy_ == {1 : 25 , 2 : 0 , 3 : 50 }
293367
@@ -314,11 +388,20 @@ def test_sampling_strategy_check_order(
314388 # dictionary is sorted. Refer to issue #428.
315389 y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
316390 sampling_strategy_ = check_sampling_strategy (
317- sampling_strategy , y , sampling_type
391+ sampling_strategy , get_classes_counts ( y ) , sampling_type
318392 )
319393 assert sampling_strategy_ == expected_result
320394
321395
396+ # FIXME: remove in 0.9
397+ def test_sampling_strategy_deprecation_array_target ():
398+ # Check that we raise a FutureWarning when an array of target is passed
399+ with pytest .warns (FutureWarning ):
400+ sampling_strategy = "auto"
401+ check_sampling_strategy (
402+ sampling_strategy , binary_target , "under-sampling" ,
403+ )
404+
322405def test_arrays_transformer_plain_list ():
323406 X = np .array ([[0 , 0 ], [1 , 1 ]])
324407 y = np .array ([[0 , 0 ], [1 , 1 ]])
0 commit comments