22import pandas as pd
33import pytest
44from sklearn .ensemble import RandomForestClassifier
5+ from sklearn .tree import DecisionTreeClassifier , ExtraTreeClassifier
56
67from boruta import BorutaPy
78
@@ -26,8 +27,8 @@ def Xy():
2627
2728 # 5 relevant features
2829 X [:, 0 ] = z
29- X [:, 1 ] = (y * np .abs (np .random .normal (0 , 1 , 1000 ))
30- + np .random .normal (0 , 0.1 , 1000 ))
30+ X [:, 1 ] = (y * np .abs (np .random .normal (0 , 1 , 1000 )) +
31+ np .random .normal (0 , 0.1 , 1000 ))
3132 X [:, 2 ] = y + np .random .normal (0 , 1 , 1000 )
3233 X [:, 3 ] = y ** 2 + np .random .normal (0 , 1 , 1000 )
3334 X [:, 4 ] = np .sqrt (y ) + np .random .binomial (2 , 0.1 , 1000 )
@@ -65,3 +66,18 @@ def test_dataframe_is_returned(Xy):
6566 bt = BorutaPy (rfc )
6667 bt .fit (X_df , y_df )
6768 assert isinstance (bt .transform (X_df , return_df = True ), pd .DataFrame )
69+
70+
71+ @pytest .mark .parametrize ("tree" , [ExtraTreeClassifier (), DecisionTreeClassifier ()])
72+ def test_boruta_with_decision_trees (tree , Xy ):
73+ msg = (
74+ f"The estimator { tree } does not take the parameter "
75+ "n_estimators. Use Random Forests or gradient boosting machines "
76+ "instead."
77+ )
78+ X , y = Xy
79+ bt = BorutaPy (tree )
80+ with pytest .raises (ValueError ) as record :
81+ bt .fit (X , y )
82+
83+ assert str (record .value ) == msg
0 commit comments