33Data generation mechanism
44=========================
55
6- This example illustrates the Geometric SMOTE data
7- generation mechanism and the usage of its
6+ This example illustrates the Geometric SMOTE data
7+ generation mechanism and the usage of its
88hyperparameters.
99
1010"""
1616import matplotlib .pyplot as plt
1717
1818from sklearn .datasets import make_blobs
19- from imblearn .over_sampling import SMOTE
20-
21- from gsmote import GeometricSMOTE
19+ from imblearn .over_sampling import SMOTE , GeometricSMOTE
2220
2321print (__doc__ )
2422
@@ -47,11 +45,11 @@ def generate_imbalanced_data(
4745def plot_scatter (X , y , title ):
4846 """Function to plot some data as a scatter plot."""
4947 plt .figure ()
50- plt .scatter (X [y == 1 , 0 ], X [y == 1 , 1 ], label = ' Positive Class' )
51- plt .scatter (X [y == 0 , 0 ], X [y == 0 , 1 ], label = ' Negative Class' )
48+ plt .scatter (X [y == 1 , 0 ], X [y == 1 , 1 ], label = " Positive Class" )
49+ plt .scatter (X [y == 0 , 0 ], X [y == 0 , 1 ], label = " Negative Class" )
5250 plt .xlim (* XLIM )
5351 plt .ylim (* YLIM )
54- plt .gca ().set_aspect (' equal' , adjustable = ' box' )
52+ plt .gca ().set_aspect (" equal" , adjustable = " box" )
5553 plt .legend ()
5654 plt .title (title )
5755
@@ -66,9 +64,9 @@ def plot_hyperparameters(oversampler, X, y, param, vals, n_subplots):
6664 for ax , val in zip (ax_arr , vals ):
6765 oversampler .set_params (** {param : val })
6866 X_res , y_res = oversampler .fit_resample (X , y )
69- ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = ' Positive Class' )
70- ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = ' Negative Class' )
71- ax .set_title (f' { val } ' )
67+ ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = " Positive Class" )
68+ ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = " Negative Class" )
69+ ax .set_title (f" { val } " )
7270 ax .set_xlim (* XLIM )
7371 ax .set_ylim (* YLIM )
7472
@@ -79,8 +77,8 @@ def plot_comparison(oversamplers, X, y):
7977 fig , ax_arr = plt .subplots (1 , 2 , figsize = (15 , 5 ))
8078 for ax , (name , ovs ) in zip (ax_arr , oversamplers ):
8179 X_res , y_res = ovs .fit_resample (X , y )
82- ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = ' Positive Class' )
83- ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = ' Negative Class' )
80+ ax .scatter (X_res [y_res == 1 , 0 ], X_res [y_res == 1 , 1 ], label = " Positive Class" )
81+ ax .scatter (X_res [y_res == 0 , 0 ], X_res [y_res == 0 , 1 ], label = " Negative Class" )
8482 ax .set_title (name )
8583 ax .set_xlim (* XLIM )
8684 ax .set_ylim (* YLIM )
@@ -98,7 +96,7 @@ def plot_comparison(oversamplers, X, y):
9896X , y = generate_imbalanced_data (
9997 200 , 2 , [(- 2.0 , 2.25 ), (1.0 , 2.0 )], 0.25 , [- 0.7 , 2.3 ], [- 0.5 , 3.1 ]
10098)
101- plot_scatter (X , y , ' Imbalanced data' )
99+ plot_scatter (X , y , " Imbalanced data" )
102100
103101###############################################################################
104102# Geometric hyperparameters
@@ -133,13 +131,13 @@ def plot_comparison(oversamplers, X, y):
133131gsmote = GeometricSMOTE (
134132 k_neighbors = 1 ,
135133 deformation_factor = 0.0 ,
136- selection_strategy = ' minority' ,
134+ selection_strategy = " minority" ,
137135 random_state = RANDOM_STATE ,
138136)
139137truncation_factors = np .array ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
140138n_subplots = [2 , 3 ]
141- plot_hyperparameters (gsmote , X , y , ' truncation_factor' , truncation_factors , n_subplots )
142- plot_hyperparameters (gsmote , X , y , ' truncation_factor' , - truncation_factors , n_subplots )
139+ plot_hyperparameters (gsmote , X , y , " truncation_factor" , truncation_factors , n_subplots )
140+ plot_hyperparameters (gsmote , X , y , " truncation_factor" , - truncation_factors , n_subplots )
143141
144142###############################################################################
145143# Deformation factor
@@ -151,12 +149,12 @@ def plot_comparison(oversamplers, X, y):
151149gsmote = GeometricSMOTE (
152150 k_neighbors = 1 ,
153151 truncation_factor = 0.0 ,
154- selection_strategy = ' minority' ,
152+ selection_strategy = " minority" ,
155153 random_state = RANDOM_STATE ,
156154)
157155deformation_factors = np .array ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
158156n_subplots = [2 , 3 ]
159- plot_hyperparameters (gsmote , X , y , ' deformation_factor' , truncation_factors , n_subplots )
157+ plot_hyperparameters (gsmote , X , y , " deformation_factor" , truncation_factors , n_subplots )
160158
161159###############################################################################
162160# Selection strategy
@@ -177,10 +175,10 @@ def plot_comparison(oversamplers, X, y):
177175 deformation_factor = 0.5 ,
178176 random_state = RANDOM_STATE ,
179177)
180- selection_strategies = np .array ([' minority' , ' majority' , ' combined' ])
178+ selection_strategies = np .array ([" minority" , " majority" , " combined" ])
181179n_subplots = [1 , 3 ]
182180plot_hyperparameters (
183- gsmote , X , y , ' selection_strategy' , selection_strategies , n_subplots
181+ gsmote , X , y , " selection_strategy" , selection_strategies , n_subplots
184182)
185183
186184###############################################################################
@@ -193,7 +191,7 @@ def plot_comparison(oversamplers, X, y):
193191
194192X_new = np .vstack ([X , np .array ([2.0 , 2.0 ])])
195193y_new = np .hstack ([y , np .ones (1 , dtype = np .int8 )])
196- plot_scatter (X_new , y_new , ' Imbalanced data' )
194+ plot_scatter (X_new , y_new , " Imbalanced data" )
197195
198196###############################################################################
199197# When the number of ``k_neighbors`` is increased, SMOTE results to the
@@ -202,11 +200,11 @@ def plot_comparison(oversamplers, X, y):
202200# ``majority``.
203201
204202oversamplers = [
205- (' SMOTE' , SMOTE (k_neighbors = 2 , random_state = RANDOM_STATE )),
203+ (" SMOTE" , SMOTE (k_neighbors = 2 , random_state = RANDOM_STATE )),
206204 (
207- ' Geometric SMOTE' ,
205+ " Geometric SMOTE" ,
208206 GeometricSMOTE (
209- k_neighbors = 2 , selection_strategy = ' combined' , random_state = RANDOM_STATE
207+ k_neighbors = 2 , selection_strategy = " combined" , random_state = RANDOM_STATE
210208 ),
211209 ),
212210]
0 commit comments