1414from sklearn .utils .multiclass import check_classification_targets
1515
1616from .utils import check_sampling_strategy , check_target_type
17+ from .utils ._validation import ArraysTransformer
1718
1819
1920class SamplerMixin (BaseEstimator , metaclass = ABCMeta ):
@@ -72,6 +73,7 @@ def fit_resample(self, X, y):
7273 The corresponding label of `X_resampled`.
7374 """
7475 check_classification_targets (y )
76+ arrays_transformer = ArraysTransformer (X , y )
7577 X , y , binarize_y = self ._check_X_y (X , y )
7678
7779 self .sampling_strategy_ = check_sampling_strategy (
@@ -80,21 +82,10 @@ def fit_resample(self, X, y):
8082
8183 output = self ._fit_resample (X , y )
8284
83- if self ._X_columns is not None or self ._y_name is not None :
84- import pandas as pd
85-
86- if self ._X_columns is not None :
87- X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
88- X_ = X_ .astype (self ._X_dtypes )
89- else :
90- X_ = output [0 ]
91-
9285 y_ = (label_binarize (output [1 ], np .unique (y ))
9386 if binarize_y else output [1 ])
9487
95- if self ._y_name is not None :
96- y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
97-
88+ X_ , y_ = arrays_transformer .transform (output [0 ], y_ )
9889 return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
9990
10091 # define an alias for back-compatibility
@@ -137,22 +128,6 @@ def __init__(self, sampling_strategy="auto"):
137128 self .sampling_strategy = sampling_strategy
138129
139130 def _check_X_y (self , X , y , accept_sparse = None ):
140- if hasattr (X , "loc" ):
141- # store information to build dataframe
142- self ._X_columns = X .columns
143- self ._X_dtypes = X .dtypes
144- else :
145- self ._X_columns = None
146- self ._X_dtypes = None
147-
148- if hasattr (y , "loc" ):
149- # store information to build a series
150- self ._y_name = y .name
151- self ._y_dtype = y .dtype
152- else :
153- self ._y_name = None
154- self ._y_dtype = None
155-
156131 if accept_sparse is None :
157132 accept_sparse = ["csr" , "csc" ]
158133 y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
@@ -265,8 +240,8 @@ def fit_resample(self, X, y):
265240 y_resampled : array-like of shape (n_samples_new,)
266241 The corresponding label of `X_resampled`.
267242 """
268- # store the columns name to reconstruct a dataframe
269- self . _columns = X . columns if hasattr ( X , "loc" ) else None
243+ arrays_transformer = ArraysTransformer ( X , y )
244+
270245 if self .validate :
271246 check_classification_targets (y )
272247 X , y , binarize_y = self ._check_X_y (
@@ -280,22 +255,12 @@ def fit_resample(self, X, y):
280255 output = self ._fit_resample (X , y )
281256
282257 if self .validate :
283- if self ._X_columns is not None or self ._y_name is not None :
284- import pandas as pd
285-
286- if self ._X_columns is not None :
287- X_ = pd .DataFrame (output [0 ], columns = self ._X_columns )
288- X_ = X_ .astype (self ._X_dtypes )
289- else :
290- X_ = output [0 ]
291258
292259 y_ = (label_binarize (output [1 ], np .unique (y ))
293260 if binarize_y else output [1 ])
294-
295- if self ._y_name is not None :
296- y_ = pd .Series (y_ , dtype = self ._y_dtype , name = self ._y_name )
297-
261+ X_ , y_ = arrays_transformer .transform (output [0 ], y_ )
298262 return (X_ , y_ ) if len (output ) == 2 else (X_ , y_ , output [2 ])
263+
299264 return output
300265
301266 def _fit_resample (self , X , y ):
0 commit comments