diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index a248f7b..789ebe8 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -12,11 +12,13 @@ import numpy as np import scipy as sp from sklearn.utils import check_random_state, check_X_y -from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.base import BaseEstimator +from sklearn.feature_selection import SelectorMixin +from sklearn.utils.validation import check_is_fitted import warnings -class BorutaPy(BaseEstimator, TransformerMixin): +class BorutaPy(BaseEstimator, SelectorMixin): """ Improved Python implementation of the Boruta R package. @@ -452,6 +454,10 @@ def _transform(self, X, weak=False, return_df=False): X = X[:, indices] return X + def _get_support_mask(self): + check_is_fitted(self, 'support_') + return self.support_ + def _get_tree_num(self, n_feat): depth = None try: