Skip to content

Commit 4b7522e

Browse files
authored
add error message when user passes decision trees (#141)
* add tests after rebasing * add test for tree error
1 parent a9dba80 commit 4b7522e

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

boruta/boruta_py.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _fit(self, X, y):
326326

327327
# set n_estimators
328328
if self.n_estimators != 'auto':
329-
self.estimator.set_params(n_estimators=self.n_estimators)
329+
self._set_n_estimators(self.n_estimators)
330330

331331
# main feature selection loop
332332
while np.any(dec_reg == 0) and _iter < self.max_iter:
@@ -335,7 +335,7 @@ def _fit(self, X, y):
335335
# number of features that aren't rejected
336336
not_rejected = np.where(dec_reg >= 0)[0].shape[0]
337337
n_tree = self._get_tree_num(not_rejected)
338-
self.estimator.set_params(n_estimators=n_tree)
338+
self._set_n_estimators(n_estimators=n_tree)
339339

340340
# make sure we start with a new tree in each iteration
341341
if self._is_lightgbm:
@@ -454,6 +454,17 @@ def _transform(self, X, weak=False, return_df=False):
454454
X = X[:, indices]
455455
return X
456456

457+
def _set_n_estimators(self, n_estimators):
458+
try:
459+
self.estimator.set_params(n_estimators=n_estimators)
460+
except ValueError:
461+
raise ValueError(
462+
f"The estimator {self.estimator} does not take the parameter "
463+
"n_estimators. Use Random Forests or gradient boosting machines "
464+
"instead."
465+
)
466+
return self
467+
457468
def _get_tree_num(self, n_feat):
458469
depth = None
459470
try:

boruta/test/test_boruta.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pandas as pd
33
import pytest
44
from sklearn.ensemble import RandomForestClassifier
5+
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
56

67
from boruta import BorutaPy
78

@@ -26,8 +27,8 @@ def Xy():
2627

2728
# 5 relevant features
2829
X[:, 0] = z
29-
X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000))
30-
+ np.random.normal(0, 0.1, 1000))
30+
X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000)) +
31+
np.random.normal(0, 0.1, 1000))
3132
X[:, 2] = y + np.random.normal(0, 1, 1000)
3233
X[:, 3] = y**2 + np.random.normal(0, 1, 1000)
3334
X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000)
@@ -65,3 +66,18 @@ def test_dataframe_is_returned(Xy):
6566
bt = BorutaPy(rfc)
6667
bt.fit(X_df, y_df)
6768
assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame)
69+
70+
71+
@pytest.mark.parametrize("tree", [ExtraTreeClassifier(), DecisionTreeClassifier()])
72+
def test_boruta_with_decision_trees(tree, Xy):
73+
msg = (
74+
f"The estimator {tree} does not take the parameter "
75+
"n_estimators. Use Random Forests or gradient boosting machines "
76+
"instead."
77+
)
78+
X, y = Xy
79+
bt = BorutaPy(tree)
80+
with pytest.raises(ValueError) as record:
81+
bt.fit(X, y)
82+
83+
assert str(record.value) == msg

0 commit comments

Comments
 (0)