|
3 | 3 | # Christos Aridas |
4 | 4 | # License: MIT |
5 | 5 |
|
6 | | -from __future__ import print_function |
7 | | - |
8 | 6 | from collections import Counter |
9 | 7 |
|
10 | 8 | import pytest |
11 | 9 | import numpy as np |
12 | 10 |
|
13 | | -from pytest import raises |
14 | | - |
15 | 11 | from sklearn.datasets import load_iris |
16 | 12 |
|
17 | 13 | from imblearn.datasets import make_imbalance |
18 | 14 |
|
19 | | -data = load_iris() |
20 | | -X, Y = data.data, data.target |
21 | 15 |
|
| 16 | +@pytest.fixture |
| 17 | +def iris(): |
| 18 | + return load_iris(return_X_y=True) |
22 | 19 |
|
23 | | -def test_make_imbalanced_backcompat(): |
| 20 | + |
| 21 | +def test_make_imbalanced_backcompat(iris): |
24 | 22 | # check an error is raised with we don't pass sampling_strategy and ratio |
25 | | - with raises(TypeError, match="missing 1 required positional argument"): |
26 | | - make_imbalance(X, Y) |
| 23 | + with pytest.raises(TypeError, match="missing 1 required positional argument"): |
| 24 | + make_imbalance(*iris) |
27 | 25 |
|
28 | 26 |
|
29 | | -def test_make_imbalance_error(): |
| 27 | +@pytest.mark.parametrize( |
| 28 | + "sampling_strategy, err_msg", |
| 29 | + [({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"), |
| 30 | + ({0: 10, 1: 70}, "should be less or equal to the original"), |
| 31 | + ('random-string', "has to be a dictionary or a function")] |
| 32 | +) |
| 33 | +def test_make_imbalance_error(iris, sampling_strategy, err_msg): |
30 | 34 | # we are reusing part of utils.check_sampling_strategy, however this is not |
31 | 35 | # cover in the common tests so we will repeat it here |
32 | | - sampling_strategy = {0: -100, 1: 50, 2: 50} |
33 | | - with raises(ValueError, match="in a class cannot be negative"): |
34 | | - make_imbalance(X, Y, sampling_strategy) |
35 | | - sampling_strategy = {0: 10, 1: 70} |
36 | | - with raises(ValueError, match="should be less or equal to the original"): |
37 | | - make_imbalance(X, Y, sampling_strategy) |
38 | | - y_ = np.zeros((X.shape[0], )) |
39 | | - sampling_strategy = {0: 10} |
40 | | - with raises(ValueError, match="needs to have more than 1 class."): |
41 | | - make_imbalance(X, y_, sampling_strategy) |
42 | | - sampling_strategy = 'random-string' |
43 | | - with raises(ValueError, match="has to be a dictionary or a function"): |
44 | | - make_imbalance(X, Y, sampling_strategy) |
45 | | - |
46 | | - |
47 | | -def test_make_imbalance_dict(): |
48 | | - sampling_strategy = {0: 10, 1: 20, 2: 30} |
49 | | - X_, y_ = make_imbalance(X, Y, sampling_strategy=sampling_strategy) |
50 | | - assert Counter(y_) == sampling_strategy |
51 | | - |
52 | | - sampling_strategy = {0: 10, 1: 20} |
53 | | - X_, y_ = make_imbalance(X, Y, sampling_strategy=sampling_strategy) |
54 | | - assert Counter(y_) == {0: 10, 1: 20, 2: 50} |
| 36 | + X, y = iris |
| 37 | + with pytest.raises(ValueError, match=err_msg): |
| 38 | + make_imbalance(X, y, sampling_strategy) |
| 39 | + |
| 40 | + |
| 41 | +def test_make_imbalance_error_single_class(iris): |
| 42 | + X, y = iris |
| 43 | + y = np.zeros_like(y) |
| 44 | + with pytest.raises(ValueError, match="needs to have more than 1 class."): |
| 45 | + make_imbalance(X, y, {0: 10}) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.parametrize( |
| 49 | + "sampling_strategy, expected_counts", |
| 50 | + [({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}), |
| 51 | + ({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50})] |
| 52 | +) |
| 53 | +def test_make_imbalance_dict(iris, sampling_strategy, expected_counts): |
| 54 | + X, y = iris |
| 55 | + _, y_ = make_imbalance(X, y, sampling_strategy=sampling_strategy) |
| 56 | + assert Counter(y_) == expected_counts |
55 | 57 |
|
56 | 58 |
|
57 | 59 | @pytest.mark.filterwarnings("ignore:'ratio' has been deprecated in 0.4") |
58 | | -def test_make_imbalance_ratio(): |
59 | | - # check that using 'ratio' is working |
60 | | - sampling_strategy = {0: 10, 1: 20, 2: 30} |
61 | | - X_, y_ = make_imbalance(X, Y, ratio=sampling_strategy) |
62 | | - assert Counter(y_) == sampling_strategy |
63 | | - |
64 | | - sampling_strategy = {0: 10, 1: 20} |
65 | | - X_, y_ = make_imbalance(X, Y, ratio=sampling_strategy) |
66 | | - assert Counter(y_) == {0: 10, 1: 20, 2: 50} |
| 60 | +@pytest.mark.parametrize( |
| 61 | + "sampling_strategy, expected_counts", |
| 62 | + [({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}), |
| 63 | + ({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50})] |
| 64 | +) |
| 65 | +def test_make_imbalance_dict_ratio(iris, sampling_strategy, expected_counts): |
| 66 | + X, y = iris |
| 67 | + _, y_ = make_imbalance(X, y, ratio=sampling_strategy) |
| 68 | + assert Counter(y_) == expected_counts |
0 commit comments